]> git.ipfire.org Git - ipfire.org.git/commitdiff
database: Import wrapper module from PBS
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 26 Jul 2023 14:00:41 +0000 (14:00 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 26 Jul 2023 14:00:41 +0000 (14:00 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/backend/base.py
src/backend/database.py
src/backend/misc.py

index 971c3910099b63250c1710444a433e0bfb4ca47b..40e1c7a5a39933f9c1ddfb576432586bb6c3bf1b 100644 (file)
@@ -97,7 +97,7 @@ class Backend(object):
                        "password" : self.config.get("database", "password"),
                }
 
-               self.db = database.Connection(**credentials)
+               self.db = database.Connection(self, **credentials)
 
        @lazy_property
        def ssl_context(self):
index f79cf1283c68dd994aa2fe1fdc5cf9eaf1852b98..bf3cf108c8775dc9638cba47792037802f176246 100644 (file)
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/python
 
 """
        A lightweight wrapper around psycopg2.
@@ -8,8 +8,17 @@
        as torndb.
 """
 
+import asyncio
+import itertools
 import logging
-import psycopg2
+import psycopg
+import psycopg_pool
+import time
+
+from . import misc
+
+# Setup logging
+log = logging.getLogger("pbs.database")
 
 class Connection(object):
        """
@@ -28,57 +37,111 @@ class Connection(object):
                We explicitly set the timezone to UTC and the character encoding to
                UTF-8 on all connections to avoid time zone and encoding errors.
        """
-       def __init__(self, host, database, user=None, password=None):
-               self.host = host
-               self.database = database
-
-               self._db = None
-               self._db_args = {
-                       "host"     : host,
-                       "database" : database,
-                       "user"     : user,
-                       "password" : password,
-               }
+       def __init__(self, backend, host, database, user=None, password=None):
+               self.backend = backend
 
-               try:
-                       self.reconnect()
-               except Exception:
-                       logging.error("Cannot connect to database on %s", self.host, exc_info=True)
+               # Stores connections assigned to tasks
+               self.__connections = {}
+
+               # Create a connection pool
+               self.pool = psycopg_pool.ConnectionPool(
+                       "postgresql://%s:%s@%s/%s" % (user, password, host, database),
+
+                       # Callback to configure any new connections
+                       configure=self.__configure,
+
+                       # Set limits for min/max connections in the pool
+                       min_size=4,
+                       max_size=128,
+
+                       # Give clients up to one minute to retrieve a connection
+                       timeout=60,
 
-       def __del__(self):
-               self.close()
+                       # Close connections after they have been idle for one minute
+                       max_idle=60,
+               )
 
-       def close(self):
+       def __configure(self, conn):
                """
-                       Closes this database connection.
+                       Configures any newly opened connections
                """
-               if getattr(self, "_db", None) is not None:
-                       self._db.close()
-                       self._db = None
+               # Enable autocommit
+               conn.autocommit = True
 
-       def reconnect(self):
+               # Return any rows as dicts
+               conn.row_factory = psycopg.rows.dict_row
+
+               # Automatically convert DataObjects
+               conn.adapters.register_dumper(misc.Object, misc.ObjectDumper)
+
+       def connection(self, *args, **kwargs):
                """
-                       Closes the existing database connection and re-opens it.
+                       Returns a connection from the pool
                """
-               self.close()
+               # Fetch the current task
+               task = asyncio.current_task()
+
+               assert task, "Could not determine task"
+
+               # Try returning the same connection to the same task
+               try:
+                       return self.__connections[task]
+               except KeyError:
+                       pass
+
+               # Fetch a new connection from the pool
+               conn = self.__connections[task] = self.pool.getconn(*args, **kwargs)
+
+               log.debug("Assigning database connection %s to %s" % (conn, task))
 
-               self._db = psycopg2.connect(**self._db_args)
-               self._db.autocommit = True
+               # When the task finishes, release the connection
+               task.add_done_callback(self.__release_connection)
 
-               # Initialize the timezone setting.
-               self.execute("SET TIMEZONE TO 'UTC'")
+               return conn
+
+       def __release_connection(self, task):
+               # Retrieve the connection
+               try:
+                       conn = self.__connections[task]
+               except KeyError:
+                       return
+
+               log.debug("Releasing database connection %s of %s" % (conn, task))
+
+               # Delete it
+               del self.__connections[task]
+
+               # Return the connection back into the pool
+               self.pool.putconn(conn)
+
+       def _execute(self, cursor, execute, query, parameters):
+               # Store the time we started this query
+               t = time.monotonic()
+
+               try:
+                       log.debug("Running SQL query %s" % (query % parameters))
+               except Exception:
+                       pass
+
+               # Execute the query
+               execute(query, parameters)
+
+               # How long did this take?
+               elapsed = time.monotonic() - t
+
+               # Log the query time
+               log.debug("  Query time: %.2fms" % (elapsed * 1000))
 
        def query(self, query, *parameters, **kwparameters):
                """
                        Returns a row list for the given query and parameters.
                """
-               cursor = self._cursor()
-               try:
-                       self._execute(cursor, query, parameters, kwparameters)
-                       column_names = [d[0] for d in cursor.description]
-                       return [Row(zip(column_names, row)) for row in cursor]
-               finally:
-                       cursor.close()
+               conn = self.connection()
+
+               with conn.cursor() as cursor:
+                       self._execute(cursor, cursor.execute, query, parameters or kwparameters)
+
+                       return [Row(row) for row in cursor]
 
        def get(self, query, *parameters, **kwparameters):
                """
@@ -94,89 +157,48 @@ class Connection(object):
 
        def execute(self, query, *parameters, **kwparameters):
                """
-                       Executes the given query, returning the lastrowid from the query.
+                       Executes the given query.
                """
-               return self.execute_lastrowid(query, *parameters, **kwparameters)
+               conn = self.connection()
 
-       def execute_lastrowid(self, query, *parameters, **kwparameters):
-               """
-                       Executes the given query, returning the lastrowid from the query.
-               """
-               cursor = self._cursor()
-               try:
-                       self._execute(cursor, query, parameters, kwparameters)
-                       return cursor.lastrowid
-               finally:
-                       cursor.close()
-
-       def execute_rowcount(self, query, *parameters, **kwparameters):
-               """
-                       Executes the given query, returning the rowcount from the query.
-               """
-               cursor = self._cursor()
-               try:
-                       self._execute(cursor, query, parameters, kwparameters)
-                       return cursor.rowcount
-               finally:
-                       cursor.close()
+               with conn.cursor() as cursor:
+                       self._execute(cursor, cursor.execute, query, parameters or kwparameters)
 
        def executemany(self, query, parameters):
                """
                        Executes the given query against all the given param sequences.
-
-                       We return the lastrowid from the query.
                """
-               return self.executemany_lastrowid(query, parameters)
+               conn = self.connection()
 
-       def executemany_lastrowid(self, query, parameters):
-               """
-                       Executes the given query against all the given param sequences.
+               with conn.cursor() as cursor:
+                       self._execute(cursor, cursor.executemany, query, parameters)
 
-                       We return the lastrowid from the query.
+       def transaction(self):
                """
-               cursor = self._cursor()
-               try:
-                       cursor.executemany(query, parameters)
-                       return cursor.lastrowid
-               finally:
-                       cursor.close()
-
-       def executemany_rowcount(self, query, parameters):
+                       Creates a new transaction on the current tasks' connection
                """
-                       Executes the given query against all the given param sequences.
+               conn = self.connection()
 
-                       We return the rowcount from the query.
-               """
-               cursor = self._cursor()
+               return conn.transaction()
 
-               try:
-                       cursor.executemany(query, parameters)
-                       return cursor.rowcount
-               finally:
-                       cursor.close()
-
-       def _ensure_connected(self):
-               if self._db is None:
-                       logging.warning("Database connection was lost...")
-
-                       self.reconnect()
-
-       def _cursor(self):
-               self._ensure_connected()
-               return self._db.cursor()
+       def fetch_one(self, cls, query, *args, **kwargs):
+               """
+                       Takes a class and a query and will return one object of that class
+               """
+               # Execute the query
+               res = self.get(query, *args)
 
-       def _execute(self, cursor, query, parameters, kwparameters):
-               logging.debug("SQL Query: %s" % (query % (kwparameters or parameters)))
+               # Return an object (if possible)
+               if res:
+                       return cls(self.backend, res.id, res, **kwargs)
 
-               try:
-                       return cursor.execute(query, kwparameters or parameters)
-               except (OperationalError, psycopg2.ProgrammingError):
-                       logging.error("Error connecting to database on %s", self.host)
-                       self.close()
-                       raise
+       def fetch_many(self, cls, query, *args, **kwargs):
+               # Execute the query
+               res = self.query(query, *args)
 
-       def transaction(self):
-               return Transaction(self)
+               # Return a generator with objects
+               for row in res:
+                       yield cls(self.backend, row.id, row, **kwargs)
 
 
 class Row(dict):
@@ -186,24 +208,3 @@ class Row(dict):
                        return self[name]
                except KeyError:
                        raise AttributeError(name)
-
-
-class Transaction(object):
-       def __init__(self, db):
-               self.db = db
-
-               self.db.execute("START TRANSACTION")
-
-       def __enter__(self):
-               return self
-
-       def __exit__(self, exctype, excvalue, traceback):
-               if exctype is not None:
-                       self.db.execute("ROLLBACK")
-               else:
-                       self.db.execute("COMMIT")
-
-
-# Alias some common exceptions
-IntegrityError = psycopg2.IntegrityError
-OperationalError = psycopg2.OperationalError
index f2f2e7532619fdb3872382fee74b61150f0bd0bb..4960e3070b7bd26cea93310a99fad76f2112de82 100644 (file)
@@ -1,5 +1,7 @@
 #!/usr/bin/python
 
+import psycopg.adapt
+
 class Object(object):
        def __init__(self, backend, *args, **kwargs):
                self.backend = backend
@@ -39,3 +41,11 @@ class Object(object):
        @property
        def settings(self):
                return self.backend.settings
+
+
+# SQL Integration
+
+class ObjectDumper(psycopg.adapt.Dumper):
+       def dump(self, obj):
+               # Return the ID (as bytes)
+               return bytes("%s" % obj.id, "utf-8")