From: Michael Tremer Date: Wed, 26 Jul 2023 14:00:41 +0000 (+0000) Subject: database: Import wrapper module from PBS X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=3082f0e9ed5e5aa02a2e2bcd1ee747169f314bd0;p=ipfire.org.git database: Import wrapper module from PBS Signed-off-by: Michael Tremer --- diff --git a/src/backend/base.py b/src/backend/base.py index 971c3910..40e1c7a5 100644 --- a/src/backend/base.py +++ b/src/backend/base.py @@ -97,7 +97,7 @@ class Backend(object): "password" : self.config.get("database", "password"), } - self.db = database.Connection(**credentials) + self.db = database.Connection(self, **credentials) @lazy_property def ssl_context(self): diff --git a/src/backend/database.py b/src/backend/database.py index f79cf128..bf3cf108 100644 --- a/src/backend/database.py +++ b/src/backend/database.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/python """ A lightweight wrapper around psycopg2. @@ -8,8 +8,17 @@ as torndb. """ +import asyncio +import itertools import logging -import psycopg2 +import psycopg +import psycopg_pool +import time + +from . import misc + +# Setup logging +log = logging.getLogger("pbs.database") class Connection(object): """ @@ -28,57 +37,111 @@ class Connection(object): We explicitly set the timezone to UTC and the character encoding to UTF-8 on all connections to avoid time zone and encoding errors. """ - def __init__(self, host, database, user=None, password=None): - self.host = host - self.database = database - - self._db = None - self._db_args = { - "host" : host, - "database" : database, - "user" : user, - "password" : password, - } + def __init__(self, backend, host, database, user=None, password=None): + self.backend = backend - try: - self.reconnect() - except Exception: - logging.error("Cannot connect to database on %s", self.host, exc_info=True) + # Stores connections assigned to tasks + self.__connections = {} + + # Create a connection pool + self.pool = psycopg_pool.ConnectionPool( + "postgresql://%s:%s@%s/%s" % (user, password, host, database), + + # Callback to configure any new connections + configure=self.__configure, + + # Set limits for min/max connections in the pool + min_size=4, + max_size=128, + + # Give clients up to one minute to retrieve a connection + timeout=60, - def __del__(self): - self.close() + # Close connections after they have been idle for one minute + max_idle=60, + ) - def close(self): + def __configure(self, conn): """ - Closes this database connection. + Configures any newly opened connections """ - if getattr(self, "_db", None) is not None: - self._db.close() - self._db = None + # Enable autocommit + conn.autocommit = True - def reconnect(self): + # Return any rows as dicts + conn.row_factory = psycopg.rows.dict_row + + # Automatically convert DataObjects + conn.adapters.register_dumper(misc.Object, misc.ObjectDumper) + + def connection(self, *args, **kwargs): """ - Closes the existing database connection and re-opens it. + Returns a connection from the pool """ - self.close() + # Fetch the current task + task = asyncio.current_task() + + assert task, "Could not determine task" + + # Try returning the same connection to the same task + try: + return self.__connections[task] + except KeyError: + pass + + # Fetch a new connection from the pool + conn = self.__connections[task] = self.pool.getconn(*args, **kwargs) + + log.debug("Assigning database connection %s to %s" % (conn, task)) - self._db = psycopg2.connect(**self._db_args) - self._db.autocommit = True + # When the task finishes, release the connection + task.add_done_callback(self.__release_connection) - # Initialize the timezone setting. - self.execute("SET TIMEZONE TO 'UTC'") + return conn + + def __release_connection(self, task): + # Retrieve the connection + try: + conn = self.__connections[task] + except KeyError: + return + + log.debug("Releasing database connection %s of %s" % (conn, task)) + + # Delete it + del self.__connections[task] + + # Return the connection back into the pool + self.pool.putconn(conn) + + def _execute(self, cursor, execute, query, parameters): + # Store the time we started this query + t = time.monotonic() + + try: + log.debug("Running SQL query %s" % (query % parameters)) + except Exception: + pass + + # Execute the query + execute(query, parameters) + + # How long did this take? + elapsed = time.monotonic() - t + + # Log the query time + log.debug(" Query time: %.2fms" % (elapsed * 1000)) def query(self, query, *parameters, **kwparameters): """ Returns a row list for the given query and parameters. """ - cursor = self._cursor() - try: - self._execute(cursor, query, parameters, kwparameters) - column_names = [d[0] for d in cursor.description] - return [Row(zip(column_names, row)) for row in cursor] - finally: - cursor.close() + conn = self.connection() + + with conn.cursor() as cursor: + self._execute(cursor, cursor.execute, query, parameters or kwparameters) + + return [Row(row) for row in cursor] def get(self, query, *parameters, **kwparameters): """ @@ -94,89 +157,48 @@ class Connection(object): def execute(self, query, *parameters, **kwparameters): """ - Executes the given query, returning the lastrowid from the query. + Executes the given query. """ - return self.execute_lastrowid(query, *parameters, **kwparameters) + conn = self.connection() - def execute_lastrowid(self, query, *parameters, **kwparameters): - """ - Executes the given query, returning the lastrowid from the query. - """ - cursor = self._cursor() - try: - self._execute(cursor, query, parameters, kwparameters) - return cursor.lastrowid - finally: - cursor.close() - - def execute_rowcount(self, query, *parameters, **kwparameters): - """ - Executes the given query, returning the rowcount from the query. - """ - cursor = self._cursor() - try: - self._execute(cursor, query, parameters, kwparameters) - return cursor.rowcount - finally: - cursor.close() + with conn.cursor() as cursor: + self._execute(cursor, cursor.execute, query, parameters or kwparameters) def executemany(self, query, parameters): """ Executes the given query against all the given param sequences. - - We return the lastrowid from the query. """ - return self.executemany_lastrowid(query, parameters) + conn = self.connection() - def executemany_lastrowid(self, query, parameters): - """ - Executes the given query against all the given param sequences. + with conn.cursor() as cursor: + self._execute(cursor, cursor.executemany, query, parameters) - We return the lastrowid from the query. + def transaction(self): """ - cursor = self._cursor() - try: - cursor.executemany(query, parameters) - return cursor.lastrowid - finally: - cursor.close() - - def executemany_rowcount(self, query, parameters): + Creates a new transaction on the current tasks' connection """ - Executes the given query against all the given param sequences. + conn = self.connection() - We return the rowcount from the query. - """ - cursor = self._cursor() + return conn.transaction() - try: - cursor.executemany(query, parameters) - return cursor.rowcount - finally: - cursor.close() - - def _ensure_connected(self): - if self._db is None: - logging.warning("Database connection was lost...") - - self.reconnect() - - def _cursor(self): - self._ensure_connected() - return self._db.cursor() + def fetch_one(self, cls, query, *args, **kwargs): + """ + Takes a class and a query and will return one object of that class + """ + # Execute the query + res = self.get(query, *args) - def _execute(self, cursor, query, parameters, kwparameters): - logging.debug("SQL Query: %s" % (query % (kwparameters or parameters))) + # Return an object (if possible) + if res: + return cls(self.backend, res.id, res, **kwargs) - try: - return cursor.execute(query, kwparameters or parameters) - except (OperationalError, psycopg2.ProgrammingError): - logging.error("Error connecting to database on %s", self.host) - self.close() - raise + def fetch_many(self, cls, query, *args, **kwargs): + # Execute the query + res = self.query(query, *args) - def transaction(self): - return Transaction(self) + # Return a generator with objects + for row in res: + yield cls(self.backend, row.id, row, **kwargs) class Row(dict): @@ -186,24 +208,3 @@ class Row(dict): return self[name] except KeyError: raise AttributeError(name) - - -class Transaction(object): - def __init__(self, db): - self.db = db - - self.db.execute("START TRANSACTION") - - def __enter__(self): - return self - - def __exit__(self, exctype, excvalue, traceback): - if exctype is not None: - self.db.execute("ROLLBACK") - else: - self.db.execute("COMMIT") - - -# Alias some common exceptions -IntegrityError = psycopg2.IntegrityError -OperationalError = psycopg2.OperationalError diff --git a/src/backend/misc.py b/src/backend/misc.py index f2f2e753..4960e307 100644 --- a/src/backend/misc.py +++ b/src/backend/misc.py @@ -1,5 +1,7 @@ #!/usr/bin/python +import psycopg.adapt + class Object(object): def __init__(self, backend, *args, **kwargs): self.backend = backend @@ -39,3 +41,11 @@ class Object(object): @property def settings(self): return self.backend.settings + + +# SQL Integration + +class ObjectDumper(psycopg.adapt.Dumper): + def dump(self, obj): + # Return the ID (as bytes) + return bytes("%s" % obj.id, "utf-8")