]>
| Commit | Line | Data |
|---|---|---|
| 1 | #!/usr/bin/python | |
| 2 | ||
| 3 | """ | |
| 4 | A lightweight wrapper around psycopg2. | |
| 5 | ||
| 6 | Originally part of the Tornado framework. The tornado.database module | |
| 7 | is slated for removal in Tornado 3.0, and it is now available separately | |
| 8 | as torndb. | |
| 9 | """ | |
| 10 | ||
| 11 | import asyncio | |
| 12 | import itertools | |
| 13 | import logging | |
| 14 | import psycopg | |
| 15 | import psycopg_pool | |
| 16 | import time | |
| 17 | ||
| 18 | from . import misc | |
| 19 | ||
| 20 | # Setup logging | |
| 21 | log = logging.getLogger() | |
| 22 | ||
| 23 | class Connection(object): | |
| 24 | """ | |
| 25 | A lightweight wrapper around MySQLdb DB-API connections. | |
| 26 | ||
| 27 | The main value we provide is wrapping rows in a dict/object so that | |
| 28 | columns can be accessed by name. Typical usage:: | |
| 29 | ||
| 30 | db = torndb.Connection("localhost", "mydatabase") | |
| 31 | for article in db.query("SELECT * FROM articles"): | |
| 32 | print article.title | |
| 33 | ||
| 34 | Cursors are hidden by the implementation, but other than that, the methods | |
| 35 | are very similar to the DB-API. | |
| 36 | ||
| 37 | We explicitly set the timezone to UTC and the character encoding to | |
| 38 | UTF-8 on all connections to avoid time zone and encoding errors. | |
| 39 | """ | |
| 40 | def __init__(self, backend, host, database, user=None, password=None): | |
| 41 | self.backend = backend | |
| 42 | ||
| 43 | # Stores connections assigned to tasks | |
| 44 | self.__connections = {} | |
| 45 | ||
| 46 | # Create a connection pool | |
| 47 | self.pool = psycopg_pool.ConnectionPool( | |
| 48 | "postgresql://%s:%s@%s/%s" % (user, password, host, database), | |
| 49 | ||
| 50 | # Callback to configure any new connections | |
| 51 | configure=self.__configure, | |
| 52 | ||
| 53 | # Set limits for min/max connections in the pool | |
| 54 | min_size=8, | |
| 55 | max_size=512, | |
| 56 | ||
| 57 | # Give clients up to one minute to retrieve a connection | |
| 58 | timeout=60, | |
| 59 | ||
| 60 | # Close connections after they have been idle for a few seconds | |
| 61 | max_idle=5, | |
| 62 | ) | |
| 63 | ||
| 64 | def __configure(self, conn): | |
| 65 | """ | |
| 66 | Configures any newly opened connections | |
| 67 | """ | |
| 68 | # Enable autocommit | |
| 69 | conn.autocommit = True | |
| 70 | ||
| 71 | # Return any rows as dicts | |
| 72 | conn.row_factory = psycopg.rows.dict_row | |
| 73 | ||
| 74 | # Automatically convert DataObjects | |
| 75 | conn.adapters.register_dumper(misc.Object, misc.ObjectDumper) | |
| 76 | ||
| 77 | def connection(self, *args, **kwargs): | |
| 78 | """ | |
| 79 | Returns a connection from the pool | |
| 80 | """ | |
| 81 | # Fetch the current task | |
| 82 | task = asyncio.current_task() | |
| 83 | ||
| 84 | assert task, "Could not determine task" | |
| 85 | ||
| 86 | # Try returning the same connection to the same task | |
| 87 | try: | |
| 88 | return self.__connections[task] | |
| 89 | except KeyError: | |
| 90 | pass | |
| 91 | ||
| 92 | # Fetch a new connection from the pool | |
| 93 | conn = self.__connections[task] = self.pool.getconn(*args, **kwargs) | |
| 94 | ||
| 95 | log.debug("Assigning database connection %s to %s" % (conn, task)) | |
| 96 | ||
| 97 | # When the task finishes, release the connection | |
| 98 | task.add_done_callback(self.__release_connection) | |
| 99 | ||
| 100 | return conn | |
| 101 | ||
| 102 | def __release_connection(self, task): | |
| 103 | # Retrieve the connection | |
| 104 | try: | |
| 105 | conn = self.__connections[task] | |
| 106 | except KeyError: | |
| 107 | return | |
| 108 | ||
| 109 | log.debug("Releasing database connection %s of %s" % (conn, task)) | |
| 110 | ||
| 111 | # Delete it | |
| 112 | del self.__connections[task] | |
| 113 | ||
| 114 | # Return the connection back into the pool | |
| 115 | self.pool.putconn(conn) | |
| 116 | ||
| 117 | def _execute(self, cursor, execute, query, parameters): | |
| 118 | # Store the time we started this query | |
| 119 | t = time.monotonic() | |
| 120 | ||
| 121 | try: | |
| 122 | log.debug("Running SQL query %s" % (query % parameters)) | |
| 123 | except Exception: | |
| 124 | pass | |
| 125 | ||
| 126 | # Execute the query | |
| 127 | execute(query, parameters) | |
| 128 | ||
| 129 | # How long did this take? | |
| 130 | elapsed = time.monotonic() - t | |
| 131 | ||
| 132 | # Log the query time | |
| 133 | log.debug(" Query time: %.2fms" % (elapsed * 1000)) | |
| 134 | ||
| 135 | def query(self, query, *parameters, **kwparameters): | |
| 136 | """ | |
| 137 | Returns a row list for the given query and parameters. | |
| 138 | """ | |
| 139 | conn = self.connection() | |
| 140 | ||
| 141 | with conn.cursor() as cursor: | |
| 142 | self._execute(cursor, cursor.execute, query, parameters or kwparameters) | |
| 143 | ||
| 144 | return [Row(row) for row in cursor] | |
| 145 | ||
| 146 | def get(self, query, *parameters, **kwparameters): | |
| 147 | """ | |
| 148 | Returns the first row returned for the given query. | |
| 149 | """ | |
| 150 | rows = self.query(query, *parameters, **kwparameters) | |
| 151 | if not rows: | |
| 152 | return None | |
| 153 | elif len(rows) > 1: | |
| 154 | raise Exception("Multiple rows returned for Database.get() query") | |
| 155 | else: | |
| 156 | return rows[0] | |
| 157 | ||
| 158 | def execute(self, query, *parameters, **kwparameters): | |
| 159 | """ | |
| 160 | Executes the given query. | |
| 161 | """ | |
| 162 | conn = self.connection() | |
| 163 | ||
| 164 | with conn.cursor() as cursor: | |
| 165 | self._execute(cursor, cursor.execute, query, parameters or kwparameters) | |
| 166 | ||
| 167 | def executemany(self, query, parameters): | |
| 168 | """ | |
| 169 | Executes the given query against all the given param sequences. | |
| 170 | """ | |
| 171 | conn = self.connection() | |
| 172 | ||
| 173 | with conn.cursor() as cursor: | |
| 174 | self._execute(cursor, cursor.executemany, query, parameters) | |
| 175 | ||
| 176 | def transaction(self): | |
| 177 | """ | |
| 178 | Creates a new transaction on the current tasks' connection | |
| 179 | """ | |
| 180 | conn = self.connection() | |
| 181 | ||
| 182 | return conn.transaction() | |
| 183 | ||
| 184 | def fetch_one(self, cls, query, *args, **kwargs): | |
| 185 | """ | |
| 186 | Takes a class and a query and will return one object of that class | |
| 187 | """ | |
| 188 | # Execute the query | |
| 189 | res = self.get(query, *args) | |
| 190 | ||
| 191 | # Return an object (if possible) | |
| 192 | if res: | |
| 193 | return cls(self.backend, res.id, res, **kwargs) | |
| 194 | ||
| 195 | def fetch_many(self, cls, query, *args, **kwargs): | |
| 196 | # Execute the query | |
| 197 | res = self.query(query, *args) | |
| 198 | ||
| 199 | # Return a generator with objects | |
| 200 | for row in res: | |
| 201 | yield cls(self.backend, row.id, row, **kwargs) | |
| 202 | ||
| 203 | ||
| 204 | class Row(dict): | |
| 205 | """A dict that allows for object-like property access syntax.""" | |
| 206 | def __getattr__(self, name): | |
| 207 | try: | |
| 208 | return self[name] | |
| 209 | except KeyError: | |
| 210 | raise AttributeError(name) |