]> git.ipfire.org Git - people/ms/libloc.git/blobdiff - src/python/location/database.py
database: Migrate to psycopg3
[people/ms/libloc.git] / src / python / location / database.py
index 2c93ed092dd1c1085c937663a8a97ff20592b03e..82b2bfcca412c903e72581f53c9d14fcf54c660a 100644 (file)
@@ -1,87 +1,53 @@
 """
-       A lightweight wrapper around psycopg2.
-
-       Originally part of the Tornado framework.  The tornado.database module
-       is slated for removal in Tornado 3.0, and it is now available separately
-       as torndb.
+       A lightweight wrapper around psycopg3.
 """
 
 import logging
-import psycopg2
+import psycopg
 import time
 
+# Setup logging
 log = logging.getLogger("location.database")
-log.propagate = 1
 
 class Connection(object):
-       """
-               A lightweight wrapper around MySQLdb DB-API connections.
-
-               The main value we provide is wrapping rows in a dict/object so that
-               columns can be accessed by name. Typical usage::
-
-                       db = torndb.Connection("localhost", "mydatabase")
-                       for article in db.query("SELECT * FROM articles"):
-                               print article.title
-
-               Cursors are hidden by the implementation, but other than that, the methods
-               are very similar to the DB-API.
-
-               We explicitly set the timezone to UTC and the character encoding to
-               UTF-8 on all connections to avoid time zone and encoding errors.
-       """
        def __init__(self, host, database, user=None, password=None):
-               self.host = host
-               self.database = database
-
-               self._db = None
-               self._db_args = {
-                       "host"     : host,
-                       "database" : database,
-                       "user"     : user,
-                       "password" : password,
-                       "sslmode"  : "require",
-               }
+               # Create a connection pool
+               self.connection = psycopg.connect(
+                       "postgresql://%s:%s@%s/%s" % (user, password, host, database),
 
-               try:
-                       self.reconnect()
-               except Exception:
-                       log.error("Cannot connect to database on %s", self.host, exc_info=True)
+                       # Enable autocommit
+                       autocommit=True,
+
+                       # Return any rows as dicts
+                       row_factory = psycopg.rows.dict_row,
+               )
 
-       def __del__(self):
-               self.close()
+       def _execute(self, cursor, execute, query, parameters):
+               # Store the time we started this query
+               #t = time.monotonic()
 
-       def close(self):
-               """
-                       Closes this database connection.
-               """
-               if getattr(self, "_db", None) is not None:
-                       self._db.close()
-                       self._db = None
+               #try:
+               #       log.debug("Running SQL query %s" % (query % parameters))
+               #except Exception:
+               #       pass
 
-       def reconnect(self):
-               """
-                       Closes the existing database connection and re-opens it.
-               """
-               self.close()
+               # Execute the query
+               execute(query, parameters)
 
-               self._db = psycopg2.connect(**self._db_args)
-               self._db.autocommit = True
+               # How long did this take?
+               #elapsed = time.monotonic() - t
 
-               # Initialize the timezone setting.
-               self.execute("SET TIMEZONE TO 'UTC'")
+               # Log the query time
+               #log.debug("  Query time: %.2fms" % (elapsed * 1000))
 
        def query(self, query, *parameters, **kwparameters):
                """
                        Returns a row list for the given query and parameters.
                """
-               cursor = self._cursor()
-               try:
-                       self._execute(cursor, query, parameters, kwparameters)
-                       column_names = [d[0] for d in cursor.description]
-                       return [Row(zip(column_names, row)) for row in cursor]
-               finally:
-                       cursor.close()
+               with self.connection.cursor() as cursor:
+                       self._execute(cursor, cursor.execute, query, parameters or kwparameters)
+
+                       return [Row(row) for row in cursor]
 
        def get(self, query, *parameters, **kwparameters):
                """
@@ -97,104 +63,23 @@ class Connection(object):
 
        def execute(self, query, *parameters, **kwparameters):
                """
-                       Executes the given query, returning the lastrowid from the query.
+                       Executes the given query.
                """
-               return self.execute_lastrowid(query, *parameters, **kwparameters)
-
-       def execute_lastrowid(self, query, *parameters, **kwparameters):
-               """
-                       Executes the given query, returning the lastrowid from the query.
-               """
-               cursor = self._cursor()
-               try:
-                       self._execute(cursor, query, parameters, kwparameters)
-                       return cursor.lastrowid
-               finally:
-                       cursor.close()
-
-       def execute_rowcount(self, query, *parameters, **kwparameters):
-               """
-                       Executes the given query, returning the rowcount from the query.
-               """
-               cursor = self._cursor()
-               try:
-                       self._execute(cursor, query, parameters, kwparameters)
-                       return cursor.rowcount
-               finally:
-                       cursor.close()
+               with self.connection.cursor() as cursor:
+                       self._execute(cursor, cursor.execute, query, parameters or kwparameters)
 
        def executemany(self, query, parameters):
                """
                        Executes the given query against all the given param sequences.
-
-                       We return the lastrowid from the query.
                """
-               return self.executemany_lastrowid(query, parameters)
+               with self.connection.cursor() as cursor:
+                       self._execute(cursor, cursor.executemany, query, parameters)
 
-       def executemany_lastrowid(self, query, parameters):
-               """
-                       Executes the given query against all the given param sequences.
-
-                       We return the lastrowid from the query.
-               """
-               cursor = self._cursor()
-               try:
-                       cursor.executemany(query, parameters)
-                       return cursor.lastrowid
-               finally:
-                       cursor.close()
-
-       def executemany_rowcount(self, query, parameters):
+       def transaction(self):
                """
-                       Executes the given query against all the given param sequences.
-
-                       We return the rowcount from the query.
+                       Creates a new transaction on the current tasks' connection
                """
-               cursor = self._cursor()
-
-               try:
-                       cursor.executemany(query, parameters)
-                       return cursor.rowcount
-               finally:
-                       cursor.close()
-
-       def _ensure_connected(self):
-               if self._db is None:
-                       log.warning("Database connection was lost...")
-
-                       self.reconnect()
-
-       def _cursor(self):
-               self._ensure_connected()
-               return self._db.cursor()
-
-       def _execute(self, cursor, query, parameters, kwparameters):
-               log.debug(
-                               "Executing query: %s" % \
-                                               cursor.mogrify(query, kwparameters or parameters).decode(),
-               )
-
-               # Store the time when the query started
-               t = time.monotonic()
-
-               try:
-                       return cursor.execute(query, kwparameters or parameters)
-
-               # Catch any errors
-               except OperationalError:
-                       log.error("Error connecting to database on %s", self.host)
-                       self.close()
-                       raise
-
-               # Log how long the query took
-               finally:
-                       # Determine duration the query took
-                       d = time.monotonic() - t
-
-                       log.debug("Query took %.2fms" % (d * 1000.0))
-
-       def transaction(self):
-               return Transaction(self)
+               return self.connection.transaction()
 
 
 class Row(dict):
@@ -204,24 +89,3 @@ class Row(dict):
                        return self[name]
                except KeyError:
                        raise AttributeError(name)
-
-
-class Transaction(object):
-       def __init__(self, db):
-               self.db = db
-
-               self.db.execute("START TRANSACTION")
-
-       def __enter__(self):
-               return self
-
-       def __exit__(self, exctype, excvalue, traceback):
-               if exctype is not None:
-                       self.db.execute("ROLLBACK")
-               else:
-                       self.db.execute("COMMIT")
-
-
-# Alias some common exceptions
-IntegrityError = psycopg2.IntegrityError
-OperationalError = psycopg2.OperationalError