How to Use Inheritance with Dataclass in Python

Updated: March 1, 2023 By: Khue Post a comment

The dataclasses module brings to the table a great way to create simple and lightweight data structures in Python. However, sometimes you need more complex data structures that require inheritance. Inheritance is a powerful feature in object-oriented programming that allows you to define a base class with common fields and methods and then create subclasses that inherit these fields and methods while also adding their own unique fields and methods.

The Fundamentals

In order to use inheritance with dataclass, you can define a base class using the @dataclass decorator, and then, you can define a subclass that inherits from the base class using the standard Python syntax for inheritance.

In the example below, we are going to define a base class named Animal and two subclasses named Cat and Dog:

from dataclasses import dataclass

# This is the base class
@dataclass
class Animal:
    name: str
    sound: str

    def make_sound(self):
        print(self.sound)


# This is a subclass that inherits from Animal
@dataclass
class Cat(Animal):
    breed: str

# This is another subclass that inherits from Animal
@dataclass
class Dog(Animal):
    age: int

# Create some instances of the subclasses
cat = Cat("Fluffy", "Meow", "Persian")
dog = Dog("Spot", "Woof", 5)

print(cat)
print(dog)

Output:

Cat(name='Fluffy', sound='Meow', breed='Persian')
Dog(name='Spot', sound='Woof', age=5)

When you define a subclass, it automatically inherits all the fields and methods of its parent class. In our example, the Cat and Dog subclasses both inherit the name and sound fields, as well as the make_sound() method, from the Animal class.

Abstract Base Class with Abstract Methods

This example shows you how to use inheritance with dataclasses to implement a simple class hierarchy for shapes. We are going to use the @abstractmethod and @abstractmethod decorators in the base class.

The code:

from dataclasses import dataclass
from abc import ABC, abstractmethod
from math import pi
from typing import Optional

# abstract base class
@dataclass
class Shape(ABC):
    @abstractmethod
    def area(self) -> float:
        pass

    @abstractmethod
    def perimeter(self) -> float:
        pass


# concrete class implementations
# note that we don't need to implement the abstract methods
# in the base class, since they are inherited

@dataclass
class Circle(Shape):
    radius: float

    def area(self) -> float:
        return pi * self.radius ** 2

    def perimeter(self) -> float:
        return 2 * pi * self.radius


@dataclass
class Rectangle(Shape):
    width: float
    height: float

    def area(self) -> float:
        return self.width * self.height

    def perimeter(self) -> float:
        return 2 * (self.width + self.height)


@dataclass
class Square(Rectangle):
    width: float
    height: Optional[float] = None

    def __post_init__(self):
        self.height = self.width

    def perimeter(self) -> float:
        return 4 * self.width


# try it out
circle = Circle(5)
print("Circle area:", circle.area())
print("Circle perimeter:", circle.perimeter())

rectangle = Rectangle(5, 10)
print("Rectangle area:", rectangle.area())
print("Rectangle perimeter:", rectangle.perimeter())

square = Square(5)
print("Square area:", square.area())
print("Square perimeter:", square.perimeter())

Output:

Circle area: 78.53981633974483
Circle perimeter: 31.41592653589793
Rectangle area: 50
Rectangle perimeter: 30
Square area: 25
Square perimeter: 20

That’s it. Happy coding. If you have any questions, please comment!