From 5a3769f68575b2b075fd7c01e0d4405b1893d52c Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 18 Feb 2009 06:38:53 +0000 Subject: [PATCH] reflection methods not use decorator for caching --- lib/sqlalchemy/dialects/information_schema.py | 8 + lib/sqlalchemy/dialects/mssql/base.py | 89 ++----- lib/sqlalchemy/dialects/oracle/base.py | 52 +--- lib/sqlalchemy/dialects/postgres/base.py | 85 ++----- lib/sqlalchemy/engine/default.py | 176 ------------- lib/sqlalchemy/engine/reflection.py | 234 +++++++++++++++++- test/reflection.py | 2 +- 7 files changed, 293 insertions(+), 353 deletions(-) diff --git a/lib/sqlalchemy/dialects/information_schema.py b/lib/sqlalchemy/dialects/information_schema.py index b15082ac2e..9a65cca4cd 100644 --- a/lib/sqlalchemy/dialects/information_schema.py +++ b/lib/sqlalchemy/dialects/information_schema.py @@ -78,6 +78,14 @@ ref_constraints = Table("referential_constraints", ischema, Column("delete_rule", String), schema="information_schema") +views = Table("views", ischema, + Column("table_catalog", String), + Column("table_schema", String), + Column("table_name", String), + Column("view_definition", String), + Column("check_option", String), + Column("is_updatable", String), + schema="information_schema") def table_names(connection, schema): s = select([tables.c.table_name], tables.c.table_schema==schema) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 43fe4b5d57..b45e7cd5aa 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -231,7 +231,7 @@ import datetime, decimal, inspect, operator, sys, re from sqlalchemy import sql, schema, exc, util from sqlalchemy.sql import compiler, expression, operators as sql_operators, functions as sql_functions -from sqlalchemy.engine import default, base +from sqlalchemy.engine import default, base, reflection from sqlalchemy import types as sqltypes from decimal import Decimal as _python_Decimal @@ -1044,10 +1044,10 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): return value -class MSInfoCache(default.DefaultInfoCache): +class MSInfoCache(reflection.DefaultInfoCache): def __init__(self, *args, **kwargs): - default.DefaultInfoCache.__init__(self, *args, **kwargs) + reflection.DefaultInfoCache.__init__(self, *args, **kwargs) class MSDialect(default.DefaultDialect): @@ -1147,27 +1147,19 @@ class MSDialect(default.DefaultDialect): row = c.fetchone() return row is not None + @reflection.caches def get_schema_names(self, connection, info_cache=None): - if info_cache: - schema_names = info_cache.getSchemaNames() - if schema_names is not None: - return schema_names - import sqlalchemy.databases.information_schema as ischema + import sqlalchemy.dialects.information_schema as ischema s = sql.select([self.uppercase_table(ischema.schemata).c.schema_name], order_by=[ischema.schemata.c.schema_name] ) schema_names = [r[0] for r in connection.execute(s)] - if info_cache: - info_cache.addAllSchemas(schema_names) return schema_names + @reflection.caches def get_table_names(self, connection, schemaname, info_cache=None): - import sqlalchemy.databases.information_schema as ischema + import sqlalchemy.dialects.information_schema as ischema current_schema = schemaname or self.get_default_schema_name(connection) - if info_cache: - table_names = info_cache.getTableNames(current_schema) - if table_names is not None: - return table_names tables = self.uppercase_table(ischema.tables) s = sql.select([tables.c.table_name], sql.and_( @@ -1177,17 +1169,12 @@ class MSDialect(default.DefaultDialect): order_by=[tables.c.table_name] ) table_names = [r[0] for r in connection.execute(s)] - if info_cache: - info_cache.addAllTables(table_names, current_schema) return table_names + @reflection.caches def get_view_names(self, connection, schemaname=None, info_cache=None): - import sqlalchemy.databases.information_schema as ischema + import sqlalchemy.dialects.information_schema as ischema current_schema = schemaname or self.get_default_schema_name(connection) - if info_cache: - view_names = info_cache.getViewNames(current_schema) - if view_names is not None: - return view_names tables = self.uppercase_table(ischema.tables) s = sql.select([tables.c.table_name], sql.and_( @@ -1197,17 +1184,12 @@ class MSDialect(default.DefaultDialect): order_by=[tables.c.table_name] ) view_names = [r[0] for r in connection.execute(s)] - if info_cache: - info_cache.addAllViews(view_names, schemaname) return view_names + @reflection.caches def get_indexes(self, connection, tablename, schemaname=None, info_cache=None): current_schema = schemaname or self.get_default_schema_name(connection) - if info_cache: - table_cache = info_cache.getTable(tablename, current_schema) - if table_cache and 'indexes' in table_cache: - return table_cache.get('indexes') full_tname = "%s.%s" % (current_schema, tablename) indexes = [] s = sql.text("exec sp_helpindex '%s'" % full_tname) @@ -1219,20 +1201,13 @@ class MSDialect(default.DefaultDialect): 'column_names' : row['index_keys'].split(','), 'unique': 'unique' in row['index_description'] }) - if info_cache: - table_cache = info_cache.getTable(tablename, current_schema, - create=True) - table_cache['indexes'] = indexes return indexes + @reflection.caches def get_view_definition(self, connection, viewname, schemaname=None, info_cache=None): - import sqlalchemy.databases.information_schema as ischema + import sqlalchemy.dialects.information_schema as ischema current_schema = schemaname or self.get_default_schema_name(connection) - if info_cache: - view_cache = info_cache.getView(viewname, current_schema) - if view_cache and 'definition' in view_cache.keys(): - return view_cache.get('definition') views = self.uppercase_table(ischema.views) s = sql.select([views.c.view_definition], sql.and_( @@ -1243,20 +1218,13 @@ class MSDialect(default.DefaultDialect): rp = connection.execute(s) if rp: view_def = rp.scalar() - if info_cache: - view_cache = info_cache.getView(viewname, current_schema, - create=True) - view_cache['definition'] = view_def return view_def + @reflection.caches def get_columns(self, connection, tablename, schemaname=None, info_cache=None): # Get base columns current_schema = schemaname or self.get_default_schema_name(connection) - if info_cache: - table_cache = info_cache.getTable(tablename, current_schema) - if table_cache and 'columns' in table_cache.keys(): - return table_cache.get('columns') import sqlalchemy.dialects.information_schema as ischema columns = self.uppercase_table(ischema.columns) s = sql.select([columns], @@ -1311,20 +1279,13 @@ class MSDialect(default.DefaultDialect): 'attrs' : colargs } cols.append(cdict) - if info_cache: - table_cache = info_cache.getTable(tablename, current_schema, - create=True) - table_cache['columns'] = cols return cols + @reflection.caches def get_primary_keys(self, connection, tablename, schemaname=None, info_cache=None): import sqlalchemy.dialects.information_schema as ischema current_schema = schemaname or self.get_default_schema_name(connection) - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname) - if table_cache and 'primary_keys' in table_cache.keys(): - return table_cache.get('primary_keys') pkeys = [] # Add constraints RR = self.uppercase_table(ischema.ref_constraints) #information_schema.referential_constraints @@ -1342,20 +1303,13 @@ class MSDialect(default.DefaultDialect): for row in c: if 'PRIMARY' in row[TC.c.constraint_type.name]: pkeys.append(row[0]) - if info_cache: - table_cache = info_cache.getTable(tablename, current_schema, - create=True) - table_cache['primary_keys'] = pkeys return pkeys + @reflection.caches def get_foreign_keys(self, connection, tablename, schemaname=None, info_cache=None): import sqlalchemy.dialects.information_schema as ischema current_schema = schemaname or self.get_default_schema_name(connection) - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname) - if table_cache and 'foreign_keys' in table_cache.keys(): - return table_cache.get('foreign_keys') # Add constraints RR = self.uppercase_table(ischema.ref_constraints) #information_schema.referential_constraints TC = self.uppercase_table(ischema.constraints) #information_schema.table_constraints @@ -1392,8 +1346,8 @@ class MSDialect(default.DefaultDialect): fknm, scols, rcols = (rfknm, [], []) if not scol in scols: scols.append(scol) - if not (rschema, rtbl, rcol) in rcols: - rcols.append((rschema, rtbl, rcol)) + if not rcol in rcols: + rcols.append(rcol) if fknm and scols: fkeys.append({ 'name' : fknm, @@ -1402,9 +1356,6 @@ class MSDialect(default.DefaultDialect): 'referred_table' : rtbl, 'referred_columns' : rcols }) - if info_cache: - table_cache = info_cache.getTable(tablename, current_schema) - table_cache['foreign_keys'] = fkeys return fkeys def reflecttable(self, connection, table, include_columns): @@ -1489,4 +1440,8 @@ class MSDialect(default.DefaultDialect): else: schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection) - table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True)) + ##table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True)) + table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, rschema, rtbl, c) for c in rcols], fknm, link_to_name=True)) + +# fixme. I added this for the tests to run. -Randall +MSSQLDialect = MSDialect diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index de9d14265b..6fc87e1adb 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -75,7 +75,7 @@ is not in use this flag should be left off. import datetime, random, re from sqlalchemy import util, sql, schema, log -from sqlalchemy.engine import default, base +from sqlalchemy.engine import default, base, reflection from sqlalchemy.sql import compiler, visitors, expression from sqlalchemy.sql import operators as sql_operators, functions as sql_functions from sqlalchemy import types as sqltypes @@ -447,7 +447,7 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer): name = re.sub(r'^_+', '', savepoint.ident) return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) -class OracleInfoCache(default.DefaultInfoCache): +class OracleInfoCache(reflection.DefaultInfoCache): pass class OracleDialect(default.DefaultDialect): @@ -583,15 +583,18 @@ class OracleDialect(default.DefaultDialect): owner = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) return (actual_name, owner, dblink, synonym) + @reflection.caches def get_schema_names(self, connection, info_cache=None): s = "SELECT username FROM all_users ORDER BY username" cursor = connection.execute(s,) return [self._normalize_name(row[0]) for row in cursor] + @reflection.caches def get_table_names(self, connection, schemaname=None, info_cache=None): schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) return self.table_names(connection, schemaname) + @reflection.caches def get_view_names(self, connection, schemaname=None, info_cache=None): schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) s = "select view_name from all_views where OWNER = :owner" @@ -599,6 +602,7 @@ class OracleDialect(default.DefaultDialect): {'owner':self._denormalize_name(schemaname)}) return [self._normalize_name(row[0]) for row in cursor] + @reflection.caches def get_columns(self, connection, tablename, schemaname=None, info_cache=None, resolve_synonyms=False, dblink=''): @@ -606,10 +610,6 @@ class OracleDialect(default.DefaultDialect): (tablename, schemaname, dblink, synonym) = \ self._prepare_reflection_args(connection, tablename, schemaname, resolve_synonyms, dblink) - if info_cache: - columns = info_cache.getColumns(tablename, schemaname) - if columns: - return columns columns = [] c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':tablename, 'owner':schemaname}) @@ -654,10 +654,9 @@ class OracleDialect(default.DefaultDialect): 'attrs': colargs } columns.append(cdict) - if info_cache: - info_cache.setColumns(columns, tablename, schemaname) return columns + @reflection.caches def get_indexes(self, connection, tablename, schemaname=None, info_cache=None, resolve_synonyms=False, dblink=''): @@ -665,10 +664,6 @@ class OracleDialect(default.DefaultDialect): (tablename, schemaname, dblink, synonym) = \ self._prepare_reflection_args(connection, tablename, schemaname, resolve_synonyms, dblink) - if info_cache: - indexes = info_cache.getIndexes(tablename, schemaname) - if indexes: - return indexes indexes = [] q = """ SELECT a.INDEX_NAME, a.COLUMN_NAME, b.UNIQUENESS @@ -699,17 +694,11 @@ class OracleDialect(default.DefaultDialect): index['unique'] = uniqueness.get(rset.uniqueness, False) index['column_names'].append(rset.column_name) last_index_name = rset.index_name - if info_cache: - info_cache.setIndexes(indexes, tablename, schemaname) return indexes def _get_constraint_data(self, connection, tablename, schemaname=None, info_cache=None, dblink=''): - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname) - if table_cache and ['constraints'] in table_cache.keys(): - return table_cache['constraints'] rp = connection.execute("""SELECT ac.constraint_name, ac.constraint_type, @@ -731,21 +720,14 @@ class OracleDialect(default.DefaultDialect): ORDER BY ac.constraint_name, loc.position, rem.position""" % {'dblink':dblink}, {'table_name' : tablename, 'owner' : schemaname}) constraint_data = rp.fetchall() - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname, - create=True) - table_cache['constraints'] = constraint_data return constraint_data + @reflection.caches def get_primary_keys(self, connection, tablename, schemaname=None, info_cache=None, resolve_synonyms=False, dblink=''): (tablename, schemaname, dblink, synonym) = \ self._prepare_reflection_args(connection, tablename, schemaname, resolve_synonyms, dblink) - if info_cache: - pkeys = info_cache.getPrimaryKeys(tablename, schemaname) - if pkeys is not None: - return pkeys pkeys = [] constraint_data = self._get_constraint_data(connection, tablename, schemaname, info_cache, dblink) @@ -754,19 +736,14 @@ class OracleDialect(default.DefaultDialect): (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) if cons_type == 'P': pkeys.append(local_column) - if info_cache: - info_cache.setPrimaryKeys(pkeys, tablename, schemaname) return pkeys + @reflection.caches def get_foreign_keys(self, connection, tablename, schemaname=None, info_cache=None, resolve_synonyms=False, dblink=''): (tablename, schemaname, dblink, synonym) = \ self._prepare_reflection_args(connection, tablename, schemaname, resolve_synonyms, dblink) - if info_cache: - fkeys = info_cache.getForeignKeys(tablename, schemaname) - if fkeys is not None: - return fkeys constraint_data = self._get_constraint_data(connection, tablename, schemaname, info_cache, dblink) @@ -807,19 +784,14 @@ class OracleDialect(default.DefaultDialect): 'referred_columns' : value[1] } fkeys.append(fkey_d) - if info_cache: - info_cache.setForeignKeys(fkeys, tablename, schemaname) return fkeys + @reflection.caches def get_view_definition(self, connection, viewname, schemaname=None, info_cache=None, resolve_synonyms=False, dblink=''): (viewname, schemaname, dblink, synonym) = \ self._prepare_reflection_args(connection, viewname, schemaname, resolve_synonyms, dblink) - if info_cache: - view_cache = info_cache.getView(viewname, schemaname) - if view_cache and 'definition' in view_cache: - return view_cache['definition'] s = """ SELECT text FROM all_views WHERE owner = :schemaname @@ -829,10 +801,6 @@ class OracleDialect(default.DefaultDialect): viewname=viewname, schemaname=schemaname) if rp: view_def = rp.scalar().decode(self.encoding) - if info_cache: - view = info_cache.getView(viewname, schemaname, - create=True) - view['definition'] = view_def return view_def def reflecttable(self, connection, table, include_columns): diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 705778cc5b..d031e30ae8 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -67,7 +67,7 @@ option to the Index constructor:: import re from sqlalchemy import sql, schema, exc, util -from sqlalchemy.engine import base, default +from sqlalchemy.engine import base, default, reflection from sqlalchemy.sql import compiler, expression from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes @@ -397,19 +397,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): value = value[1:-1].replace('""','"') return value -class PGInfoCache(default.DefaultInfoCache): +class PGInfoCache(reflection.DefaultInfoCache): def __init__(self): - default.DefaultInfoCache.__init__(self) + reflection.DefaultInfoCache.__init__(self) - def getTableOID(self, tablename, schemaname=None): - table = self.getTable(tablename, schemaname) + def get_table_oid(self, tablename, schemaname=None): + table = self.get_table(tablename, schemaname) if table: return table.get('oid') - def setTableOID(self, oid, tablename, schemaname=None): - table = self.getTable(tablename, schemaname, create=True) + def set_table_oid(self, oid, tablename, schemaname=None): + table = self.get_table(tablename, schemaname, create=True) table['oid'] = oid class PGDialect(default.DefaultDialect): @@ -526,7 +526,7 @@ class PGDialect(default.DefaultDialect): """ table_oid = None if info_cache: - table_oid = info_cache.getTableOID(tablename, schemaname) + table_oid = info_cache.get_table_oid(tablename, schemaname) if table_oid: return table_oid if schemaname is not None: @@ -554,17 +554,14 @@ class PGDialect(default.DefaultDialect): c = connection.execute(s, table_name=tablename, schema=schemaname) table_oid = c.scalar() if table_oid is None: - raise exc.NoSuchTableError(table_name) + raise exc.NoSuchTableError(tablename) # cache it if info_cache: - info_cache.setTableOID(table_oid, tablename, schemaname) + info_cache.set_table_oid(table_oid, tablename, schemaname) return table_oid + @reflection.caches def get_schema_names(self, connection, info_cache=None): - if info_cache: - schema_names = info_cache.getSchemaNames() - if schema_names is not None: - return schema_names s = """ SELECT nspname FROM pg_namespace @@ -574,33 +571,23 @@ class PGDialect(default.DefaultDialect): # what about system tables? schema_names = [row[0].decode(self.encoding) for row in rp \ if not row[0].startswith('pg_')] - if info_cache: - info_cache.addAllSchemas(schema_names) return schema_names + @reflection.caches def get_table_names(self, connection, schemaname=None, info_cache=None): if schemaname is not None: current_schema = schemaname else: current_schema = self.get_default_schema_name(connection) - if info_cache: - table_names = info_cache.getTableNames(current_schema) - if table_names is not None: - return table_names table_names = self.table_names(connection, current_schema) - if info_cache: - info_cache.addAllTables(table_names, current_schema) return table_names + @reflection.caches def get_view_names(self, connection, schemaname=None, info_cache=None): if schemaname is not None: current_schema = schemaname else: current_schema = self.get_default_schema_name(connection) - if info_cache: - view_names = info_cache.getViewNames(current_schema) - if view_names is not None: - return view_names s = """ SELECT relname FROM pg_class c @@ -608,20 +595,15 @@ class PGDialect(default.DefaultDialect): AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) """ % dict(schema=current_schema) view_names = [row[0].decode(self.encoding) for row in connection.execute(s)] - if info_cache: - info_cache.addAllViews(view_names, schemaname) return view_names + @reflection.caches def get_view_definition(self, connection, viewname, schemaname=None, info_cache=None): if schemaname is not None: current_schema = schemaname else: current_schema = self.get_default_schema_name(connection) - if info_cache: - view_cache = info_cache.getView(viewname, current_schema) - if view_cache and 'definition' in view_cache: - return view_cache['definition'] s = """ SELECT definition FROM pg_views WHERE schemaname = :schemaname @@ -631,18 +613,12 @@ class PGDialect(default.DefaultDialect): viewname=viewname, schemaname=current_schema) if rp: view_def = rp.scalar().decode(self.encoding) - if info_cache: - view = info_cache.getView(viewname, current_schema, - create=True) - view['definition'] = view_def return view_def + @reflection.caches def get_columns(self, connection, tablename, schemaname=None, info_cache=None): - if info_cache: - columns = info_cache.getColumns(tablename, schemaname) - if columns is not None: - return columns + table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) SQL_COLS = """ @@ -726,16 +702,11 @@ class PGDialect(default.DefaultDialect): column_info = dict(name=name, type=coltype, nullable=nullable, default=default, colargs=colargs) columns.append(column_info) - if info_cache: - info_cache.setColumns(columns, tablename, schemaname) return columns + @reflection.caches def get_primary_keys(self, connection, tablename, schemaname=None, info_cache=None): - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname) - if table_cache and 'primary_keys' in table_cache.keys(): - return table_cache.get('primary_keys') table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) PK_SQL = """ @@ -749,18 +720,11 @@ class PGDialect(default.DefaultDialect): t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) primary_keys = [r[0] for r in c.fetchall()] - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname, - create=True) - table_cache['primary_keys'] = primary_keys return primary_keys + @reflection.caches def get_foreign_keys(self, connection, tablename, schemaname=None, info_cache=None): - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname) - if table_cache and 'foreign_keys' in table_cache.keys(): - return table_cache.get('foreign_keys') preparer = self.identifier_preparer table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) @@ -795,17 +759,10 @@ class PGDialect(default.DefaultDialect): 'referred_columns' : referred_columns } fkeys.append(fkey_d) - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname, - create=True) - table_cache['foreign_keys'] = fkeys return fkeys + @reflection.caches def get_indexes(self, connection, tablename, schemaname, info_cache=None): - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname) - if table_cache and 'indexes' in table_cache.keys(): - return table_cache.get('indexes') table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) IDX_SQL = """ @@ -844,10 +801,6 @@ class PGDialect(default.DefaultDialect): index_d['name'] = idx_name index_d['column_names'].append(col) index_d['unique'] = unique - if info_cache: - table_cache = info_cache.getTable(tablename, schemaname, - create=True) - table_cache['indexes'] = indexes return indexes def reflecttable(self, connection, table, include_columns): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index b50411c0cf..b719219a5d 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -20,182 +20,6 @@ from sqlalchemy import exc, types as sqltypes AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) -class DefaultInfoCache(object): - """Default implementation of InfoCache - - InfoCache provides a means for dialects to cache information obtained for - reflection and a convenient interface for setting and retrieving cached - data. - - """ - - def __init__(self): - self._cache = dict(schemas={}) - self.tables_are_complete = False - self.schemas_are_complete = False - self.views_are_complete = False - - def clear(self): - """Clear the cache.""" - self._cache = dict(schemas={}) - - def getSchemas(self): - """Return the schemas dict.""" - return self._cache.get('schemas') - - def getSchemaNames(self, check_complete=True): - """Return cached schema names. - - By default, only return them if they're complete. - - """ - if check_complete and self.schemas_are_complete: - return self.getSchemas().keys() - elif not check_complete: - return self.getSchemas().keys() - else: - return None - - def getSchema(self, schemaname, create=False): - """Return cached schema and optionally create it if it does not exist. - - """ - schema = self._cache['schemas'].get(schemaname) - if schema is not None: - return schema - elif create: - return self.addSchema(schemaname) - return None - - def addSchema(self, schemaname): - self._cache['schemas'][schemaname] = dict(tables={}, views={}) - return self.getSchema(schemaname) - - def addAllSchemas(self, schemanames): - for schemaname in schemanames: - self.addSchema(schemaname) - self.schemas_are_complete = True - - def getTable(self, tablename, schemaname=None, create=False, - table_type='table'): - """Return cached table and optionally create it if it does not exist. - - - """ - cache = self._cache - schema = self.getSchema(schemaname, create=create) - if schema is None: - return None - if table_type == 'view': - table = schema['views'].get(tablename) - else: - table = schema['tables'].get(tablename) - if table is not None: - return table - elif create: - return self.addTable(tablename, schemaname, table_type=table_type) - return None - - def getTableNames(self, schemaname=None, check_complete=True, - table_type='table'): - """Return cached table names. - - By default, only return them if they're complete. - - """ - if table_type == 'view': - complete = self.views_are_complete - else: - complete = self.tables_are_complete - if check_complete and complete: - return self.getTables(schemaname, table_type=table_type).keys() - elif not check_complete: - return self.getTables(schemaname, table_type=table_type).keys() - else: - return None - - def addTable(self, tablename, schemaname=None, table_type='table'): - schema = self.getSchema(schemaname, create=True) - if table_type == 'table': - schema['tables'][tablename] = dict(columns={}) - else: - schema['views'][tablename] = dict(columns={}) - return self.getTable(tablename, schemaname, table_type=table_type) - - def addAllTables(self, tablenames, schemaname=None, table_type='table'): - for tablename in tablenames: - self.addTable(tablename, schemaname, table_type) - if table_type == 'view': - self.views_are_complete = True - else: - self.tables_are_complete = True - - def getView(self, viewname, schemaname=None, create=False): - return self.getTable(viewname, schemaname, create, 'view') - - def getViewNames(self, schemaname=None, check_complete=True): - return self.getTableNames(schemaname, check_complete, 'view') - - def addView(self, viewname, schemaname=None): - return self.addTable(viewname, schemaname, 'view') - - def addAllViews(self, viewnames, schemaname=None): - return self.addAllTables(viewnames, schemaname, 'view') - - def _getTableData(self, key, tablename, schemaname=None): - table_cache = self.getTable(tablename, schemaname) - if table_cache is not None and key in table_cache.keys(): - return table_cache[key] - - def _setTableData(self, key, data, tablename, schemaname=None): - """Cache data for schemaname.tablename using key. - - It will create a schema and table entry in the cache if needed. - - """ - table_cache = self.getTable(tablename, schemaname, create=True) - table_cache[key] = data - - def getColumns(self, tablename, schemaname=None): - """Return columns list or None.""" - - return self._getTableData('columns', tablename, schemaname) - - def setColumns(self, columns, tablename, schemaname=None): - """Add list of columns to table cache.""" - - return self._setTableData('columns', columns, tablename, schemaname) - - def getPrimaryKeys(self, tablename, schemaname=None): - """Return primary key list or None.""" - - return self._getTableData('primary_keys', tablename, schemaname) - - def setPrimaryKeys(self, pkeys, tablename, schemaname=None): - """Add list of primary keys to table cache.""" - - return self._setTableData('primary_keys', pkeys, tablename, schemaname) - - def getForeignKeys(self, tablename, schemaname=None): - """Return foreign key list or None.""" - - return self._getTableData('foreign_keys', tablename, schemaname) - - def setForeignKeys(self, fkeys, tablename, schemaname=None): - """Add list of foreign keys to table cache.""" - - return self._setTableData('foreign_keys', fkeys, tablename, schemaname) - - def getIndexes(self, tablename, schemaname=None): - """Return indexes list or None.""" - - return self._getTableData('indexes', tablename, schemaname) - - def setIndexes(self, indexes, tablename, schemaname=None): - """Add list of indexes to table cache.""" - - return self._setTableData('indexes', indexes, tablename, schemaname) - class DefaultDialect(base.Dialect): """Default implementation of Dialect""" diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 677cafee95..7f8143d600 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -18,8 +18,240 @@ I'm still trying to decide upon conventions for both the Inspector interface as """ import sqlalchemy +from sqlalchemy import util from sqlalchemy.types import TypeEngine + +@util.decorator +def caches(fn, self, con, *args, **kw): + # what are we caching? + fn_name = fn.__name__ + if not fn_name.startswith('get_'): + # don't recognize this. + return fn(self, con, *args, **kw) + else: + attr_to_cache = fn_name[4:] + # The first arguments will always be self and con. + # Assuming *args and *kw will be acceptable to info_cache method. + if 'info_cache' in kw: + kw_cp = kw.copy() + info_cache = kw_cp.pop('info_cache') + methodname = "%s_%s" % ('get', attr_to_cache) + # fixme. + for bad_kw in ('dblink', 'resolve_synonyms'): + if bad_kw in kw_cp: + del kw_cp[bad_kw] + information = getattr(info_cache, methodname)(*args, **kw_cp) + if information: + return information + information = fn(self, con, *args, **kw) + if 'info_cache' in locals(): + methodname = "%s_%s" % ('set', attr_to_cache) + getattr(info_cache, methodname)(information, *args, **kw_cp) + return information + +class DefaultInfoCache(object): + """Default implementation of InfoCache + + InfoCache provides a means for dialects to cache information obtained for + reflection and a convenient interface for setting and retrieving cached + data. + + """ + + def __init__(self): + self._cache = dict(schemas={}) + self.tables_are_complete = False + self.schemas_are_complete = False + self.views_are_complete = False + + def clear(self): + """Clear the cache.""" + self._cache = dict(schemas={}) + + # schemas + + def get_schemas(self): + """Return the schemas dict.""" + return self._cache.get('schemas') + + + def get_schema(self, schemaname, create=False): + """Return cached schema and optionally create it if it does not exist. + + """ + schema = self._cache['schemas'].get(schemaname) + if schema is not None: + return schema + elif create: + return self.add_schema(schemaname) + return None + + def add_schema(self, schemaname): + self._cache['schemas'][schemaname] = dict(tables={}, views={}) + return self.get_schema(schemaname) + + def get_schema_names(self, check_complete=True): + """Return cached schema names. + + By default, only return them if they're complete. + + """ + if check_complete and self.schemas_are_complete: + return self.get_schemas().keys() + elif not check_complete: + return self.get_schemas().keys() + else: + return None + + def set_schema_names(self, schemanames): + for schemaname in schemanames: + self.add_schema(schemaname) + self.schemas_are_complete = True + + # tables + + def get_table(self, tablename, schemaname=None, create=False, + table_type='table'): + """Return cached table and optionally create it if it does not exist. + + + """ + cache = self._cache + schema = self.get_schema(schemaname, create=create) + if schema is None: + return None + if table_type == 'view': + table = schema['views'].get(tablename) + else: + table = schema['tables'].get(tablename) + if table is not None: + return table + elif create: + return self.add_table(tablename, schemaname, table_type=table_type) + return None + + def get_table_names(self, schemaname=None, check_complete=True, + table_type='table'): + """Return cached table names. + + By default, only return them if they're complete. + + """ + if table_type == 'view': + complete = self.views_are_complete + else: + complete = self.tables_are_complete + if check_complete and complete: + return self.get_tables(schemaname, table_type=table_type).keys() + elif not check_complete: + return self.get_tables(schemaname, table_type=table_type).keys() + else: + return None + + def add_table(self, tablename, schemaname=None, table_type='table'): + schema = self.get_schema(schemaname, create=True) + if table_type == 'table': + schema['tables'][tablename] = dict(columns={}) + else: + schema['views'][tablename] = dict(columns={}) + return self.get_table(tablename, schemaname, table_type=table_type) + + def set_table_names(self, tablenames, schemaname=None, table_type='table'): + for tablename in tablenames: + self.add_table(tablename, schemaname, table_type) + if table_type == 'view': + self.views_are_complete = True + else: + self.tables_are_complete = True + + # views + + def get_view(self, viewname, schemaname=None, create=False): + return self.get_table(viewname, schemaname, create, 'view') + + def get_view_names(self, schemaname=None, check_complete=True): + return self.get_table_names(schemaname, check_complete, 'view') + + def add_view(self, viewname, schemaname=None): + return self.add_table(viewname, schemaname, 'view') + + def set_view_names(self, viewnames, schemaname=None): + return self.set_table_names(viewnames, schemaname, 'view') + + def get_view_definition(self, viewname, schemaname=None): + view_cache = self.get_view(viewname, schemaname) + if view_cache and 'definition' in view_cache: + return view_cache['definition'] + + def set_view_definition(self, definition, viewname, schemaname=None): + view_cache = self.get_view(viewname, schemaname, create=True) + view_cache['definition'] = definition + + # table data + + def _get_table_data(self, key, tablename, schemaname=None): + table_cache = self.get_table(tablename, schemaname) + if table_cache is not None and key in table_cache.keys(): + return table_cache[key] + + def _set_table_data(self, key, data, tablename, schemaname=None): + """Cache data for schemaname.tablename using key. + + It will create a schema and table entry in the cache if needed. + + """ + table_cache = self.get_table(tablename, schemaname, create=True) + table_cache[key] = data + + # columns + + def get_columns(self, tablename, schemaname=None): + """Return columns list or None.""" + + return self._get_table_data('columns', tablename, schemaname) + + def set_columns(self, columns, tablename, schemaname=None): + """Add list of columns to table cache.""" + + return self._set_table_data('columns', columns, tablename, schemaname) + + # primary keys + + def get_primary_keys(self, tablename, schemaname=None): + """Return primary key list or None.""" + + return self._get_table_data('primary_keys', tablename, schemaname) + + def set_primary_keys(self, pkeys, tablename, schemaname=None): + """Add list of primary keys to table cache.""" + + return self._set_table_data('primary_keys', pkeys, tablename, schemaname) + + # foreign keys + + def get_foreign_keys(self, tablename, schemaname=None): + """Return foreign key list or None.""" + + return self._get_table_data('foreign_keys', tablename, schemaname) + + def set_foreign_keys(self, fkeys, tablename, schemaname=None): + """Add list of foreign keys to table cache.""" + + return self._set_table_data('foreign_keys', fkeys, tablename, schemaname) + + # indexes + + def get_indexes(self, tablename, schemaname=None): + """Return indexes list or None.""" + + return self._get_table_data('indexes', tablename, schemaname) + + def set_indexes(self, indexes, tablename, schemaname=None): + """Add list of indexes to table cache.""" + + return self._set_table_data('indexes', indexes, tablename, schemaname) + class Inspector(object): """performs database introspection @@ -129,7 +361,7 @@ class Inspector(object): col_defs = self.engine.dialect.get_columns(self.conn, tablename, schemaname, - self.info_cache) + info_cache=self.info_cache) for col_def in col_defs: # make this easy and only return instances for coltype coltype = col_def['type'] diff --git a/test/reflection.py b/test/reflection.py index 39240487c5..23e5befd31 100644 --- a/test/reflection.py +++ b/test/reflection.py @@ -70,7 +70,7 @@ def createViews(con, schema=None): if schema: fullname = "%s.%s" % (schema, tablename) view_name = fullname + '_v' - query = "CREATE OR REPLACE VIEW %s AS SELECT * FROM %s" % (view_name, + query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name, fullname) con.execute(sa.sql.text(query)) -- 2.47.3