import configparser
import functools
import logging
-import sqlmodel
# Initialize logging as early as possible
from . import logger
log = logging.getLogger(__name__)
# Import sub-modules
+from . import database
from . import lists
from . import sources
"""
uri = self.config.get("database", "uri")
- # Create the database engine
- return sqlmodel.create_engine(
- uri,
-
- # Log more if we are running in debug mode
- echo=self.debug,
-
- # Use our own logger
- logging_name=log.name,
- )
-
- def session(self):
- """
- Returns a new database session
- """
- return sqlmodel.Session(self.db)
+ # Create a new database connection
+ return database.Database(self, uri)
@functools.cached_property
def lists(self):
--- /dev/null
+###############################################################################
+# #
+# dnsbl - A DNS Blacklist Compositor For IPFire #
+# Copyright (C) 2025 IPFire Development Team #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <http://www.gnu.org/licenses/>. #
+# #
+###############################################################################
+
+import logging
+import sqlmodel
+
+# Setup logging
+log = logging.getLogger(__name__)
+
+class Database(object):
+ def __init__(self, backend, uri):
+ self.backend = backend
+
+ # Connect to the database
+ self.engine = self.connect(uri)
+
+ # Session
+ self.__session = None
+
+ def connect(self, uri):
+ """
+ Connects to the database
+ """
+ # Create the database engine
+ engine = sqlmodel.create_engine(
+ uri,
+
+ # Log more if we are running in debug mode
+ echo=self.backend.debug,
+
+ # Use our own logger
+ logging_name=log.name,
+ )
+
+ return engine
+
+ def session(self):
+ """
+ Returns the current database session
+ """
+ if self.__session is None:
+ self.__session = sqlmodel.Session(self.engine)
+
+ return self.__session
+
+ def __enter__(self):
+ return self.session()
+
+ def __exit__(self, type, exception, traceback):
+ session = self.session()
+
+ if exception is None:
+ session.commit()
+
+ def execute(self, stmt):
+ """
+ Executes a statement and returns a result object
+ """
+ # Fetch our session
+ session = self.session()
+
+ # Execute the statement
+ return session.execute(stmt)
+
+ def insert(self, cls, **kwargs):
+ """
+ Inserts a new object into the database
+ """
+ # Fetch our session
+ session = self.session()
+
+ # Create a new object
+ object = cls(**kwargs)
+
+ # Add it to the database
+ session.add(object)
+
+ # Return the object
+ return object
+
+ def fetch_one(self, stmt):
+ result = self.execute(stmt)
+
+ # Apply unique filtering
+ result = result.unique()
+
+ # Return exactly one object or none, but fail otherwise
+ return result.scalar_one_or_none()
)
)
- with self.backend.session() as session:
- result = session.execute(stmt)
-
- return result.scalar_one_or_none()
+ return self.backend.db.fetch_one(stmt)
def _make_slug(self, name):
i = 0
slug = self._make_slug(name)
# Create a new list
- with self.backend.session() as session:
- list = List(
- name = name,
- slug = slug,
- created_by = created_by,
- license = license,
- )
- session.add(list)
- session.commit()
-
- return list
+ return self.backend.db.insert(
+ List,
+ name = name,
+ slug = slug,
+ created_by = created_by,
+ license = license,
+ )
class List(sqlmodel.SQLModel, table=True):