From 232281a563c500e4c8e16729be9b602050bdedb8 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Sat, 28 Feb 2009 06:06:44 +0000 Subject: [PATCH] moving to simpler cache technique --- lib/sqlalchemy/dialects/postgres/base.py | 90 +++++++----------------- lib/sqlalchemy/engine/reflection.py | 35 ++++++--- 2 files changed, 53 insertions(+), 72 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index d031e30ae8..e9a1f09e33 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -397,21 +397,6 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): value = value[1:-1].replace('""','"') return value -class PGInfoCache(reflection.DefaultInfoCache): - - def __init__(self): - - reflection.DefaultInfoCache.__init__(self) - - def get_table_oid(self, tablename, schemaname=None): - table = self.get_table(tablename, schemaname) - if table: - return table.get('oid') - - def set_table_oid(self, oid, tablename, schemaname=None): - table = self.get_table(tablename, schemaname, create=True) - table['oid'] = oid - class PGDialect(default.DefaultDialect): name = 'postgres' supports_alter = True @@ -432,7 +417,6 @@ class PGDialect(default.DefaultDialect): type_compiler = PGTypeCompiler preparer = PGIdentifierPreparer defaultrunner = PGDefaultRunner - info_cache = PGInfoCache def do_begin_twophase(self, connection, xid): @@ -515,8 +499,7 @@ class PGDialect(default.DefaultDialect): raise AssertionError("Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3)]) - def _get_table_oid(self, connection, tablename, schemaname=None, - info_cache=None): + def _get_table_oid(self, connection, tablename, schemaname=None): """Fetch the oid for schemaname.tablename. Several reflection methods require the table oid. The idea for using @@ -525,10 +508,6 @@ class PGDialect(default.DefaultDialect): """ table_oid = None - if info_cache: - table_oid = info_cache.get_table_oid(tablename, schemaname) - if table_oid: - return table_oid if schemaname is not None: schema_where_clause = "n.nspname = :schema" else: @@ -555,13 +534,10 @@ class PGDialect(default.DefaultDialect): table_oid = c.scalar() if table_oid is None: raise exc.NoSuchTableError(tablename) - # cache it - if info_cache: - info_cache.set_table_oid(table_oid, tablename, schemaname) return table_oid - @reflection.caches - def get_schema_names(self, connection, info_cache=None): + @reflection.cache + def get_schema_names(self, connection): s = """ SELECT nspname FROM pg_namespace @@ -573,8 +549,8 @@ class PGDialect(default.DefaultDialect): if not row[0].startswith('pg_')] return schema_names - @reflection.caches - def get_table_names(self, connection, schemaname=None, info_cache=None): + @reflection.cache + def get_table_names(self, connection, schemaname=None): if schemaname is not None: current_schema = schemaname else: @@ -582,8 +558,8 @@ class PGDialect(default.DefaultDialect): table_names = self.table_names(connection, current_schema) return table_names - @reflection.caches - def get_view_names(self, connection, schemaname=None, info_cache=None): + @reflection.cache + def get_view_names(self, connection, schemaname=None): if schemaname is not None: current_schema = schemaname else: @@ -597,9 +573,8 @@ class PGDialect(default.DefaultDialect): view_names = [row[0].decode(self.encoding) for row in connection.execute(s)] return view_names - @reflection.caches - def get_view_definition(self, connection, viewname, schemaname=None, - info_cache=None): + @reflection.cache + def get_view_definition(self, connection, viewname, schemaname=None): if schemaname is not None: current_schema = schemaname else: @@ -615,12 +590,10 @@ class PGDialect(default.DefaultDialect): view_def = rp.scalar().decode(self.encoding) return view_def - @reflection.caches - def get_columns(self, connection, tablename, schemaname=None, - info_cache=None): + @reflection.cache + def get_columns(self, connection, tablename, schemaname=None): - table_oid = self._get_table_oid(connection, tablename, schemaname, - info_cache) + table_oid = self._get_table_oid(connection, tablename, schemaname) SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), @@ -704,11 +677,9 @@ class PGDialect(default.DefaultDialect): columns.append(column_info) return columns - @reflection.caches - def get_primary_keys(self, connection, tablename, schemaname=None, - info_cache=None): - table_oid = self._get_table_oid(connection, tablename, schemaname, - info_cache) + @reflection.cache + def get_primary_keys(self, connection, tablename, schemaname=None): + table_oid = self._get_table_oid(connection, tablename, schemaname) PK_SQL = """ SELECT attname FROM pg_attribute WHERE attrelid = ( @@ -722,12 +693,10 @@ class PGDialect(default.DefaultDialect): primary_keys = [r[0] for r in c.fetchall()] return primary_keys - @reflection.caches - def get_foreign_keys(self, connection, tablename, schemaname=None, - info_cache=None): + @reflection.cache + def get_foreign_keys(self, connection, tablename, schemaname=None): preparer = self.identifier_preparer - table_oid = self._get_table_oid(connection, tablename, schemaname, - info_cache) + table_oid = self._get_table_oid(connection, tablename, schemaname) FK_SQL = """ SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef FROM pg_catalog.pg_constraint r @@ -761,10 +730,9 @@ class PGDialect(default.DefaultDialect): fkeys.append(fkey_d) return fkeys - @reflection.caches - def get_indexes(self, connection, tablename, schemaname, info_cache=None): - table_oid = self._get_table_oid(connection, tablename, schemaname, - info_cache) + @reflection.cache + def get_indexes(self, connection, tablename, schemaname): + table_oid = self._get_table_oid(connection, tablename, schemaname) IDX_SQL = """ SELECT c.relname, i.indisunique, i.indexprs, i.indpred, a.attname @@ -813,9 +781,7 @@ class PGDialect(default.DefaultDialect): if isinstance(tablename, str): tablename = tablename.decode(self.encoding) # end Py2K - info_cache = PGInfoCache() - for col_d in self.get_columns(connection, tablename, schemaname, - info_cache): + for col_d in self.get_columns(connection, tablename, schemaname): name = col_d['name'] coltype = col_d['type'] nullable = col_d['nullable'] @@ -835,19 +801,16 @@ class PGDialect(default.DefaultDialect): colargs.append(schema.DefaultClause(sql.text(default))) table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) # Now we have the table oid cached. - table_oid = self._get_table_oid(connection, tablename, schemaname, - info_cache) + table_oid = self._get_table_oid(connection, tablename, schemaname) # Primary keys - for pk in self.get_primary_keys(connection, tablename, schemaname, - info_cache): + for pk in self.get_primary_keys(connection, tablename, schemaname): if pk in table.c: col = table.c[pk] table.primary_key.add(col) if col.default is None: col.autoincrement = False # Foreign keys - fkeys = self.get_foreign_keys(connection, tablename, schemaname, - info_cache) + fkeys = self.get_foreign_keys(connection, tablename, schemaname) for fkey_d in fkeys: conname = fkey_d['name'] constrained_columns = fkey_d['constrained_columns'] @@ -868,8 +831,7 @@ class PGDialect(default.DefaultDialect): table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) # Indexes - indexes = self.get_indexes(connection, tablename, schemaname, - info_cache) + indexes = self.get_indexes(connection, tablename, schemaname) for index_d in indexes: name = index_d['name'] columns = index_d['column_names'] diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 7f8143d600..2f7d3021d6 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -22,6 +22,21 @@ from sqlalchemy import util from sqlalchemy.types import TypeEngine +##@util.decorator +def cache(fn): + def decorated(self, con, *args, **kw): + info_cache = kw.pop('info_cache', None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = (fn.__name__, args, str(kw)) + ret = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + return decorated + +# keeping this around until all dialects are fixed @util.decorator def caches(fn, self, con, *args, **kw): # what are we caching? @@ -270,7 +285,11 @@ class Inspector(object): self.engine = conn.engine else: self.engine = conn - self.info_cache = self.engine.dialect.info_cache() + # fixme. This is just until all dialects are converted + if hasattr(self, 'info_cache'): + self.info_cache = self.engine.dialect.info_cache() + else: + self.info_cache = {} def default_schema_name(self): return self.engine.dialect.get_default_schema_name(self.conn) @@ -282,7 +301,7 @@ class Inspector(object): """ if hasattr(self.engine.dialect, 'get_schema_names'): return self.engine.dialect.get_schema_names(self.conn, - self.info_cache) + info_cache=self.info_cache) return [] def get_table_names(self, schemaname=None, order_by=None): @@ -296,7 +315,7 @@ class Inspector(object): """ if hasattr(self.engine.dialect, 'get_table_names'): tnames = self.engine.dialect.get_table_names(self.conn, schemaname, - self.info_cache) + info_cache=self.info_cache) else: tnames = self.engine.table_names(schemaname) if order_by == 'foreign_key': @@ -325,7 +344,7 @@ class Inspector(object): """ return self.engine.dialect.get_view_names(self.conn, schemaname, - self.info_cache) + info_cache=self.info_cache) def get_view_definition(self, view_name, schemaname=None): """Return definition for `view_name`. @@ -334,7 +353,7 @@ class Inspector(object): """ return self.engine.dialect.get_view_definition( - self.conn, view_name, schemaname, self.info_cache) + self.conn, view_name, schemaname, info_cache=self.info_cache) def get_columns(self, tablename, schemaname=None): """Return information about columns in `tablename`. @@ -379,7 +398,7 @@ class Inspector(object): pkeys = self.engine.dialect.get_primary_keys(self.conn, tablename, schemaname, - self.info_cache) + info_cache=self.info_cache) return pkeys @@ -406,7 +425,7 @@ class Inspector(object): fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename, schemaname, - self.info_cache) + info_cache=self.info_cache) for fk_def in fk_defs: referred_schema = fk_def['referred_schema'] # always set the referred_schema. @@ -435,5 +454,5 @@ class Inspector(object): indexes = self.engine.dialect.get_indexes(self.conn, tablename, schemaname, - self.info_cache) + info_cache=self.info_cache) return indexes -- 2.47.3