From: Michael Tremer Date: Fri, 5 Dec 2025 17:24:36 +0000 (+0000) Subject: database: Make the backend available to all objects X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3496193522142d89df0b853e29ed3b0b2bb1664e;p=dbl.git database: Make the backend available to all objects Signed-off-by: Michael Tremer --- diff --git a/src/dnsbl/database.py b/src/dnsbl/database.py index 6dc8b44..b97eb89 100644 --- a/src/dnsbl/database.py +++ b/src/dnsbl/database.py @@ -18,7 +18,9 @@ # # ############################################################################### +import functools import logging +import sqlalchemy.orm import sqlmodel # Setup logging @@ -31,6 +33,15 @@ class Database(object): # Connect to the database self.engine = self.connect(uri) + # Create a session maker + self.sessionmaker = sqlalchemy.orm.sessionmaker( + self.engine, + expire_on_commit = False, + info = { + "backend" : self.backend, + }, + ) + # Session self.__session = None @@ -56,7 +67,7 @@ class Database(object): Returns the current database session """ if self.__session is None: - self.__session = sqlmodel.Session(self.engine) + self.__session = self.sessionmaker() return self.__session @@ -138,3 +149,17 @@ class Database(object): # Return as set return set([o for o in objects]) + + +class BackendMixin: + @functools.cached_property + def backend(self): + # Fetch the session + session = sqlalchemy.orm.object_session(self) + + # Return the backend + return session.info.get("backend") + + @functools.cached_property + def db(self): + return self.backend.db diff --git a/src/dnsbl/lists.py b/src/dnsbl/lists.py index 1396785..479ec87 100644 --- a/src/dnsbl/lists.py +++ b/src/dnsbl/lists.py @@ -22,6 +22,7 @@ import datetime import sqlmodel import typing +from . import database from . import sources from . import util @@ -92,7 +93,7 @@ class Lists(object): ) -class List(sqlmodel.SQLModel, table=True): +class List(sqlmodel.SQLModel, database.BackendMixin, table=True): __tablename__ = "lists" # ID diff --git a/src/dnsbl/sources.py b/src/dnsbl/sources.py index ed30291..8465b47 100644 --- a/src/dnsbl/sources.py +++ b/src/dnsbl/sources.py @@ -21,6 +21,8 @@ import datetime import sqlmodel +from . import database + class Sources(object): def __init__(self, backend): self.backend = backend @@ -60,7 +62,7 @@ class Sources(object): return source -class Source(sqlmodel.SQLModel, table=True): +class Source(sqlmodel.SQLModel, database.BackendMixin, table=True): __tablename__ = "sources" # ID