]> git.ipfire.org Git - people/ms/libloc.git/commitdiff
python: Add database driver for PostgreSQL
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 12 May 2020 11:42:54 +0000 (11:42 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 12 May 2020 11:42:54 +0000 (11:42 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/python/database.py [new file with mode: 0644]
src/python/location-importer.in
src/python/logger.py

index efb8288a455a5a5f6e4ebea4f3a4bf49270cf9ec..af70a4ea732a5953b9a271eabfd8c9d21c4ae309 100644 (file)
@@ -154,6 +154,7 @@ CLEANFILES += \
 
 dist_pkgpython_PYTHON = \
        src/python/__init__.py \
+       src/python/database.py \
        src/python/i18n.py \
        src/python/importer.py \
        src/python/logger.py
diff --git a/src/python/database.py b/src/python/database.py
new file mode 100644 (file)
index 0000000..4aec8cf
--- /dev/null
@@ -0,0 +1,212 @@
+#!/usr/bin/env python
+
+"""
+       A lightweight wrapper around psycopg2.
+
+       Originally part of the Tornado framework.  The tornado.database module
+       is slated for removal in Tornado 3.0, and it is now available separately
+       as torndb.
+"""
+
+import logging
+import psycopg2
+
+log = logging.getLogger("location.database")
+log.propagate = 1
+
+class Connection(object):
+       """
+               A lightweight wrapper around MySQLdb DB-API connections.
+
+               The main value we provide is wrapping rows in a dict/object so that
+               columns can be accessed by name. Typical usage::
+
+                       db = torndb.Connection("localhost", "mydatabase")
+                       for article in db.query("SELECT * FROM articles"):
+                               print article.title
+
+               Cursors are hidden by the implementation, but other than that, the methods
+               are very similar to the DB-API.
+
+               We explicitly set the timezone to UTC and the character encoding to
+               UTF-8 on all connections to avoid time zone and encoding errors.
+       """
+       def __init__(self, host, database, user=None, password=None):
+               self.host = host
+               self.database = database
+
+               self._db = None
+               self._db_args = {
+                       "host"     : host,
+                       "database" : database,
+                       "user"     : user,
+                       "password" : password,
+               }
+
+               try:
+                       self.reconnect()
+               except Exception:
+                       log.error("Cannot connect to database on %s", self.host, exc_info=True)
+
+       def __del__(self):
+               self.close()
+
+       def close(self):
+               """
+                       Closes this database connection.
+               """
+               if getattr(self, "_db", None) is not None:
+                       self._db.close()
+                       self._db = None
+
+       def reconnect(self):
+               """
+                       Closes the existing database connection and re-opens it.
+               """
+               self.close()
+
+               self._db = psycopg2.connect(**self._db_args)
+               self._db.autocommit = True
+
+               # Initialize the timezone setting.
+               self.execute("SET TIMEZONE TO 'UTC'")
+
+       def query(self, query, *parameters, **kwparameters):
+               """
+                       Returns a row list for the given query and parameters.
+               """
+               cursor = self._cursor()
+               try:
+                       self._execute(cursor, query, parameters, kwparameters)
+                       column_names = [d[0] for d in cursor.description]
+                       return [Row(zip(column_names, row)) for row in cursor]
+               finally:
+                       cursor.close()
+
+       def get(self, query, *parameters, **kwparameters):
+               """
+                       Returns the first row returned for the given query.
+               """
+               rows = self.query(query, *parameters, **kwparameters)
+               if not rows:
+                       return None
+               elif len(rows) > 1:
+                       raise Exception("Multiple rows returned for Database.get() query")
+               else:
+                       return rows[0]
+
+       def execute(self, query, *parameters, **kwparameters):
+               """
+                       Executes the given query, returning the lastrowid from the query.
+               """
+               return self.execute_lastrowid(query, *parameters, **kwparameters)
+
+       def execute_lastrowid(self, query, *parameters, **kwparameters):
+               """
+                       Executes the given query, returning the lastrowid from the query.
+               """
+               cursor = self._cursor()
+               try:
+                       self._execute(cursor, query, parameters, kwparameters)
+                       return cursor.lastrowid
+               finally:
+                       cursor.close()
+
+       def execute_rowcount(self, query, *parameters, **kwparameters):
+               """
+                       Executes the given query, returning the rowcount from the query.
+               """
+               cursor = self._cursor()
+               try:
+                       self._execute(cursor, query, parameters, kwparameters)
+                       return cursor.rowcount
+               finally:
+                       cursor.close()
+
+       def executemany(self, query, parameters):
+               """
+                       Executes the given query against all the given param sequences.
+
+                       We return the lastrowid from the query.
+               """
+               return self.executemany_lastrowid(query, parameters)
+
+       def executemany_lastrowid(self, query, parameters):
+               """
+                       Executes the given query against all the given param sequences.
+
+                       We return the lastrowid from the query.
+               """
+               cursor = self._cursor()
+               try:
+                       cursor.executemany(query, parameters)
+                       return cursor.lastrowid
+               finally:
+                       cursor.close()
+
+       def executemany_rowcount(self, query, parameters):
+               """
+                       Executes the given query against all the given param sequences.
+
+                       We return the rowcount from the query.
+               """
+               cursor = self._cursor()
+
+               try:
+                       cursor.executemany(query, parameters)
+                       return cursor.rowcount
+               finally:
+                       cursor.close()
+
+       def _ensure_connected(self):
+               if self._db is None:
+                       log.warning("Database connection was lost...")
+
+                       self.reconnect()
+
+       def _cursor(self):
+               self._ensure_connected()
+               return self._db.cursor()
+
+       def _execute(self, cursor, query, parameters, kwparameters):
+               log.debug("SQL Query: %s" % (query % (kwparameters or parameters)))
+
+               try:
+                       return cursor.execute(query, kwparameters or parameters)
+               except (OperationalError, psycopg2.ProgrammingError):
+                       log.error("Error connecting to database on %s", self.host)
+                       self.close()
+                       raise
+
+       def transaction(self):
+               return Transaction(self)
+
+
+class Row(dict):
+       """A dict that allows for object-like property access syntax."""
+       def __getattr__(self, name):
+               try:
+                       return self[name]
+               except KeyError:
+                       raise AttributeError(name)
+
+
+class Transaction(object):
+       def __init__(self, db):
+               self.db = db
+
+               self.db.execute("START TRANSACTION")
+
+       def __enter__(self):
+               return self
+
+       def __exit__(self, exctype, excvalue, traceback):
+               if exctype is not None:
+                       self.db.execute("ROLLBACK")
+               else:
+                       self.db.execute("COMMIT")
+
+
+# Alias some common exceptions
+IntegrityError = psycopg2.IntegrityError
+OperationalError = psycopg2.OperationalError
index 1ced867db8cb6f7b3acaae4c1944be683af15d13..564ada9700cc7b9222cc96b19e6f5009dd24bd45 100644 (file)
@@ -23,6 +23,7 @@ import sys
 
 # Load our location module
 import location
+import location.database
 import location.importer
 from location.i18n import _
 
@@ -44,6 +45,16 @@ class CLI(object):
                parser.add_argument("--version", action="version",
                        version="%(prog)s @VERSION@")
 
+               # Database
+               parser.add_argument("--database-host", required=True,
+                       help=_("Database Hostname"), metavar=_("HOST"))
+               parser.add_argument("--database-name", required=True,
+                       help=_("Database Name"), metavar=_("NAME"))
+               parser.add_argument("--database-username", required=True,
+                       help=_("Database Username"), metavar=_("USERNAME"))
+               parser.add_argument("--database-password", required=True,
+                       help=_("Database Password"), metavar=_("PASSWORD"))
+
                args = parser.parse_args()
 
                # Enable debug logging
@@ -56,8 +67,11 @@ class CLI(object):
                # Parse command line arguments
                args = self.parse_cli()
 
+               # Initialise database
+               db = self._setup_database(args)
+
                # Call function
-               ret = self.handle_import(args)
+               ret = self.handle_import(db, args)
 
                # Return with exit code
                if ret:
@@ -66,7 +80,25 @@ class CLI(object):
                # Otherwise just exit
                sys.exit(0)
 
-       def handle_import(self, ns):
+       def _setup_database(self, ns):
+               """
+                       Initialise the database
+               """
+               # Connect to database
+               db = location.database.Connection(
+                       host=ns.database_host, database=ns.database_name,
+                       user=ns.database_username, password=ns.database_password,
+               )
+
+               with db.transaction():
+                       db.execute("""
+                               CREATE TABLE IF NOT EXISTS asnums(number integer, name text);
+                               CREATE UNIQUE INDEX IF NOT EXISTS asnums_number ON asnums(number);
+                       """)
+
+               return db
+
+       def handle_import(self, db, ns):
                pass
 
 
index a8f3b59a607baa434e3f746df2389df25796be0d..18d8123a2850d9308ba745fb736fa8f38f6c76f2 100644 (file)
@@ -22,7 +22,7 @@ import logging.handlers
 
 # Initialise root logger
 log = logging.getLogger("location")
-log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
 
 # Log to console
 handler = logging.StreamHandler()