From d9158b5e2719b699b8d528195f710698380cb749 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 11 Feb 2009 04:45:35 +0000 Subject: [PATCH] revised to use PGInfoCache --- lib/sqlalchemy/dialects/postgres/base.py | 130 ++++++++++++++++------- 1 file changed, 90 insertions(+), 40 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index d3dc9c0e3f..b4cf435058 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -397,6 +397,21 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): value = value[1:-1].replace('""','"') return value +class PGInfoCache(default.DefaultInfoCache): + + def __init__(self): + + default.DefaultInfoCache.__init__(self) + + def getTableOID(self, tablename, schemaname=None): + table = self.getTable(tablename, schemaname) + if table: + return table.get('oid') + + def setTableOID(self, oid, tablename, schemaname=None): + table = self.getTable(tablename, schemaname, create=True) + table['oid'] = oid + class PGDialect(default.DefaultDialect): name = 'postgres' supports_alter = True @@ -499,23 +514,6 @@ class PGDialect(default.DefaultDialect): raise AssertionError("Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3)]) - def _prepare_info_cache(self, info_cache, tablename, schemaname): - """Add index for schemaname.table_name if it does not exist. - - This is done so that certain keys can be assumed to be present. - - """ - # First, make sure it has the keys we expect. - if info_cache is None: - info_cache = dict(tables={}) - elif 'tables' not in info_cache: - info_cache['tables'] = {} - # Add the table index if needed. - table_index = "%s.%s" % (schemaname, tablename) - if table_index not in info_cache['tables']: - info_cache['tables'][table_index] = {} - return info_cache - def _get_table_oid(self, connection, tablename, schemaname=None, info_cache=None): """Fetch the oid for schemaname.tablename. @@ -525,10 +523,9 @@ class PGDialect(default.DefaultDialect): subsequent calls. """ - info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) - # If it's in info_cache, juse use that. - table_index = "%s.%s" % (schemaname, tablename) - table_oid = info_cache['tables'][table_index].get('table_oid') + table_oid = None + if info_cache: + table_oid = info_cache.getTableOID(tablename, schemaname) if table_oid: return table_oid if schemaname is not None: @@ -553,10 +550,15 @@ class PGDialect(default.DefaultDialect): if table_oid is None: raise exc.NoSuchTableError(table_name) # cache it - info_cache['tables'][table_index]['table_oid'] = table_oid + if info_cache: + info_cache.setTableOID(table_oid, tablename, schemaname) return table_oid def get_schema_names(self, connection, info_cache=None): + if info_cache: + schema_names = info_cache.getSchemaNames() + if schema_names is not None: + return schema_names s = """ SELECT nspname FROM pg_namespace @@ -564,28 +566,45 @@ class PGDialect(default.DefaultDialect): """ rp = connection.execute(s) # what about system tables? - return [row[0].decode(self.encoding) for row in rp \ - if not row[0].startswith('pg_')] + schema_names = [row[0].decode(self.encoding) for row in rp \ + if not row[0].startswith('pg_')] + if info_cache: + info_cache.addAllSchemas(schema_names) + return schema_names def get_table_names(self, connection, schemaname=None, info_cache=None): if schemaname is not None: current_schema = schemaname else: current_schema = self.get_default_schema_name(connection) - return self.table_names(connection, current_schema) + if info_cache: + table_names = info_cache.getTableNames(current_schema) + if table_names is not None: + return table_names + table_names = self.table_names(connection, current_schema) + if info_cache: + info_cache.addAllTables(table_names, current_schema) + return table_names def get_view_names(self, connection, schemaname=None, info_cache=None): if schemaname is not None: current_schema = schemaname else: current_schema = self.get_default_schema_name(connection) + if info_cache: + view_names = info_cache.getViewNames(current_schema) + if view_names is not None: + return view_names s = """ SELECT relname FROM pg_class c WHERE relkind = 'v' AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) """ % dict(schema=current_schema) - return [row[0].decode(self.encoding) for row in connection.execute(s)] + view_names = [row[0].decode(self.encoding) for row in connection.execute(s)] + if info_cache: + info_cache.addAllViews(view_names, schemaname) + return view_names def get_view_definition(self, connection, viewname, schemaname=None, info_cache=None): @@ -593,6 +612,10 @@ class PGDialect(default.DefaultDialect): current_schema = schemaname else: current_schema = self.get_default_schema_name(connection) + if info_cache: + view = info_cache.getView(viewname, current_schema) + if view.get('definition'): + return view['definition'] s = """ SELECT definition FROM pg_views WHERE schemaname = :schemaname @@ -601,11 +624,19 @@ class PGDialect(default.DefaultDialect): rp = connection.execute(sql.text(s), viewname=viewname, schemaname=current_schema) if rp: - return rp.scalar().decode(self.encoding) + view_def = rp.scalar().decode(self.encoding) + if info_cache: + view = info_cache.getView(viewname, current_schema, + create=True) + view['definition'] = view_def + return view_def def get_columns(self, connection, tablename, schemaname=None, info_cache=None): - info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname) + if table_cache and 'columns' in table_cache.keys(): + return table_cache.get('columns') table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) SQL_COLS = """ @@ -689,11 +720,18 @@ class PGDialect(default.DefaultDialect): column_info = dict(name=name, type=coltype, nullable=nullable, default=default, colargs=colargs) columns.append(column_info) + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname, + create=True) + table_cache['columns'] = columns return columns def get_primary_keys(self, connection, tablename, schemaname=None, info_cache=None): - info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname) + if table_cache and 'primary_keys' in table_cache.keys(): + return table_cache.get('primary_keys') table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) PK_SQL = """ @@ -706,19 +744,20 @@ class PGDialect(default.DefaultDialect): """ t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) - return [r[0] for r in c.fetchall()] - for row in c.fetchall(): - pk = row[0] - if pk in table.c: - col = table.c[pk] - table.primary_key.add(col) - if col.default is None: - col.autoincrement = False + primary_keys = [r[0] for r in c.fetchall()] + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname, + create=True) + table_cache['primary_keys'] = primary_keys + return primary_keys def get_foreign_keys(self, connection, tablename, schemaname=None, info_cache=None): + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname) + if table_cache and 'foreign_keys' in table_cache.keys(): + return table_cache.get('foreign_keys') preparer = self.identifier_preparer - info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) FK_SQL = """ @@ -752,10 +791,17 @@ class PGDialect(default.DefaultDialect): 'referred_columns' : referred_columns } fkeys.append(fkey_d) + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname, + create=True) + table_cache['foreign_keys'] = fkeys return fkeys def get_indexes(self, connection, tablename, schemaname, info_cache=None): - info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname) + if table_cache and 'indexes' in table_cache.keys(): + return table_cache.get('indexes') table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) IDX_SQL = """ @@ -793,6 +839,10 @@ class PGDialect(default.DefaultDialect): index_d['name'] = idx_name index_d['column_names'].append(col) index_d['unique'] = unique + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname, + create=True) + table_cache['indexes'] = indexes return indexes def reflecttable(self, connection, table, include_columns): @@ -805,7 +855,7 @@ class PGDialect(default.DefaultDialect): if isinstance(tablename, str): tablename = tablename.decode(self.encoding) # end Py2K - info_cache = {} + info_cache = PGInfoCache() for col_d in self.get_columns(connection, tablename, schemaname, info_cache): name = col_d['name'] -- 2.47.3