+++ /dev/null
-#!/usr/bin/env python
-#
-# Copyright 2009 Facebook
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""A lightweight wrapper around MySQLdb."""
-
-from __future__ import absolute_import, division, with_statement
-
-import copy
-import itertools
-import time
-
-from tornado.log import gen_log
-
-try:
- import MySQLdb.constants
- import MySQLdb.converters
- import MySQLdb.cursors
-except ImportError:
- # If MySQLdb isn't available this module won't actually be useable,
- # but we want it to at least be importable (mainly for readthedocs.org,
- # which has limitations on third-party modules)
- MySQLdb = None
-
-
-class Connection(object):
- """A lightweight wrapper around MySQLdb DB-API connections.
-
- The main value we provide is wrapping rows in a dict/object so that
- columns can be accessed by name. Typical usage::
-
- db = database.Connection("localhost", "mydatabase")
- for article in db.query("SELECT * FROM articles"):
- print article.title
-
- Cursors are hidden by the implementation, but other than that, the methods
- are very similar to the DB-API.
-
- We explicitly set the timezone to UTC and the character encoding to
- UTF-8 on all connections to avoid time zone and encoding errors.
- """
- def __init__(self, host, database, user=None, password=None,
- max_idle_time=7 * 3600):
- self.host = host
- self.database = database
- self.max_idle_time = max_idle_time
-
- args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8",
- db=database, init_command='SET time_zone = "+0:00"',
- sql_mode="TRADITIONAL")
- if user is not None:
- args["user"] = user
- if password is not None:
- args["passwd"] = password
-
- # We accept a path to a MySQL socket file or a host(:port) string
- if "/" in host:
- args["unix_socket"] = host
- else:
- self.socket = None
- pair = host.split(":")
- if len(pair) == 2:
- args["host"] = pair[0]
- args["port"] = int(pair[1])
- else:
- args["host"] = host
- args["port"] = 3306
-
- self._db = None
- self._db_args = args
- self._last_use_time = time.time()
- try:
- self.reconnect()
- except Exception:
- gen_log.error("Cannot connect to MySQL on %s", self.host,
- exc_info=True)
-
- def __del__(self):
- self.close()
-
- def close(self):
- """Closes this database connection."""
- if getattr(self, "_db", None) is not None:
- self._db.close()
- self._db = None
-
- def reconnect(self):
- """Closes the existing database connection and re-opens it."""
- self.close()
- self._db = MySQLdb.connect(**self._db_args)
- self._db.autocommit(True)
-
- def iter(self, query, *parameters):
- """Returns an iterator for the given query and parameters."""
- self._ensure_connected()
- cursor = MySQLdb.cursors.SSCursor(self._db)
- try:
- self._execute(cursor, query, parameters)
- column_names = [d[0] for d in cursor.description]
- for row in cursor:
- yield Row(zip(column_names, row))
- finally:
- cursor.close()
-
- def query(self, query, *parameters):
- """Returns a row list for the given query and parameters."""
- cursor = self._cursor()
- try:
- self._execute(cursor, query, parameters)
- column_names = [d[0] for d in cursor.description]
- return [Row(itertools.izip(column_names, row)) for row in cursor]
- finally:
- cursor.close()
-
- def get(self, query, *parameters):
- """Returns the first row returned for the given query."""
- rows = self.query(query, *parameters)
- if not rows:
- return None
- elif len(rows) > 1:
- raise Exception("Multiple rows returned for Database.get() query")
- else:
- return rows[0]
-
- # rowcount is a more reasonable default return value than lastrowid,
- # but for historical compatibility execute() must return lastrowid.
- def execute(self, query, *parameters):
- """Executes the given query, returning the lastrowid from the query."""
- return self.execute_lastrowid(query, *parameters)
-
- def execute_lastrowid(self, query, *parameters):
- """Executes the given query, returning the lastrowid from the query."""
- cursor = self._cursor()
- try:
- self._execute(cursor, query, parameters)
- return cursor.lastrowid
- finally:
- cursor.close()
-
- def execute_rowcount(self, query, *parameters):
- """Executes the given query, returning the rowcount from the query."""
- cursor = self._cursor()
- try:
- self._execute(cursor, query, parameters)
- return cursor.rowcount
- finally:
- cursor.close()
-
- def executemany(self, query, parameters):
- """Executes the given query against all the given param sequences.
-
- We return the lastrowid from the query.
- """
- return self.executemany_lastrowid(query, parameters)
-
- def executemany_lastrowid(self, query, parameters):
- """Executes the given query against all the given param sequences.
-
- We return the lastrowid from the query.
- """
- cursor = self._cursor()
- try:
- cursor.executemany(query, parameters)
- return cursor.lastrowid
- finally:
- cursor.close()
-
- def executemany_rowcount(self, query, parameters):
- """Executes the given query against all the given param sequences.
-
- We return the rowcount from the query.
- """
- cursor = self._cursor()
- try:
- cursor.executemany(query, parameters)
- return cursor.rowcount
- finally:
- cursor.close()
-
- def _ensure_connected(self):
- # Mysql by default closes client connections that are idle for
- # 8 hours, but the client library does not report this fact until
- # you try to perform a query and it fails. Protect against this
- # case by preemptively closing and reopening the connection
- # if it has been idle for too long (7 hours by default).
- if (self._db is None or
- (time.time() - self._last_use_time > self.max_idle_time)):
- self.reconnect()
- self._last_use_time = time.time()
-
- def _cursor(self):
- self._ensure_connected()
- return self._db.cursor()
-
- def _execute(self, cursor, query, parameters):
- try:
- return cursor.execute(query, parameters)
- except OperationalError:
- gen_log.error("Error connecting to MySQL on %s", self.host)
- self.close()
- raise
-
-
-class Row(dict):
- """A dict that allows for object-like property access syntax."""
- def __getattr__(self, name):
- try:
- return self[name]
- except KeyError:
- raise AttributeError(name)
-
-if MySQLdb is not None:
- # Fix the access conversions to properly recognize unicode/binary
- FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
- FLAG = MySQLdb.constants.FLAG
- CONVERSIONS = copy.copy(MySQLdb.converters.conversions)
-
- field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
- if 'VARCHAR' in vars(FIELD_TYPE):
- field_types.append(FIELD_TYPE.VARCHAR)
-
- for field_type in field_types:
- CONVERSIONS[field_type] = [(FLAG.BINARY, str)] + CONVERSIONS[field_type]
-
- # Alias some common MySQL exceptions
- IntegrityError = MySQLdb.IntegrityError
- OperationalError = MySQLdb.OperationalError