How to Use Inheritance to Write Pretty Code and Supercharge Development

Inheritance to Reduce Code Duplication Cover Image

Do you remember learning about inheritance in school? It may have seemed needlessly complex for the projects you worked on at the time. However, as you begin working on bigger, more complex projects, making wise use of this object-oriented design principal can vastly cut down on the size of your codebase, making maintenance and development much more easy. Read on to see an example of how to use inheritance to reduce code duplication in your project.

An Example: Automating Database Queries

I am the main developer and maintainer of a GraphQL API at the company I work for. One thing my API need to do is interact with a database. Every table in the database needs to be able to have create, read, update, and delete operation implemented. As you can probably imagine, the implementation of these operations will be very similar across tables. In fact, the only difference will be the schema of the data in each table. Because the implementations are the same, I can use inheritance to implement each operation once, and use the same implementation for all of my tables.

Quick Review

Inheritance is a handy tool that allows us to implement a single parent class, and then have other child classes inherit from that class. When they do this, the child classes get all the functionality of the parent, and they can extend and modify the behavior of the parent, Here is a very simple example:

class Animal:
    def __init__(self):
        self.sound = ""

    def speak(self):
        return self.sound

class Dog(Animal):
    def __init__(self):
        super.__init__(self)
        self.sound = "Woof!"

    def growl(self):
        return "Grrr!"

class Cat(Animal):
    def __init__(self):
        super.__init__(self)
        self.sound = "Meow!"

    def purr(self):
        return "Purr!"

dog = Dog()
cat = Cat()

print(dog.speak())  # Outputs: Woof!
print(cat.speak())  # Outputs: Meow!
print(dog.growl())  # Outputs: Grrr!
print(cat.purr())   # Outputs: Purr!

In this incredibly simple example of inheritance, we see both the power of inheritance to reduce code duplication, and its power to maintain flexibility and extendibility. The speak method is implemented once, in the Animal class, and both Dog and Cat have access to it. They can, however, modify the sound they make by setting the self.sound attribute.

At the same time, the child classes maintain the ability to implement new functionality specific to themselves, as can be seen in the Dog class’s growl method, and the Cat class’s purr method. So, this is a simple example of the power of inheritance to maintain flexible code with minimal duplication. Now, let’s look at a more useful example: the database interactions mentioned above.

Inheritance in Action

To use inheritance for creating objects to interact with a database, we need to begin by declaring the parent class. But first, a few prerequisites for this code to make sense:

  • Firstly, this is obviously a python class. But the principle will apply in any object-oriented language.
  • Secondly, this object makes use of SQLModel under the hood to facilitate they actual database interactions. It serves as a wrapper class around them to provide a smooth interface.
  • In the constructor, you will notice an object, self.db, set equal to DB. This DB object is a SQLModel Session.
  • You’ll notice that when instantiating this object, there are a series of SQLModel models that we pass to the object. These are not defined here, but will be defined in child classes of the Database. These passed in models are the key to reducing code duplication through inheritance while maintaining the ability to interact with multiple tables using the same validation logic.

Take a look at the simple Database class below

class Database:
    def __init__(
        self,
        base_model,
        create_model=None,
        get_model=None,
        get_with_relations_model=None,
        update_model=None,
        delete_model=None,
    ):
        """Set the database object and models for the class."""
        # create database object
        self.db = DB    #a SQLModel Session

        # set models
        self.base_model = base_model
        self.create_model = create_model
        self.get_model = get_model
        self.get_with_relations_model = get_with_relations_model
        self.update_model = update_model
        self.delete_model = delete_model

    def get_all(self) -> List[Any]:
        """get all records"""
        if self.base_model is None:
            raise Exception("base_model is not set")

        _all = self.db.query(self.base_model).all()
        return _all

    def create(self, user, data: Any) -> Any:
        """create a record"""
        if self.base_model is None:
            raise Exception("create_model is not set")

        create = self.base_model.model_validate(data)
        create.created_by = user.username
        create.updated_by = user.username
        try:
            self.db.add(create)
            self.db.commit()
            self.db.refresh(create)
        except IntegrityError as e:
            self.db.rollback()
            raise e
        return create

    def get(self, id: int) -> Any:
        """get a record by id"""
        if self.base_model is None:
            raise Exception("base_model is not set")

        get = self.db.get(self.base_model, id)
        if not get:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail=f"Record with id {id} not found",
            )
        return get

    def update(
        self,
        user,
        id: int,
        data: Any,
    ) -> Any:
        """update a record"""
        if self.base_model is None:
            raise Exception("base_model is not set")

        get = self.db.get(self.base_model, id)
        if not get:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail=f"Record with id {id} not found",
            )
        data = data.model_dump(exclude_unset=True, exclude_none=True)
        get.sqlmodel_update(data)
        get.updated_by = user.username
        try:
            self.db.add(get)
            self.db.commit()
            self.db.refresh(get)
        except IntegrityError as e:
            self.db.rollback()
            raise e
        return get

    def delete(self, id: int) -> Any:
        """delete a record by id"""
        if self.base_model is None:
            raise Exception("base_model is not set")

        get = self.db.get(self.base_model, id)
        if not get:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail=f"Record with id {id} not found",
            )
        try:
            self.db.delete(get)
            self.db.commit()
        except IntegrityError as e:
            self.db.rollback()
            raise e
        return get

You will notice that this object contains implementations of create, update, delete, get, and get all queries. It is crucial to understand that this Database object is the parent class, so no instances of this object will be created on their own. Instead, other classes for interacting with specific tables, like say, a table of users, will inherit from this class, and they will use their own unique schema when calling these queries. This means that the Database class is where we should include logic that we want to run on interactions with every table.

Some examples of this logic include rollbacks of faulty commits that raise integrity errors, setting the updated_by field on records, and raising 404s if the record is not found (examples of all can be found above). This is ideal because by including these common actions in the parent class, we no longer need to define them independently in the object for interacting with each table. This is a concrete example where inheritance helps us reduce code duplication.

Now, let’s take a look at what a child class of Database might look like.

class DBUser(Database):
    def __init__(self, **kwargs):
        """Initialize the database object and models for the users table."""
        super().__init__(
            User,
            UserCreate,
            UserUpdate,
            UserGet,
            UserGetWithRelations,
            **kwargs,
        )

I know what you are probably thinking: “That’s it?” Well yes, that is it. Because the DBUser class inherits from Database, it has access to all the functionality that we provided to Database. This class simply defines the models (which in turn define the data schema) this object should use, and leaves the rest up to the superclass. This way, we don’t have to define all the menial logic over and over again within the class for each table, since it is defined once in Database, the parent class. And, here’s the real kicker: we can define a class for interacting with every single table in our database in this exact same way. All we have to do is change the models passed in in the constructor.

Conclusion

You can probably see some of the benefits that arise from using this inheritance-based approach. For one thing, all the logic that should run for every table when a create, update, or delete occurs can be defined once (like, for example, updating an “updated_by” field), making your codebase smaller and changes easier to make. Additionally, flexibility is maintained, because if we need to run some custom logic to update a record in one particular table (but that logic is not needed to update other tables), we can override the update method of the superclass to include our custom logic. We can also add a new type of query to be used with all tables by adding a single method to the superclass. Finally, and most importantly, the ability to inherit from the superclass and simply pass in the expected data schema makes it incredibly quick and easy for our app to integrate with new tables in our database, without blowing up the size of the codebase.

So, now we have seen an example of a concrete way inheritance is used in real world applications to heavily reduce the size and complexity of projects. In fact, this principle has allowed me to slash out ~80% of the lines from the API I maintain at my company. By making wise use of principles like inheritance, you can greatly augment your development powers and make your applications vastly more scalable and maintainable. Happy coding!

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *