-#!/usr/bin/env python
+#!/usr/bin/python
"""
A lightweight wrapper around psycopg2.
as torndb.
"""
+import asyncio
+import itertools
import logging
-import psycopg2
+import psycopg
+import psycopg_pool
+import time
+
+from . import misc
+
+# Setup logging
+log = logging.getLogger("pbs.database")
class Connection(object):
"""
We explicitly set the timezone to UTC and the character encoding to
UTF-8 on all connections to avoid time zone and encoding errors.
"""
- def __init__(self, host, database, user=None, password=None):
- self.host = host
- self.database = database
-
- self._db = None
- self._db_args = {
- "host" : host,
- "database" : database,
- "user" : user,
- "password" : password,
- }
+ def __init__(self, backend, host, database, user=None, password=None):
+ self.backend = backend
- try:
- self.reconnect()
- except Exception:
- logging.error("Cannot connect to database on %s", self.host, exc_info=True)
+ # Stores connections assigned to tasks
+ self.__connections = {}
+
+ # Create a connection pool
+ self.pool = psycopg_pool.ConnectionPool(
+ "postgresql://%s:%s@%s/%s" % (user, password, host, database),
+
+ # Callback to configure any new connections
+ configure=self.__configure,
+
+ # Set limits for min/max connections in the pool
+ min_size=4,
+ max_size=128,
+
+ # Give clients up to one minute to retrieve a connection
+ timeout=60,
- def __del__(self):
- self.close()
+ # Close connections after they have been idle for one minute
+ max_idle=60,
+ )
- def close(self):
+ def __configure(self, conn):
"""
- Closes this database connection.
+ Configures any newly opened connections
"""
- if getattr(self, "_db", None) is not None:
- self._db.close()
- self._db = None
+ # Enable autocommit
+ conn.autocommit = True
- def reconnect(self):
+ # Return any rows as dicts
+ conn.row_factory = psycopg.rows.dict_row
+
+ # Automatically convert DataObjects
+ conn.adapters.register_dumper(misc.Object, misc.ObjectDumper)
+
+ def connection(self, *args, **kwargs):
"""
- Closes the existing database connection and re-opens it.
+ Returns a connection from the pool
"""
- self.close()
+ # Fetch the current task
+ task = asyncio.current_task()
+
+ assert task, "Could not determine task"
+
+ # Try returning the same connection to the same task
+ try:
+ return self.__connections[task]
+ except KeyError:
+ pass
+
+ # Fetch a new connection from the pool
+ conn = self.__connections[task] = self.pool.getconn(*args, **kwargs)
+
+ log.debug("Assigning database connection %s to %s" % (conn, task))
- self._db = psycopg2.connect(**self._db_args)
- self._db.autocommit = True
+ # When the task finishes, release the connection
+ task.add_done_callback(self.__release_connection)
- # Initialize the timezone setting.
- self.execute("SET TIMEZONE TO 'UTC'")
+ return conn
+
+ def __release_connection(self, task):
+ # Retrieve the connection
+ try:
+ conn = self.__connections[task]
+ except KeyError:
+ return
+
+ log.debug("Releasing database connection %s of %s" % (conn, task))
+
+ # Delete it
+ del self.__connections[task]
+
+ # Return the connection back into the pool
+ self.pool.putconn(conn)
+
+ def _execute(self, cursor, execute, query, parameters):
+ # Store the time we started this query
+ t = time.monotonic()
+
+ try:
+ log.debug("Running SQL query %s" % (query % parameters))
+ except Exception:
+ pass
+
+ # Execute the query
+ execute(query, parameters)
+
+ # How long did this take?
+ elapsed = time.monotonic() - t
+
+ # Log the query time
+ log.debug(" Query time: %.2fms" % (elapsed * 1000))
def query(self, query, *parameters, **kwparameters):
"""
Returns a row list for the given query and parameters.
"""
- cursor = self._cursor()
- try:
- self._execute(cursor, query, parameters, kwparameters)
- column_names = [d[0] for d in cursor.description]
- return [Row(zip(column_names, row)) for row in cursor]
- finally:
- cursor.close()
+ conn = self.connection()
+
+ with conn.cursor() as cursor:
+ self._execute(cursor, cursor.execute, query, parameters or kwparameters)
+
+ return [Row(row) for row in cursor]
def get(self, query, *parameters, **kwparameters):
"""
def execute(self, query, *parameters, **kwparameters):
"""
- Executes the given query, returning the lastrowid from the query.
+ Executes the given query.
"""
- return self.execute_lastrowid(query, *parameters, **kwparameters)
+ conn = self.connection()
- def execute_lastrowid(self, query, *parameters, **kwparameters):
- """
- Executes the given query, returning the lastrowid from the query.
- """
- cursor = self._cursor()
- try:
- self._execute(cursor, query, parameters, kwparameters)
- return cursor.lastrowid
- finally:
- cursor.close()
-
- def execute_rowcount(self, query, *parameters, **kwparameters):
- """
- Executes the given query, returning the rowcount from the query.
- """
- cursor = self._cursor()
- try:
- self._execute(cursor, query, parameters, kwparameters)
- return cursor.rowcount
- finally:
- cursor.close()
+ with conn.cursor() as cursor:
+ self._execute(cursor, cursor.execute, query, parameters or kwparameters)
def executemany(self, query, parameters):
"""
Executes the given query against all the given param sequences.
-
- We return the lastrowid from the query.
"""
- return self.executemany_lastrowid(query, parameters)
+ conn = self.connection()
- def executemany_lastrowid(self, query, parameters):
- """
- Executes the given query against all the given param sequences.
+ with conn.cursor() as cursor:
+ self._execute(cursor, cursor.executemany, query, parameters)
- We return the lastrowid from the query.
+ def transaction(self):
"""
- cursor = self._cursor()
- try:
- cursor.executemany(query, parameters)
- return cursor.lastrowid
- finally:
- cursor.close()
-
- def executemany_rowcount(self, query, parameters):
+ Creates a new transaction on the current tasks' connection
"""
- Executes the given query against all the given param sequences.
+ conn = self.connection()
- We return the rowcount from the query.
- """
- cursor = self._cursor()
+ return conn.transaction()
- try:
- cursor.executemany(query, parameters)
- return cursor.rowcount
- finally:
- cursor.close()
-
- def _ensure_connected(self):
- if self._db is None:
- logging.warning("Database connection was lost...")
-
- self.reconnect()
-
- def _cursor(self):
- self._ensure_connected()
- return self._db.cursor()
+ def fetch_one(self, cls, query, *args, **kwargs):
+ """
+ Takes a class and a query and will return one object of that class
+ """
+ # Execute the query
+ res = self.get(query, *args)
- def _execute(self, cursor, query, parameters, kwparameters):
- logging.debug("SQL Query: %s" % (query % (kwparameters or parameters)))
+ # Return an object (if possible)
+ if res:
+ return cls(self.backend, res.id, res, **kwargs)
- try:
- return cursor.execute(query, kwparameters or parameters)
- except (OperationalError, psycopg2.ProgrammingError):
- logging.error("Error connecting to database on %s", self.host)
- self.close()
- raise
+ def fetch_many(self, cls, query, *args, **kwargs):
+ # Execute the query
+ res = self.query(query, *args)
- def transaction(self):
- return Transaction(self)
+ # Return a generator with objects
+ for row in res:
+ yield cls(self.backend, row.id, row, **kwargs)
class Row(dict):
return self[name]
except KeyError:
raise AttributeError(name)
-
-
-class Transaction(object):
- def __init__(self, db):
- self.db = db
-
- self.db.execute("START TRANSACTION")
-
- def __enter__(self):
- return self
-
- def __exit__(self, exctype, excvalue, traceback):
- if exctype is not None:
- self.db.execute("ROLLBACK")
- else:
- self.db.execute("COMMIT")
-
-
-# Alias some common exceptions
-IntegrityError = psycopg2.IntegrityError
-OperationalError = psycopg2.OperationalError