--- /dev/null
+#!/usr/bin/env python
+
+"""
+ A lightweight wrapper around psycopg2.
+
+ Originally part of the Tornado framework. The tornado.database module
+ is slated for removal in Tornado 3.0, and it is now available separately
+ as torndb.
+"""
+
+import logging
+import psycopg2
+
+log = logging.getLogger("location.database")
+log.propagate = 1
+
+class Connection(object):
+ """
+ A lightweight wrapper around MySQLdb DB-API connections.
+
+ The main value we provide is wrapping rows in a dict/object so that
+ columns can be accessed by name. Typical usage::
+
+ db = torndb.Connection("localhost", "mydatabase")
+ for article in db.query("SELECT * FROM articles"):
+ print article.title
+
+ Cursors are hidden by the implementation, but other than that, the methods
+ are very similar to the DB-API.
+
+ We explicitly set the timezone to UTC and the character encoding to
+ UTF-8 on all connections to avoid time zone and encoding errors.
+ """
+ def __init__(self, host, database, user=None, password=None):
+ self.host = host
+ self.database = database
+
+ self._db = None
+ self._db_args = {
+ "host" : host,
+ "database" : database,
+ "user" : user,
+ "password" : password,
+ }
+
+ try:
+ self.reconnect()
+ except Exception:
+ log.error("Cannot connect to database on %s", self.host, exc_info=True)
+
+ def __del__(self):
+ self.close()
+
+ def close(self):
+ """
+ Closes this database connection.
+ """
+ if getattr(self, "_db", None) is not None:
+ self._db.close()
+ self._db = None
+
+ def reconnect(self):
+ """
+ Closes the existing database connection and re-opens it.
+ """
+ self.close()
+
+ self._db = psycopg2.connect(**self._db_args)
+ self._db.autocommit = True
+
+ # Initialize the timezone setting.
+ self.execute("SET TIMEZONE TO 'UTC'")
+
+ def query(self, query, *parameters, **kwparameters):
+ """
+ Returns a row list for the given query and parameters.
+ """
+ cursor = self._cursor()
+ try:
+ self._execute(cursor, query, parameters, kwparameters)
+ column_names = [d[0] for d in cursor.description]
+ return [Row(zip(column_names, row)) for row in cursor]
+ finally:
+ cursor.close()
+
+ def get(self, query, *parameters, **kwparameters):
+ """
+ Returns the first row returned for the given query.
+ """
+ rows = self.query(query, *parameters, **kwparameters)
+ if not rows:
+ return None
+ elif len(rows) > 1:
+ raise Exception("Multiple rows returned for Database.get() query")
+ else:
+ return rows[0]
+
+ def execute(self, query, *parameters, **kwparameters):
+ """
+ Executes the given query, returning the lastrowid from the query.
+ """
+ return self.execute_lastrowid(query, *parameters, **kwparameters)
+
+ def execute_lastrowid(self, query, *parameters, **kwparameters):
+ """
+ Executes the given query, returning the lastrowid from the query.
+ """
+ cursor = self._cursor()
+ try:
+ self._execute(cursor, query, parameters, kwparameters)
+ return cursor.lastrowid
+ finally:
+ cursor.close()
+
+ def execute_rowcount(self, query, *parameters, **kwparameters):
+ """
+ Executes the given query, returning the rowcount from the query.
+ """
+ cursor = self._cursor()
+ try:
+ self._execute(cursor, query, parameters, kwparameters)
+ return cursor.rowcount
+ finally:
+ cursor.close()
+
+ def executemany(self, query, parameters):
+ """
+ Executes the given query against all the given param sequences.
+
+ We return the lastrowid from the query.
+ """
+ return self.executemany_lastrowid(query, parameters)
+
+ def executemany_lastrowid(self, query, parameters):
+ """
+ Executes the given query against all the given param sequences.
+
+ We return the lastrowid from the query.
+ """
+ cursor = self._cursor()
+ try:
+ cursor.executemany(query, parameters)
+ return cursor.lastrowid
+ finally:
+ cursor.close()
+
+ def executemany_rowcount(self, query, parameters):
+ """
+ Executes the given query against all the given param sequences.
+
+ We return the rowcount from the query.
+ """
+ cursor = self._cursor()
+
+ try:
+ cursor.executemany(query, parameters)
+ return cursor.rowcount
+ finally:
+ cursor.close()
+
+ def _ensure_connected(self):
+ if self._db is None:
+ log.warning("Database connection was lost...")
+
+ self.reconnect()
+
+ def _cursor(self):
+ self._ensure_connected()
+ return self._db.cursor()
+
+ def _execute(self, cursor, query, parameters, kwparameters):
+ log.debug("SQL Query: %s" % (query % (kwparameters or parameters)))
+
+ try:
+ return cursor.execute(query, kwparameters or parameters)
+ except (OperationalError, psycopg2.ProgrammingError):
+ log.error("Error connecting to database on %s", self.host)
+ self.close()
+ raise
+
+ def transaction(self):
+ return Transaction(self)
+
+
+class Row(dict):
+ """A dict that allows for object-like property access syntax."""
+ def __getattr__(self, name):
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+
+class Transaction(object):
+ def __init__(self, db):
+ self.db = db
+
+ self.db.execute("START TRANSACTION")
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exctype, excvalue, traceback):
+ if exctype is not None:
+ self.db.execute("ROLLBACK")
+ else:
+ self.db.execute("COMMIT")
+
+
+# Alias some common exceptions
+IntegrityError = psycopg2.IntegrityError
+OperationalError = psycopg2.OperationalError
# Load our location module
import location
+import location.database
import location.importer
from location.i18n import _
parser.add_argument("--version", action="version",
version="%(prog)s @VERSION@")
+ # Database
+ parser.add_argument("--database-host", required=True,
+ help=_("Database Hostname"), metavar=_("HOST"))
+ parser.add_argument("--database-name", required=True,
+ help=_("Database Name"), metavar=_("NAME"))
+ parser.add_argument("--database-username", required=True,
+ help=_("Database Username"), metavar=_("USERNAME"))
+ parser.add_argument("--database-password", required=True,
+ help=_("Database Password"), metavar=_("PASSWORD"))
+
args = parser.parse_args()
# Enable debug logging
# Parse command line arguments
args = self.parse_cli()
+ # Initialise database
+ db = self._setup_database(args)
+
# Call function
- ret = self.handle_import(args)
+ ret = self.handle_import(db, args)
# Return with exit code
if ret:
# Otherwise just exit
sys.exit(0)
- def handle_import(self, ns):
+ def _setup_database(self, ns):
+ """
+ Initialise the database
+ """
+ # Connect to database
+ db = location.database.Connection(
+ host=ns.database_host, database=ns.database_name,
+ user=ns.database_username, password=ns.database_password,
+ )
+
+ with db.transaction():
+ db.execute("""
+ CREATE TABLE IF NOT EXISTS asnums(number integer, name text);
+ CREATE UNIQUE INDEX IF NOT EXISTS asnums_number ON asnums(number);
+ """)
+
+ return db
+
+ def handle_import(self, db, ns):
pass