value = value[1:-1].replace('""','"')
return value
-class PGInfoCache(reflection.DefaultInfoCache):
-
- def __init__(self):
-
- reflection.DefaultInfoCache.__init__(self)
-
- def get_table_oid(self, tablename, schemaname=None):
- table = self.get_table(tablename, schemaname)
- if table:
- return table.get('oid')
-
- def set_table_oid(self, oid, tablename, schemaname=None):
- table = self.get_table(tablename, schemaname, create=True)
- table['oid'] = oid
-
class PGDialect(default.DefaultDialect):
name = 'postgres'
supports_alter = True
type_compiler = PGTypeCompiler
preparer = PGIdentifierPreparer
defaultrunner = PGDefaultRunner
- info_cache = PGInfoCache
def do_begin_twophase(self, connection, xid):
raise AssertionError("Could not determine version from string '%s'" % v)
return tuple([int(x) for x in m.group(1, 2, 3)])
- def _get_table_oid(self, connection, tablename, schemaname=None,
- info_cache=None):
+ def _get_table_oid(self, connection, tablename, schemaname=None):
"""Fetch the oid for schemaname.tablename.
Several reflection methods require the table oid. The idea for using
"""
table_oid = None
- if info_cache:
- table_oid = info_cache.get_table_oid(tablename, schemaname)
- if table_oid:
- return table_oid
if schemaname is not None:
schema_where_clause = "n.nspname = :schema"
else:
table_oid = c.scalar()
if table_oid is None:
raise exc.NoSuchTableError(tablename)
- # cache it
- if info_cache:
- info_cache.set_table_oid(table_oid, tablename, schemaname)
return table_oid
- @reflection.caches
- def get_schema_names(self, connection, info_cache=None):
+ @reflection.cache
+ def get_schema_names(self, connection):
s = """
SELECT nspname
FROM pg_namespace
if not row[0].startswith('pg_')]
return schema_names
- @reflection.caches
- def get_table_names(self, connection, schemaname=None, info_cache=None):
+ @reflection.cache
+ def get_table_names(self, connection, schemaname=None):
if schemaname is not None:
current_schema = schemaname
else:
table_names = self.table_names(connection, current_schema)
return table_names
- @reflection.caches
- def get_view_names(self, connection, schemaname=None, info_cache=None):
+ @reflection.cache
+ def get_view_names(self, connection, schemaname=None):
if schemaname is not None:
current_schema = schemaname
else:
view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
return view_names
- @reflection.caches
- def get_view_definition(self, connection, viewname, schemaname=None,
- info_cache=None):
+ @reflection.cache
+ def get_view_definition(self, connection, viewname, schemaname=None):
if schemaname is not None:
current_schema = schemaname
else:
view_def = rp.scalar().decode(self.encoding)
return view_def
- @reflection.caches
- def get_columns(self, connection, tablename, schemaname=None,
- info_cache=None):
+ @reflection.cache
+ def get_columns(self, connection, tablename, schemaname=None):
- table_oid = self._get_table_oid(connection, tablename, schemaname,
- info_cache)
+ table_oid = self._get_table_oid(connection, tablename, schemaname)
SQL_COLS = """
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod),
columns.append(column_info)
return columns
- @reflection.caches
- def get_primary_keys(self, connection, tablename, schemaname=None,
- info_cache=None):
- table_oid = self._get_table_oid(connection, tablename, schemaname,
- info_cache)
+ @reflection.cache
+ def get_primary_keys(self, connection, tablename, schemaname=None):
+ table_oid = self._get_table_oid(connection, tablename, schemaname)
PK_SQL = """
SELECT attname FROM pg_attribute
WHERE attrelid = (
primary_keys = [r[0] for r in c.fetchall()]
return primary_keys
- @reflection.caches
- def get_foreign_keys(self, connection, tablename, schemaname=None,
- info_cache=None):
+ @reflection.cache
+ def get_foreign_keys(self, connection, tablename, schemaname=None):
preparer = self.identifier_preparer
- table_oid = self._get_table_oid(connection, tablename, schemaname,
- info_cache)
+ table_oid = self._get_table_oid(connection, tablename, schemaname)
FK_SQL = """
SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
FROM pg_catalog.pg_constraint r
fkeys.append(fkey_d)
return fkeys
- @reflection.caches
- def get_indexes(self, connection, tablename, schemaname, info_cache=None):
- table_oid = self._get_table_oid(connection, tablename, schemaname,
- info_cache)
+ @reflection.cache
+ def get_indexes(self, connection, tablename, schemaname):
+ table_oid = self._get_table_oid(connection, tablename, schemaname)
IDX_SQL = """
SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
a.attname
if isinstance(tablename, str):
tablename = tablename.decode(self.encoding)
# end Py2K
- info_cache = PGInfoCache()
- for col_d in self.get_columns(connection, tablename, schemaname,
- info_cache):
+ for col_d in self.get_columns(connection, tablename, schemaname):
name = col_d['name']
coltype = col_d['type']
nullable = col_d['nullable']
colargs.append(schema.DefaultClause(sql.text(default)))
table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
# Now we have the table oid cached.
- table_oid = self._get_table_oid(connection, tablename, schemaname,
- info_cache)
+ table_oid = self._get_table_oid(connection, tablename, schemaname)
# Primary keys
- for pk in self.get_primary_keys(connection, tablename, schemaname,
- info_cache):
+ for pk in self.get_primary_keys(connection, tablename, schemaname):
if pk in table.c:
col = table.c[pk]
table.primary_key.add(col)
if col.default is None:
col.autoincrement = False
# Foreign keys
- fkeys = self.get_foreign_keys(connection, tablename, schemaname,
- info_cache)
+ fkeys = self.get_foreign_keys(connection, tablename, schemaname)
for fkey_d in fkeys:
conname = fkey_d['name']
constrained_columns = fkey_d['constrained_columns']
table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
# Indexes
- indexes = self.get_indexes(connection, tablename, schemaname,
- info_cache)
+ indexes = self.get_indexes(connection, tablename, schemaname)
for index_d in indexes:
name = index_d['name']
columns = index_d['column_names']
from sqlalchemy.types import TypeEngine
+##@util.decorator
+def cache(fn):
+ def decorated(self, con, *args, **kw):
+ info_cache = kw.pop('info_cache', None)
+ if info_cache is None:
+ return fn(self, con, *args, **kw)
+ key = (fn.__name__, args, str(kw))
+ ret = info_cache.get(key)
+ if ret is None:
+ ret = fn(self, con, *args, **kw)
+ info_cache[key] = ret
+ return ret
+ return decorated
+
+# keeping this around until all dialects are fixed
@util.decorator
def caches(fn, self, con, *args, **kw):
# what are we caching?
self.engine = conn.engine
else:
self.engine = conn
- self.info_cache = self.engine.dialect.info_cache()
+ # fixme. This is just until all dialects are converted
+ if hasattr(self, 'info_cache'):
+ self.info_cache = self.engine.dialect.info_cache()
+ else:
+ self.info_cache = {}
def default_schema_name(self):
return self.engine.dialect.get_default_schema_name(self.conn)
"""
if hasattr(self.engine.dialect, 'get_schema_names'):
return self.engine.dialect.get_schema_names(self.conn,
- self.info_cache)
+ info_cache=self.info_cache)
return []
def get_table_names(self, schemaname=None, order_by=None):
"""
if hasattr(self.engine.dialect, 'get_table_names'):
tnames = self.engine.dialect.get_table_names(self.conn, schemaname,
- self.info_cache)
+ info_cache=self.info_cache)
else:
tnames = self.engine.table_names(schemaname)
if order_by == 'foreign_key':
"""
return self.engine.dialect.get_view_names(self.conn, schemaname,
- self.info_cache)
+ info_cache=self.info_cache)
def get_view_definition(self, view_name, schemaname=None):
"""Return definition for `view_name`.
"""
return self.engine.dialect.get_view_definition(
- self.conn, view_name, schemaname, self.info_cache)
+ self.conn, view_name, schemaname, info_cache=self.info_cache)
def get_columns(self, tablename, schemaname=None):
"""Return information about columns in `tablename`.
pkeys = self.engine.dialect.get_primary_keys(self.conn, tablename,
schemaname,
- self.info_cache)
+ info_cache=self.info_cache)
return pkeys
fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename,
schemaname,
- self.info_cache)
+ info_cache=self.info_cache)
for fk_def in fk_defs:
referred_schema = fk_def['referred_schema']
# always set the referred_schema.
indexes = self.engine.dialect.get_indexes(self.conn, tablename,
schemaname,
- self.info_cache)
+ info_cache=self.info_cache)
return indexes