]> git.ipfire.org Git - dbl.git/commitdiff
dnsbl: Add our own database abstraction
authorMichael Tremer <michael.tremer@ipfire.org>
Fri, 5 Dec 2025 16:31:22 +0000 (16:31 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Fri, 5 Dec 2025 16:31:22 +0000 (16:31 +0000)
Otherwise we would have to carry session objects around which makes the
code incredibly messy.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/dnsbl/__init__.py
src/dnsbl/database.py [new file with mode: 0644]
src/dnsbl/lists.py
src/scripts/dnsbl.in

index 14a72ad75d3bdb956ef38e72c12c22c2daa3021d..de0e722dc46c5ef24b47f1c567476000ad6be8b7 100644 (file)
@@ -51,6 +51,7 @@ SED_PROCESS = \
 
 dist_pkgpython_PYTHON = \
        src/dnsbl/__init__.py \
+       src/dnsbl/database.py \
        src/dnsbl/i18n.py \
        src/dnsbl/lists.py \
        src/dnsbl/logger.py \
index e4240b0b42880b2565e8a5a000d2f24264ad68e8..5ba2a84824afb17662a6d036db82ba5968d68b45 100644 (file)
@@ -21,7 +21,6 @@
 import configparser
 import functools
 import logging
-import sqlmodel
 
 # Initialize logging as early as possible
 from . import logger
@@ -30,6 +29,7 @@ from . import logger
 log = logging.getLogger(__name__)
 
 # Import sub-modules
+from . import database
 from . import lists
 from . import sources
 
@@ -64,22 +64,8 @@ class Backend(object):
                """
                uri = self.config.get("database", "uri")
 
-               # Create the database engine
-               return sqlmodel.create_engine(
-                       uri,
-
-                       # Log more if we are running in debug mode
-                       echo=self.debug,
-
-                       # Use our own logger
-                       logging_name=log.name,
-               )
-
-       def session(self):
-               """
-                       Returns a new database session
-               """
-               return sqlmodel.Session(self.db)
+               # Create a new database connection
+               return database.Database(self, uri)
 
        @functools.cached_property
        def lists(self):
diff --git a/src/dnsbl/database.py b/src/dnsbl/database.py
new file mode 100644 (file)
index 0000000..6319e8e
--- /dev/null
@@ -0,0 +1,105 @@
+###############################################################################
+#                                                                             #
+# dnsbl - A DNS Blacklist Compositor For IPFire                               #
+# Copyright (C) 2025 IPFire Development Team                                  #
+#                                                                             #
+# This program is free software: you can redistribute it and/or modify        #
+# it under the terms of the GNU General Public License as published by        #
+# the Free Software Foundation, either version 3 of the License, or           #
+# (at your option) any later version.                                         #
+#                                                                             #
+# This program is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
+# GNU General Public License for more details.                                #
+#                                                                             #
+# You should have received a copy of the GNU General Public License           #
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
+#                                                                             #
+###############################################################################
+
+import logging
+import sqlmodel
+
+# Setup logging
+log = logging.getLogger(__name__)
+
+class Database(object):
+       def __init__(self, backend, uri):
+               self.backend = backend
+
+               # Connect to the database
+               self.engine = self.connect(uri)
+
+               # Session
+               self.__session = None
+
+       def connect(self, uri):
+               """
+                       Connects to the database
+               """
+               # Create the database engine
+               engine = sqlmodel.create_engine(
+                       uri,
+
+                       # Log more if we are running in debug mode
+                       echo=self.backend.debug,
+
+                       # Use our own logger
+                       logging_name=log.name,
+               )
+
+               return engine
+
+       def session(self):
+               """
+                       Returns the current database session
+               """
+               if self.__session is None:
+                       self.__session = sqlmodel.Session(self.engine)
+
+               return self.__session
+
+       def __enter__(self):
+               return self.session()
+
+       def __exit__(self, type, exception, traceback):
+               session = self.session()
+
+               if exception is None:
+                       session.commit()
+
+       def execute(self, stmt):
+               """
+                       Executes a statement and returns a result object
+               """
+               # Fetch our session
+               session = self.session()
+
+               # Execute the statement
+               return session.execute(stmt)
+
+       def insert(self, cls, **kwargs):
+               """
+                       Inserts a new object into the database
+               """
+               # Fetch our session
+               session = self.session()
+
+               # Create a new object
+               object = cls(**kwargs)
+
+               # Add it to the database
+               session.add(object)
+
+               # Return the object
+               return object
+
+       def fetch_one(self, stmt):
+               result = self.execute(stmt)
+
+               # Apply unique filtering
+               result = result.unique()
+
+               # Return exactly one object or none, but fail otherwise
+               return result.scalar_one_or_none()
index 4bd882013e37e64312d1cc020b6a607830d5b043..40269ac46e501a80a740c3c2138478ee0602ace4 100644 (file)
@@ -41,10 +41,7 @@ class Lists(object):
                        )
                )
 
-               with self.backend.session() as session:
-                       result = session.execute(stmt)
-
-                       return result.scalar_one_or_none()
+               return self.backend.db.fetch_one(stmt)
 
        def _make_slug(self, name):
                i = 0
@@ -66,17 +63,13 @@ class Lists(object):
                slug = self._make_slug(name)
 
                # Create a new list
-               with self.backend.session() as session:
-                       list = List(
-                               name       = name,
-                               slug       = slug,
-                               created_by = created_by,
-                               license    = license,
-                       )
-                       session.add(list)
-                       session.commit()
-
-               return list
+               return self.backend.db.insert(
+                       List,
+                       name       = name,
+                       slug       = slug,
+                       created_by = created_by,
+                       license    = license,
+               )
 
 
 class List(sqlmodel.SQLModel, table=True):
index eda44235aebc38660e0ce58aa4e0eb99e830954e..a771ba76237e816761ccb23d97195bae99e7b994 100644 (file)
@@ -100,7 +100,8 @@ class CLI(object):
                )
 
                # Call the handler function
-               ret = args.func(backend, args)
+               with backend.db:
+                       ret = args.func(backend, args)
 
                # Exit with the returned error code
                if ret: