]> git.ipfire.org Git - location/libloc.git/blame - src/python/database.py
database: Always require SSL
[location/libloc.git] / src / python / database.py
CommitLineData
29c6fa22
MT
1#!/usr/bin/env python
2
3"""
4 A lightweight wrapper around psycopg2.
5
6 Originally part of the Tornado framework. The tornado.database module
7 is slated for removal in Tornado 3.0, and it is now available separately
8 as torndb.
9"""
10
11import logging
12import psycopg2
13
14log = logging.getLogger("location.database")
15log.propagate = 1
16
17class Connection(object):
18 """
19 A lightweight wrapper around MySQLdb DB-API connections.
20
21 The main value we provide is wrapping rows in a dict/object so that
22 columns can be accessed by name. Typical usage::
23
24 db = torndb.Connection("localhost", "mydatabase")
25 for article in db.query("SELECT * FROM articles"):
26 print article.title
27
28 Cursors are hidden by the implementation, but other than that, the methods
29 are very similar to the DB-API.
30
31 We explicitly set the timezone to UTC and the character encoding to
32 UTF-8 on all connections to avoid time zone and encoding errors.
33 """
34 def __init__(self, host, database, user=None, password=None):
35 self.host = host
36 self.database = database
37
38 self._db = None
39 self._db_args = {
40 "host" : host,
41 "database" : database,
42 "user" : user,
43 "password" : password,
16cdf1d9 44 "sslmode" : "require",
29c6fa22
MT
45 }
46
47 try:
48 self.reconnect()
49 except Exception:
50 log.error("Cannot connect to database on %s", self.host, exc_info=True)
51
52 def __del__(self):
53 self.close()
54
55 def close(self):
56 """
57 Closes this database connection.
58 """
59 if getattr(self, "_db", None) is not None:
60 self._db.close()
61 self._db = None
62
63 def reconnect(self):
64 """
65 Closes the existing database connection and re-opens it.
66 """
67 self.close()
68
69 self._db = psycopg2.connect(**self._db_args)
70 self._db.autocommit = True
71
72 # Initialize the timezone setting.
73 self.execute("SET TIMEZONE TO 'UTC'")
74
75 def query(self, query, *parameters, **kwparameters):
76 """
77 Returns a row list for the given query and parameters.
78 """
79 cursor = self._cursor()
80 try:
81 self._execute(cursor, query, parameters, kwparameters)
82 column_names = [d[0] for d in cursor.description]
83 return [Row(zip(column_names, row)) for row in cursor]
84 finally:
85 cursor.close()
86
87 def get(self, query, *parameters, **kwparameters):
88 """
89 Returns the first row returned for the given query.
90 """
91 rows = self.query(query, *parameters, **kwparameters)
92 if not rows:
93 return None
94 elif len(rows) > 1:
95 raise Exception("Multiple rows returned for Database.get() query")
96 else:
97 return rows[0]
98
99 def execute(self, query, *parameters, **kwparameters):
100 """
101 Executes the given query, returning the lastrowid from the query.
102 """
103 return self.execute_lastrowid(query, *parameters, **kwparameters)
104
105 def execute_lastrowid(self, query, *parameters, **kwparameters):
106 """
107 Executes the given query, returning the lastrowid from the query.
108 """
109 cursor = self._cursor()
110 try:
111 self._execute(cursor, query, parameters, kwparameters)
112 return cursor.lastrowid
113 finally:
114 cursor.close()
115
116 def execute_rowcount(self, query, *parameters, **kwparameters):
117 """
118 Executes the given query, returning the rowcount from the query.
119 """
120 cursor = self._cursor()
121 try:
122 self._execute(cursor, query, parameters, kwparameters)
123 return cursor.rowcount
124 finally:
125 cursor.close()
126
127 def executemany(self, query, parameters):
128 """
129 Executes the given query against all the given param sequences.
130
131 We return the lastrowid from the query.
132 """
133 return self.executemany_lastrowid(query, parameters)
134
135 def executemany_lastrowid(self, query, parameters):
136 """
137 Executes the given query against all the given param sequences.
138
139 We return the lastrowid from the query.
140 """
141 cursor = self._cursor()
142 try:
143 cursor.executemany(query, parameters)
144 return cursor.lastrowid
145 finally:
146 cursor.close()
147
148 def executemany_rowcount(self, query, parameters):
149 """
150 Executes the given query against all the given param sequences.
151
152 We return the rowcount from the query.
153 """
154 cursor = self._cursor()
155
156 try:
157 cursor.executemany(query, parameters)
158 return cursor.rowcount
159 finally:
160 cursor.close()
161
162 def _ensure_connected(self):
163 if self._db is None:
164 log.warning("Database connection was lost...")
165
166 self.reconnect()
167
168 def _cursor(self):
169 self._ensure_connected()
170 return self._db.cursor()
171
172 def _execute(self, cursor, query, parameters, kwparameters):
173 log.debug("SQL Query: %s" % (query % (kwparameters or parameters)))
174
175 try:
176 return cursor.execute(query, kwparameters or parameters)
177 except (OperationalError, psycopg2.ProgrammingError):
178 log.error("Error connecting to database on %s", self.host)
179 self.close()
180 raise
181
182 def transaction(self):
183 return Transaction(self)
184
185
186class Row(dict):
187 """A dict that allows for object-like property access syntax."""
188 def __getattr__(self, name):
189 try:
190 return self[name]
191 except KeyError:
192 raise AttributeError(name)
193
194
195class Transaction(object):
196 def __init__(self, db):
197 self.db = db
198
199 self.db.execute("START TRANSACTION")
200
201 def __enter__(self):
202 return self
203
204 def __exit__(self, exctype, excvalue, traceback):
205 if exctype is not None:
206 self.db.execute("ROLLBACK")
207 else:
208 self.db.execute("COMMIT")
209
210
211# Alias some common exceptions
212IntegrityError = psycopg2.IntegrityError
213OperationalError = psycopg2.OperationalError