From: Michael Tremer Date: Fri, 5 Dec 2025 16:31:22 +0000 (+0000) Subject: dnsbl: Add our own database abstraction X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=de587ec3ed8e240233190be535648b28d5e12955;p=dbl.git dnsbl: Add our own database abstraction Otherwise we would have to carry session objects around which makes the code incredibly messy. Signed-off-by: Michael Tremer --- diff --git a/Makefile.am b/Makefile.am index 14a72ad..de0e722 100644 --- a/Makefile.am +++ b/Makefile.am @@ -51,6 +51,7 @@ SED_PROCESS = \ dist_pkgpython_PYTHON = \ src/dnsbl/__init__.py \ + src/dnsbl/database.py \ src/dnsbl/i18n.py \ src/dnsbl/lists.py \ src/dnsbl/logger.py \ diff --git a/src/dnsbl/__init__.py b/src/dnsbl/__init__.py index e4240b0..5ba2a84 100644 --- a/src/dnsbl/__init__.py +++ b/src/dnsbl/__init__.py @@ -21,7 +21,6 @@ import configparser import functools import logging -import sqlmodel # Initialize logging as early as possible from . import logger @@ -30,6 +29,7 @@ from . import logger log = logging.getLogger(__name__) # Import sub-modules +from . import database from . import lists from . import sources @@ -64,22 +64,8 @@ class Backend(object): """ uri = self.config.get("database", "uri") - # Create the database engine - return sqlmodel.create_engine( - uri, - - # Log more if we are running in debug mode - echo=self.debug, - - # Use our own logger - logging_name=log.name, - ) - - def session(self): - """ - Returns a new database session - """ - return sqlmodel.Session(self.db) + # Create a new database connection + return database.Database(self, uri) @functools.cached_property def lists(self): diff --git a/src/dnsbl/database.py b/src/dnsbl/database.py new file mode 100644 index 0000000..6319e8e --- /dev/null +++ b/src/dnsbl/database.py @@ -0,0 +1,105 @@ +############################################################################### +# # +# dnsbl - A DNS Blacklist Compositor For IPFire # +# Copyright (C) 2025 IPFire Development Team # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +############################################################################### + +import logging +import sqlmodel + +# Setup logging +log = logging.getLogger(__name__) + +class Database(object): + def __init__(self, backend, uri): + self.backend = backend + + # Connect to the database + self.engine = self.connect(uri) + + # Session + self.__session = None + + def connect(self, uri): + """ + Connects to the database + """ + # Create the database engine + engine = sqlmodel.create_engine( + uri, + + # Log more if we are running in debug mode + echo=self.backend.debug, + + # Use our own logger + logging_name=log.name, + ) + + return engine + + def session(self): + """ + Returns the current database session + """ + if self.__session is None: + self.__session = sqlmodel.Session(self.engine) + + return self.__session + + def __enter__(self): + return self.session() + + def __exit__(self, type, exception, traceback): + session = self.session() + + if exception is None: + session.commit() + + def execute(self, stmt): + """ + Executes a statement and returns a result object + """ + # Fetch our session + session = self.session() + + # Execute the statement + return session.execute(stmt) + + def insert(self, cls, **kwargs): + """ + Inserts a new object into the database + """ + # Fetch our session + session = self.session() + + # Create a new object + object = cls(**kwargs) + + # Add it to the database + session.add(object) + + # Return the object + return object + + def fetch_one(self, stmt): + result = self.execute(stmt) + + # Apply unique filtering + result = result.unique() + + # Return exactly one object or none, but fail otherwise + return result.scalar_one_or_none() diff --git a/src/dnsbl/lists.py b/src/dnsbl/lists.py index 4bd8820..40269ac 100644 --- a/src/dnsbl/lists.py +++ b/src/dnsbl/lists.py @@ -41,10 +41,7 @@ class Lists(object): ) ) - with self.backend.session() as session: - result = session.execute(stmt) - - return result.scalar_one_or_none() + return self.backend.db.fetch_one(stmt) def _make_slug(self, name): i = 0 @@ -66,17 +63,13 @@ class Lists(object): slug = self._make_slug(name) # Create a new list - with self.backend.session() as session: - list = List( - name = name, - slug = slug, - created_by = created_by, - license = license, - ) - session.add(list) - session.commit() - - return list + return self.backend.db.insert( + List, + name = name, + slug = slug, + created_by = created_by, + license = license, + ) class List(sqlmodel.SQLModel, table=True): diff --git a/src/scripts/dnsbl.in b/src/scripts/dnsbl.in index eda4423..a771ba7 100644 --- a/src/scripts/dnsbl.in +++ b/src/scripts/dnsbl.in @@ -100,7 +100,8 @@ class CLI(object): ) # Call the handler function - ret = args.func(backend, args) + with backend.db: + ret = args.func(backend, args) # Exit with the returned error code if ret: