The dataclasses module brings to the table a great way to create simple and lightweight data structures in Python. However, sometimes you need more complex data structures that require inheritance. Inheritance is a powerful feature in object-oriented programming that allows you to define a base class with common fields and methods and then create subclasses that inherit these fields and methods while also adding their own unique fields and methods.
In order to use inheritance with dataclass, you can define a base class using the @dataclass decorator, and then, you can define a subclass that inherits from the base class using the standard Python syntax for inheritance.
In the example below, we are going to define a base class named Animal and two subclasses named Cat and Dog:
from dataclasses import dataclass # This is the base class @dataclass class Animal: name: str sound: str def make_sound(self): print(self.sound) # This is a subclass that inherits from Animal @dataclass class Cat(Animal): breed: str # This is another subclass that inherits from Animal @dataclass class Dog(Animal): age: int # Create some instances of the subclasses cat = Cat("Fluffy", "Meow", "Persian") dog = Dog("Spot", "Woof", 5) print(cat) print(dog)
Cat(name='Fluffy', sound='Meow', breed='Persian') Dog(name='Spot', sound='Woof', age=5)
When you define a subclass, it automatically inherits all the fields and methods of its parent class. In our example, the Cat and Dog subclasses both inherit the name and sound fields, as well as the make_sound() method, from the Animal class.
Abstract Base Class with Abstract Methods
This example shows you how to use inheritance with dataclasses to implement a simple class hierarchy for shapes. We are going to use the @abstractmethod and @abstractmethod decorators in the base class.
from dataclasses import dataclass from abc import ABC, abstractmethod from math import pi from typing import Optional # abstract base class @dataclass class Shape(ABC): @abstractmethod def area(self) -> float: pass @abstractmethod def perimeter(self) -> float: pass # concrete class implementations # note that we don't need to implement the abstract methods # in the base class, since they are inherited @dataclass class Circle(Shape): radius: float def area(self) -> float: return pi * self.radius ** 2 def perimeter(self) -> float: return 2 * pi * self.radius @dataclass class Rectangle(Shape): width: float height: float def area(self) -> float: return self.width * self.height def perimeter(self) -> float: return 2 * (self.width + self.height) @dataclass class Square(Rectangle): width: float height: Optional[float] = None def __post_init__(self): self.height = self.width def perimeter(self) -> float: return 4 * self.width # try it out circle = Circle(5) print("Circle area:", circle.area()) print("Circle perimeter:", circle.perimeter()) rectangle = Rectangle(5, 10) print("Rectangle area:", rectangle.area()) print("Rectangle perimeter:", rectangle.perimeter()) square = Square(5) print("Square area:", square.area()) print("Square perimeter:", square.perimeter())
Circle area: 78.53981633974483 Circle perimeter: 31.41592653589793 Rectangle area: 50 Rectangle perimeter: 30 Square area: 25 Square perimeter: 20
That’s it. Happy coding. If you have any questions, please comment!