]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
reflection methods not use decorator for caching
authorRandall Smith <randall@tnr.cc>
Wed, 18 Feb 2009 06:38:53 +0000 (06:38 +0000)
committerRandall Smith <randall@tnr.cc>
Wed, 18 Feb 2009 06:38:53 +0000 (06:38 +0000)
lib/sqlalchemy/dialects/information_schema.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/reflection.py
test/reflection.py

index b15082ac2ebb6afa0b4ee97553730f19dceb3d02..9a65cca4cd05cd0be92a47dfd4e5f83bd431ea01 100644 (file)
@@ -78,6 +78,14 @@ ref_constraints = Table("referential_constraints", ischema,
     Column("delete_rule", String),
     schema="information_schema")
 
+views = Table("views", ischema,
+    Column("table_catalog", String),
+    Column("table_schema", String),
+    Column("table_name", String),
+    Column("view_definition", String),
+    Column("check_option", String),
+    Column("is_updatable", String),
+    schema="information_schema")
 
 def table_names(connection, schema):
     s = select([tables.c.table_name], tables.c.table_schema==schema)
index 43fe4b5d57c0ff609f7cab8dd9bcd9ea3928d929..b45e7cd5aa4d01e5386034a77a8dc0820cce1223 100644 (file)
@@ -231,7 +231,7 @@ import datetime, decimal, inspect, operator, sys, re
 
 from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.sql import compiler, expression, operators as sql_operators, functions as sql_functions
-from sqlalchemy.engine import default, base
+from sqlalchemy.engine import default, base, reflection
 from sqlalchemy import types as sqltypes
 from decimal import Decimal as _python_Decimal
 
@@ -1044,10 +1044,10 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
         return value
 
 
-class MSInfoCache(default.DefaultInfoCache):
+class MSInfoCache(reflection.DefaultInfoCache):
     
     def __init__(self, *args, **kwargs):
-        default.DefaultInfoCache.__init__(self, *args, **kwargs)
+        reflection.DefaultInfoCache.__init__(self, *args, **kwargs)
 
 
 class MSDialect(default.DefaultDialect):
@@ -1147,27 +1147,19 @@ class MSDialect(default.DefaultDialect):
         row  = c.fetchone()
         return row is not None
 
+    @reflection.caches 
     def get_schema_names(self, connection, info_cache=None):
-        if info_cache:
-            schema_names = info_cache.getSchemaNames()
-            if schema_names is not None:
-                return schema_names
-        import sqlalchemy.databases.information_schema as ischema
+        import sqlalchemy.dialects.information_schema as ischema
         s = sql.select([self.uppercase_table(ischema.schemata).c.schema_name],
             order_by=[ischema.schemata.c.schema_name]
         )
         schema_names = [r[0] for r in connection.execute(s)]
-        if info_cache:
-            info_cache.addAllSchemas(schema_names)
         return schema_names
 
+    @reflection.caches
     def get_table_names(self, connection, schemaname, info_cache=None):
-        import sqlalchemy.databases.information_schema as ischema
+        import sqlalchemy.dialects.information_schema as ischema
         current_schema = schemaname or self.get_default_schema_name(connection)
-        if info_cache:
-            table_names = info_cache.getTableNames(current_schema)
-            if table_names is not None:
-                return table_names
         tables = self.uppercase_table(ischema.tables)
         s = sql.select([tables.c.table_name],
             sql.and_(
@@ -1177,17 +1169,12 @@ class MSDialect(default.DefaultDialect):
             order_by=[tables.c.table_name]
         )
         table_names = [r[0] for r in connection.execute(s)]
-        if info_cache:
-            info_cache.addAllTables(table_names, current_schema)
         return table_names
 
+    @reflection.caches
     def get_view_names(self, connection, schemaname=None, info_cache=None):
-        import sqlalchemy.databases.information_schema as ischema
+        import sqlalchemy.dialects.information_schema as ischema
         current_schema = schemaname or self.get_default_schema_name(connection)
-        if info_cache:
-            view_names = info_cache.getViewNames(current_schema)
-            if view_names is not None:
-                return view_names
         tables = self.uppercase_table(ischema.tables)
         s = sql.select([tables.c.table_name],
             sql.and_(
@@ -1197,17 +1184,12 @@ class MSDialect(default.DefaultDialect):
             order_by=[tables.c.table_name]
         )
         view_names = [r[0] for r in connection.execute(s)]
-        if info_cache:
-            info_cache.addAllViews(view_names, schemaname)
         return view_names
 
+    @reflection.caches
     def get_indexes(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         current_schema = schemaname or self.get_default_schema_name(connection)
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, current_schema)
-            if table_cache and 'indexes' in table_cache:
-                return table_cache.get('indexes')
         full_tname = "%s.%s" % (current_schema, tablename)
         indexes = []
         s = sql.text("exec sp_helpindex '%s'" % full_tname)
@@ -1219,20 +1201,13 @@ class MSDialect(default.DefaultDialect):
                     'column_names' : row['index_keys'].split(','),
                     'unique': 'unique' in row['index_description']
                 })
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, current_schema,
-                                              create=True)
-            table_cache['indexes'] = indexes
         return indexes
 
+    @reflection.caches
     def get_view_definition(self, connection, viewname, schemaname=None,
                             info_cache=None):
-        import sqlalchemy.databases.information_schema as ischema
+        import sqlalchemy.dialects.information_schema as ischema
         current_schema = schemaname or self.get_default_schema_name(connection)
-        if info_cache:
-            view_cache = info_cache.getView(viewname, current_schema)
-            if view_cache and 'definition' in view_cache.keys():
-                return view_cache.get('definition')
         views = self.uppercase_table(ischema.views)
         s = sql.select([views.c.view_definition],
             sql.and_(
@@ -1243,20 +1218,13 @@ class MSDialect(default.DefaultDialect):
         rp = connection.execute(s)
         if rp:
             view_def = rp.scalar()
-            if info_cache:
-                view_cache = info_cache.getView(viewname, current_schema,
-                                                create=True)
-                view_cache['definition'] = view_def
             return view_def
 
+    @reflection.caches
     def get_columns(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         # Get base columns
         current_schema = schemaname or self.get_default_schema_name(connection)
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, current_schema)
-            if table_cache and 'columns' in table_cache.keys():
-                return table_cache.get('columns')
         import sqlalchemy.dialects.information_schema as ischema
         columns = self.uppercase_table(ischema.columns)
         s = sql.select([columns],
@@ -1311,20 +1279,13 @@ class MSDialect(default.DefaultDialect):
                 'attrs' : colargs
             }
             cols.append(cdict)
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, current_schema,
-                                              create=True)
-            table_cache['columns'] = cols
         return cols
 
+    @reflection.caches
     def get_primary_keys(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
         current_schema = schemaname or self.get_default_schema_name(connection)
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname)
-            if table_cache and 'primary_keys' in table_cache.keys():
-                return table_cache.get('primary_keys')
         pkeys = []
         # Add constraints
         RR = self.uppercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
@@ -1342,20 +1303,13 @@ class MSDialect(default.DefaultDialect):
         for row in c:
             if 'PRIMARY' in row[TC.c.constraint_type.name]:
                 pkeys.append(row[0])
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, current_schema,
-                                              create=True)
-            table_cache['primary_keys'] = pkeys
         return pkeys
 
+    @reflection.caches
     def get_foreign_keys(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
         current_schema = schemaname or self.get_default_schema_name(connection)
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname)
-            if table_cache and 'foreign_keys' in table_cache.keys():
-                return table_cache.get('foreign_keys')
         # Add constraints
         RR = self.uppercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
         TC = self.uppercase_table(ischema.constraints)        #information_schema.table_constraints
@@ -1392,8 +1346,8 @@ class MSDialect(default.DefaultDialect):
                 fknm, scols, rcols = (rfknm, [], [])
             if not scol in scols:
                 scols.append(scol)
-            if not (rschema, rtbl, rcol) in rcols:
-                rcols.append((rschema, rtbl, rcol))
+            if not rcol in rcols:
+                rcols.append(rcol)
         if fknm and scols:
             fkeys.append({
                 'name' : fknm,
@@ -1402,9 +1356,6 @@ class MSDialect(default.DefaultDialect):
                 'referred_table' : rtbl,
                 'referred_columns' : rcols
             })
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, current_schema)
-            table_cache['foreign_keys'] = fkeys
         return fkeys
 
     def reflecttable(self, connection, table, include_columns):
@@ -1489,4 +1440,8 @@ class MSDialect(default.DefaultDialect):
             else:
                 schema.Table(rtbl, table.metadata, schema=rschema,
                              autoload=True, autoload_with=connection)
-            table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
+            ##table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
+            table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, rschema, rtbl, c) for c in rcols], fknm, link_to_name=True))
+
+# fixme.  I added this for the tests to run. -Randall
+MSSQLDialect = MSDialect
index de9d14265ba28f140e539b2d94cb7ebb15096443..6fc87e1adb4b2c5887ee3d7a17c3b5517bcfd294 100644 (file)
@@ -75,7 +75,7 @@ is not in use this flag should be left off.
 import datetime, random, re
 
 from sqlalchemy import util, sql, schema, log
-from sqlalchemy.engine import default, base
+from sqlalchemy.engine import default, base, reflection
 from sqlalchemy.sql import compiler, visitors, expression
 from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
 from sqlalchemy import types as sqltypes
@@ -447,7 +447,7 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
         name = re.sub(r'^_+', '', savepoint.ident)
         return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
         
-class OracleInfoCache(default.DefaultInfoCache):
+class OracleInfoCache(reflection.DefaultInfoCache):
     pass
 
 class OracleDialect(default.DefaultDialect):
@@ -583,15 +583,18 @@ class OracleDialect(default.DefaultDialect):
             owner = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
         return (actual_name, owner, dblink, synonym)
 
+    @reflection.caches
     def get_schema_names(self, connection, info_cache=None):
         s = "SELECT username FROM all_users ORDER BY username"
         cursor = connection.execute(s,)
         return [self._normalize_name(row[0]) for row in cursor]
 
+    @reflection.caches
     def get_table_names(self, connection, schemaname=None, info_cache=None):
         schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
         return self.table_names(connection, schemaname)
 
+    @reflection.caches
     def get_view_names(self, connection, schemaname=None, info_cache=None):
         schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
         s = "select view_name from all_views where OWNER = :owner"
@@ -599,6 +602,7 @@ class OracleDialect(default.DefaultDialect):
                 {'owner':self._denormalize_name(schemaname)})
         return [self._normalize_name(row[0]) for row in cursor]
 
+    @reflection.caches
     def get_columns(self, connection, tablename, schemaname=None,
                     info_cache=None, resolve_synonyms=False, dblink=''):
 
@@ -606,10 +610,6 @@ class OracleDialect(default.DefaultDialect):
         (tablename, schemaname, dblink, synonym) = \
             self._prepare_reflection_args(connection, tablename, schemaname,
                                           resolve_synonyms, dblink)
-        if info_cache:
-            columns = info_cache.getColumns(tablename, schemaname)
-            if columns:
-                return columns
         columns = []
         c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':tablename, 'owner':schemaname})
 
@@ -654,10 +654,9 @@ class OracleDialect(default.DefaultDialect):
                 'attrs': colargs
             }
             columns.append(cdict)
-        if info_cache:
-            info_cache.setColumns(columns, tablename, schemaname)
         return columns
 
+    @reflection.caches
     def get_indexes(self, connection, tablename, schemaname=None,
                     info_cache=None, resolve_synonyms=False, dblink=''):
 
@@ -665,10 +664,6 @@ class OracleDialect(default.DefaultDialect):
         (tablename, schemaname, dblink, synonym) = \
             self._prepare_reflection_args(connection, tablename, schemaname,
                                           resolve_synonyms, dblink)
-        if info_cache:
-            indexes = info_cache.getIndexes(tablename, schemaname)
-            if indexes:
-                return indexes
         indexes = []
         q = """
         SELECT a.INDEX_NAME, a.COLUMN_NAME, b.UNIQUENESS
@@ -699,17 +694,11 @@ class OracleDialect(default.DefaultDialect):
             index['unique'] = uniqueness.get(rset.uniqueness, False)
             index['column_names'].append(rset.column_name)
             last_index_name = rset.index_name
-        if info_cache:
-            info_cache.setIndexes(indexes, tablename, schemaname)
         return indexes
 
     def _get_constraint_data(self, connection, tablename, schemaname=None,
                              info_cache=None, dblink=''):
 
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname)
-            if table_cache and ['constraints'] in table_cache.keys():
-                return table_cache['constraints']
         rp = connection.execute("""SELECT
              ac.constraint_name,
              ac.constraint_type,
@@ -731,21 +720,14 @@ class OracleDialect(default.DefaultDialect):
            ORDER BY ac.constraint_name, loc.position, rem.position"""
          % {'dblink':dblink}, {'table_name' : tablename, 'owner' : schemaname})
         constraint_data = rp.fetchall()
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname,
-                                              create=True)
-            table_cache['constraints'] = constraint_data
         return constraint_data
 
+    @reflection.caches
     def get_primary_keys(self, connection, tablename, schemaname=None,
                          info_cache=None, resolve_synonyms=False, dblink=''):
         (tablename, schemaname, dblink, synonym) = \
             self._prepare_reflection_args(connection, tablename, schemaname,
                                           resolve_synonyms, dblink)
-        if info_cache:
-            pkeys = info_cache.getPrimaryKeys(tablename, schemaname)
-            if pkeys is not None:
-                return pkeys
         pkeys = []
         constraint_data = self._get_constraint_data(connection, tablename,
                                         schemaname, info_cache, dblink)
@@ -754,19 +736,14 @@ class OracleDialect(default.DefaultDialect):
             (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
             if cons_type == 'P':
                 pkeys.append(local_column)
-        if info_cache:
-            info_cache.setPrimaryKeys(pkeys, tablename, schemaname)
         return pkeys
 
+    @reflection.caches
     def get_foreign_keys(self, connection, tablename, schemaname=None,
                          info_cache=None, resolve_synonyms=False, dblink=''):
         (tablename, schemaname, dblink, synonym) = \
             self._prepare_reflection_args(connection, tablename, schemaname,
                                           resolve_synonyms, dblink)
-        if info_cache:
-            fkeys = info_cache.getForeignKeys(tablename, schemaname)
-            if fkeys is not None:
-                return fkeys
 
         constraint_data = self._get_constraint_data(connection, tablename,
                                                 schemaname, info_cache, dblink)
@@ -807,19 +784,14 @@ class OracleDialect(default.DefaultDialect):
                     'referred_columns' : value[1]
                 }
                 fkeys.append(fkey_d)
-        if info_cache:
-            info_cache.setForeignKeys(fkeys, tablename, schemaname)
         return fkeys
 
+    @reflection.caches
     def get_view_definition(self, connection, viewname, schemaname=None,
                             info_cache=None, resolve_synonyms=False, dblink=''):
         (viewname, schemaname, dblink, synonym) = \
             self._prepare_reflection_args(connection, viewname, schemaname,
                                           resolve_synonyms, dblink)
-        if info_cache:
-            view_cache = info_cache.getView(viewname, schemaname)
-            if view_cache and 'definition' in view_cache:
-                return view_cache['definition']
         s = """
         SELECT text FROM all_views
         WHERE owner = :schemaname
@@ -829,10 +801,6 @@ class OracleDialect(default.DefaultDialect):
                                 viewname=viewname, schemaname=schemaname)
         if rp:
             view_def = rp.scalar().decode(self.encoding)
-            if info_cache:
-                view = info_cache.getView(viewname, schemaname,
-                                          create=True)
-                view['definition'] = view_def
             return view_def
 
     def reflecttable(self, connection, table, include_columns):
index 705778cc5b16576cdba08b6f3b20dc60543d1493..d031e30ae8ff02afdbb96f45a4460bc157f49c21 100644 (file)
@@ -67,7 +67,7 @@ option to the Index constructor::
 import re
 
 from sqlalchemy import sql, schema, exc, util
-from sqlalchemy.engine import base, default
+from sqlalchemy.engine import base, default, reflection
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
@@ -397,19 +397,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
             value = value[1:-1].replace('""','"')
         return value
 
-class PGInfoCache(default.DefaultInfoCache):
+class PGInfoCache(reflection.DefaultInfoCache):
     
     def __init__(self):
 
-        default.DefaultInfoCache.__init__(self)
+        reflection.DefaultInfoCache.__init__(self)
 
-    def getTableOID(self, tablename, schemaname=None):
-        table = self.getTable(tablename, schemaname)
+    def get_table_oid(self, tablename, schemaname=None):
+        table = self.get_table(tablename, schemaname)
         if table:
             return table.get('oid')
 
-    def setTableOID(self, oid, tablename, schemaname=None):
-        table = self.getTable(tablename, schemaname, create=True)
+    def set_table_oid(self, oid, tablename, schemaname=None):
+        table = self.get_table(tablename, schemaname, create=True)
         table['oid'] = oid
 
 class PGDialect(default.DefaultDialect):
@@ -526,7 +526,7 @@ class PGDialect(default.DefaultDialect):
         """
         table_oid = None
         if info_cache:
-            table_oid = info_cache.getTableOID(tablename, schemaname)
+            table_oid = info_cache.get_table_oid(tablename, schemaname)
         if table_oid:
             return table_oid
         if schemaname is not None:
@@ -554,17 +554,14 @@ class PGDialect(default.DefaultDialect):
         c = connection.execute(s, table_name=tablename, schema=schemaname)
         table_oid = c.scalar()
         if table_oid is None:
-            raise exc.NoSuchTableError(table_name)
+            raise exc.NoSuchTableError(tablename)
         # cache it
         if info_cache:
-            info_cache.setTableOID(table_oid, tablename, schemaname)
+            info_cache.set_table_oid(table_oid, tablename, schemaname)
         return table_oid
 
+    @reflection.caches
     def get_schema_names(self, connection, info_cache=None):
-        if info_cache:
-            schema_names = info_cache.getSchemaNames()
-            if schema_names is not None:
-                return schema_names
         s = """
         SELECT nspname
         FROM pg_namespace
@@ -574,33 +571,23 @@ class PGDialect(default.DefaultDialect):
         # what about system tables?
         schema_names = [row[0].decode(self.encoding) for row in rp \
                         if not row[0].startswith('pg_')]
-        if info_cache:
-            info_cache.addAllSchemas(schema_names)
         return schema_names
 
+    @reflection.caches
     def get_table_names(self, connection, schemaname=None, info_cache=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
             current_schema = self.get_default_schema_name(connection)
-        if info_cache:
-            table_names = info_cache.getTableNames(current_schema)
-            if table_names is not None:
-                return table_names
         table_names = self.table_names(connection, current_schema)
-        if info_cache:
-            info_cache.addAllTables(table_names, current_schema)
         return table_names
 
+    @reflection.caches
     def get_view_names(self, connection, schemaname=None, info_cache=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
             current_schema = self.get_default_schema_name(connection)
-        if info_cache:
-            view_names = info_cache.getViewNames(current_schema)
-            if view_names is not None:
-                return view_names
         s = """
         SELECT relname
         FROM pg_class c
@@ -608,20 +595,15 @@ class PGDialect(default.DefaultDialect):
           AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
         """ % dict(schema=current_schema)
         view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
-        if info_cache:
-            info_cache.addAllViews(view_names, schemaname)
         return view_names
 
+    @reflection.caches
     def get_view_definition(self, connection, viewname, schemaname=None,
                                                             info_cache=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
             current_schema = self.get_default_schema_name(connection)
-        if info_cache:
-            view_cache = info_cache.getView(viewname, current_schema)
-            if view_cache and 'definition' in view_cache:
-                return view_cache['definition']
         s = """
         SELECT definition FROM pg_views
         WHERE schemaname = :schemaname
@@ -631,18 +613,12 @@ class PGDialect(default.DefaultDialect):
                                 viewname=viewname, schemaname=current_schema)
         if rp:
             view_def = rp.scalar().decode(self.encoding)
-            if info_cache:
-                view = info_cache.getView(viewname, current_schema,
-                                          create=True)
-                view['definition'] = view_def
             return view_def
 
+    @reflection.caches
     def get_columns(self, connection, tablename, schemaname=None,
                     info_cache=None):
-        if info_cache:
-            columns = info_cache.getColumns(tablename, schemaname)
-            if columns is not None:
-                return columns
+
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         SQL_COLS = """
@@ -726,16 +702,11 @@ class PGDialect(default.DefaultDialect):
             column_info = dict(name=name, type=coltype, nullable=nullable,
                                default=default, colargs=colargs)
             columns.append(column_info)
-        if info_cache:
-            info_cache.setColumns(columns, tablename, schemaname)
         return columns
 
+    @reflection.caches
     def get_primary_keys(self, connection, tablename, schemaname=None,
                          info_cache=None):
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname)
-            if table_cache and 'primary_keys' in table_cache.keys():
-                return table_cache.get('primary_keys')
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         PK_SQL = """
@@ -749,18 +720,11 @@ class PGDialect(default.DefaultDialect):
         t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
         c = connection.execute(t, table_oid=table_oid)
         primary_keys = [r[0] for r in c.fetchall()]
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname,
-                                              create=True)
-            table_cache['primary_keys'] = primary_keys
         return primary_keys
 
+    @reflection.caches
     def get_foreign_keys(self, connection, tablename, schemaname=None,
                          info_cache=None):
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname)
-            if table_cache and 'foreign_keys' in table_cache.keys():
-                return table_cache.get('foreign_keys')
         preparer = self.identifier_preparer
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
@@ -795,17 +759,10 @@ class PGDialect(default.DefaultDialect):
                 'referred_columns' : referred_columns
             }
             fkeys.append(fkey_d)
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname,
-                                              create=True)
-            table_cache['foreign_keys'] = fkeys
         return fkeys
 
+    @reflection.caches
     def get_indexes(self, connection, tablename, schemaname, info_cache=None):
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname)
-            if table_cache and 'indexes' in table_cache.keys():
-                return table_cache.get('indexes')
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         IDX_SQL = """
@@ -844,10 +801,6 @@ class PGDialect(default.DefaultDialect):
             index_d['name'] = idx_name
             index_d['column_names'].append(col)
             index_d['unique'] = unique
-        if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname,
-                                              create=True)
-            table_cache['indexes'] = indexes
         return indexes
 
     def reflecttable(self, connection, table, include_columns):
index b50411c0cff2c369faec02fde4f925a409d3d17e..b719219a5df794c2c20b8af18d4c0bb5a047d7b7 100644 (file)
@@ -20,182 +20,6 @@ from sqlalchemy import exc, types as sqltypes
 AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
                                re.I | re.UNICODE)
 
-class DefaultInfoCache(object):
-    """Default implementation of InfoCache
-
-    InfoCache provides a means for dialects to cache information obtained for
-    reflection and a convenient interface for setting and retrieving cached
-    data.
-
-    """
-    
-    def __init__(self):
-        self._cache = dict(schemas={})
-        self.tables_are_complete = False
-        self.schemas_are_complete = False
-        self.views_are_complete = False
-
-    def clear(self):
-        """Clear the cache."""
-        self._cache = dict(schemas={})
-
-    def getSchemas(self):
-        """Return the schemas dict."""
-        return self._cache.get('schemas')
-
-    def getSchemaNames(self, check_complete=True):
-        """Return cached schema names.
-
-        By default, only return them if they're complete.
-
-        """
-        if check_complete and self.schemas_are_complete:
-            return self.getSchemas().keys()
-        elif not check_complete:
-            return self.getSchemas().keys()
-        else:
-            return None
-
-    def getSchema(self, schemaname, create=False):
-        """Return cached schema and optionally create it if it does not exist.
-
-        """
-        schema = self._cache['schemas'].get(schemaname)
-        if schema is not None:
-            return schema
-        elif create:
-            return self.addSchema(schemaname)
-        return None
-
-    def addSchema(self, schemaname):
-        self._cache['schemas'][schemaname] = dict(tables={}, views={})
-        return self.getSchema(schemaname)
-
-    def addAllSchemas(self, schemanames):
-        for schemaname in schemanames:
-            self.addSchema(schemaname)
-        self.schemas_are_complete = True
-
-    def getTable(self, tablename, schemaname=None, create=False,
-                                                        table_type='table'):
-        """Return cached table and optionally create it if it does not exist.
-
-
-        """
-        cache = self._cache
-        schema = self.getSchema(schemaname, create=create)
-        if schema is None:
-            return None
-        if table_type == 'view':
-            table = schema['views'].get(tablename)
-        else:
-            table = schema['tables'].get(tablename)
-        if table is not None:
-            return table
-        elif create:
-            return self.addTable(tablename, schemaname, table_type=table_type)
-        return None
-
-    def getTableNames(self, schemaname=None, check_complete=True,
-                                                        table_type='table'):
-        """Return cached table names.
-
-        By default, only return them if they're complete.
-
-        """
-        if table_type == 'view':
-            complete = self.views_are_complete
-        else:
-            complete = self.tables_are_complete
-        if check_complete and complete:
-            return self.getTables(schemaname, table_type=table_type).keys()
-        elif not check_complete:
-            return self.getTables(schemaname, table_type=table_type).keys()
-        else:
-            return None
-
-    def addTable(self, tablename, schemaname=None, table_type='table'):
-        schema = self.getSchema(schemaname, create=True)
-        if table_type == 'table':
-            schema['tables'][tablename] = dict(columns={})
-        else:
-            schema['views'][tablename] = dict(columns={})
-        return self.getTable(tablename, schemaname, table_type=table_type)
-
-    def addAllTables(self, tablenames, schemaname=None, table_type='table'):
-        for tablename in tablenames:
-            self.addTable(tablename, schemaname, table_type)
-        if table_type == 'view':
-            self.views_are_complete = True
-        else:
-            self.tables_are_complete = True
-            
-    def getView(self, viewname, schemaname=None, create=False):
-        return self.getTable(viewname, schemaname, create, 'view')
-
-    def getViewNames(self, schemaname=None, check_complete=True):
-        return self.getTableNames(schemaname, check_complete, 'view')
-
-    def addView(self, viewname, schemaname=None):
-        return self.addTable(viewname, schemaname, 'view')
-
-    def addAllViews(self, viewnames, schemaname=None):
-        return self.addAllTables(viewnames, schemaname, 'view')
-
-    def _getTableData(self, key, tablename, schemaname=None):
-        table_cache = self.getTable(tablename, schemaname)
-        if table_cache is not None and key in table_cache.keys():
-            return table_cache[key]
-
-    def _setTableData(self, key, data, tablename, schemaname=None):
-        """Cache data for schemaname.tablename using key.
-
-        It will create a schema and table entry in the cache if needed.
-
-        """
-        table_cache = self.getTable(tablename, schemaname, create=True)
-        table_cache[key] = data
-
-    def getColumns(self, tablename, schemaname=None):
-        """Return columns list or None."""
-        
-        return self._getTableData('columns', tablename, schemaname)
-
-    def setColumns(self, columns, tablename, schemaname=None):
-        """Add list of columns to table cache."""
-
-        return self._setTableData('columns', columns, tablename, schemaname)
-
-    def getPrimaryKeys(self, tablename, schemaname=None):
-        """Return primary key list or None."""
-        
-        return self._getTableData('primary_keys', tablename, schemaname)
-
-    def setPrimaryKeys(self, pkeys, tablename, schemaname=None):
-        """Add list of primary keys to table cache."""
-
-        return self._setTableData('primary_keys', pkeys, tablename, schemaname)
-
-    def getForeignKeys(self, tablename, schemaname=None):
-        """Return foreign key list or None."""
-        
-        return self._getTableData('foreign_keys', tablename, schemaname)
-
-    def setForeignKeys(self, fkeys, tablename, schemaname=None):
-        """Add list of foreign keys to table cache."""
-
-        return self._setTableData('foreign_keys', fkeys, tablename, schemaname)
-
-    def getIndexes(self, tablename, schemaname=None):
-        """Return indexes list or None."""
-        
-        return self._getTableData('indexes', tablename, schemaname)
-
-    def setIndexes(self, indexes, tablename, schemaname=None):
-        """Add list of indexes to table cache."""
-
-        return self._setTableData('indexes', indexes, tablename, schemaname)
-
 class DefaultDialect(base.Dialect):
     """Default implementation of Dialect"""
 
index 677cafee9585aea17b1a49659c0e45223a7350db..7f8143d6006443aaeb044a959d39de2fc4604bdf 100644 (file)
@@ -18,8 +18,240 @@ I'm still trying to decide upon conventions for both the Inspector interface as
 
 """
 import sqlalchemy
+from sqlalchemy import util
 from sqlalchemy.types import TypeEngine
 
+
+@util.decorator
+def caches(fn, self, con, *args, **kw):
+    # what are we caching?
+    fn_name = fn.__name__
+    if not fn_name.startswith('get_'):
+        # don't recognize this.
+        return fn(self, con, *args, **kw)
+    else:
+        attr_to_cache = fn_name[4:]
+    # The first arguments will always be self and con.
+    # Assuming *args and *kw will be acceptable to info_cache method.
+    if 'info_cache' in kw:
+        kw_cp = kw.copy()
+        info_cache = kw_cp.pop('info_cache')
+        methodname = "%s_%s" % ('get', attr_to_cache)
+        # fixme.
+        for bad_kw in ('dblink', 'resolve_synonyms'):
+            if bad_kw in kw_cp:
+                del kw_cp[bad_kw]
+        information = getattr(info_cache, methodname)(*args, **kw_cp)
+        if information:
+            return information
+    information = fn(self, con, *args, **kw)
+    if 'info_cache' in locals():
+        methodname = "%s_%s" % ('set', attr_to_cache)
+        getattr(info_cache, methodname)(information, *args, **kw_cp)
+    return information 
+
+class DefaultInfoCache(object):
+    """Default implementation of InfoCache
+
+    InfoCache provides a means for dialects to cache information obtained for
+    reflection and a convenient interface for setting and retrieving cached
+    data.
+
+    """
+    
+    def __init__(self):
+        self._cache = dict(schemas={})
+        self.tables_are_complete = False
+        self.schemas_are_complete = False
+        self.views_are_complete = False
+
+    def clear(self):
+        """Clear the cache."""
+        self._cache = dict(schemas={})
+
+    # schemas
+
+    def get_schemas(self):
+        """Return the schemas dict."""
+        return self._cache.get('schemas')
+
+
+    def get_schema(self, schemaname, create=False):
+        """Return cached schema and optionally create it if it does not exist.
+
+        """
+        schema = self._cache['schemas'].get(schemaname)
+        if schema is not None:
+            return schema
+        elif create:
+            return self.add_schema(schemaname)
+        return None
+
+    def add_schema(self, schemaname):
+        self._cache['schemas'][schemaname] = dict(tables={}, views={})
+        return self.get_schema(schemaname)
+
+    def get_schema_names(self, check_complete=True):
+        """Return cached schema names.
+
+        By default, only return them if they're complete.
+
+        """
+        if check_complete and self.schemas_are_complete:
+            return self.get_schemas().keys()
+        elif not check_complete:
+            return self.get_schemas().keys()
+        else:
+            return None
+
+    def set_schema_names(self, schemanames):
+        for schemaname in schemanames:
+            self.add_schema(schemaname)
+        self.schemas_are_complete = True
+
+    # tables
+
+    def get_table(self, tablename, schemaname=None, create=False,
+                                                        table_type='table'):
+        """Return cached table and optionally create it if it does not exist.
+
+
+        """
+        cache = self._cache
+        schema = self.get_schema(schemaname, create=create)
+        if schema is None:
+            return None
+        if table_type == 'view':
+            table = schema['views'].get(tablename)
+        else:
+            table = schema['tables'].get(tablename)
+        if table is not None:
+            return table
+        elif create:
+            return self.add_table(tablename, schemaname, table_type=table_type)
+        return None
+
+    def get_table_names(self, schemaname=None, check_complete=True,
+                                                        table_type='table'):
+        """Return cached table names.
+
+        By default, only return them if they're complete.
+
+        """
+        if table_type == 'view':
+            complete = self.views_are_complete
+        else:
+            complete = self.tables_are_complete
+        if check_complete and complete:
+            return self.get_tables(schemaname, table_type=table_type).keys()
+        elif not check_complete:
+            return self.get_tables(schemaname, table_type=table_type).keys()
+        else:
+            return None
+
+    def add_table(self, tablename, schemaname=None, table_type='table'):
+        schema = self.get_schema(schemaname, create=True)
+        if table_type == 'table':
+            schema['tables'][tablename] = dict(columns={})
+        else:
+            schema['views'][tablename] = dict(columns={})
+        return self.get_table(tablename, schemaname, table_type=table_type)
+
+    def set_table_names(self, tablenames, schemaname=None, table_type='table'):
+        for tablename in tablenames:
+            self.add_table(tablename, schemaname, table_type)
+        if table_type == 'view':
+            self.views_are_complete = True
+        else:
+            self.tables_are_complete = True
+            
+    # views
+
+    def get_view(self, viewname, schemaname=None, create=False):
+        return self.get_table(viewname, schemaname, create, 'view')
+
+    def get_view_names(self, schemaname=None, check_complete=True):
+        return self.get_table_names(schemaname, check_complete, 'view')
+
+    def add_view(self, viewname, schemaname=None):
+        return self.add_table(viewname, schemaname, 'view')
+
+    def set_view_names(self, viewnames, schemaname=None):
+        return self.set_table_names(viewnames, schemaname, 'view')
+
+    def get_view_definition(self, viewname, schemaname=None):
+        view_cache = self.get_view(viewname, schemaname)
+        if view_cache and 'definition' in view_cache:
+            return view_cache['definition']
+
+    def set_view_definition(self, definition, viewname, schemaname=None):
+        view_cache = self.get_view(viewname, schemaname, create=True)
+        view_cache['definition'] = definition
+
+    # table data
+
+    def _get_table_data(self, key, tablename, schemaname=None):
+        table_cache = self.get_table(tablename, schemaname)
+        if table_cache is not None and key in table_cache.keys():
+            return table_cache[key]
+
+    def _set_table_data(self, key, data, tablename, schemaname=None):
+        """Cache data for schemaname.tablename using key.
+
+        It will create a schema and table entry in the cache if needed.
+
+        """
+        table_cache = self.get_table(tablename, schemaname, create=True)
+        table_cache[key] = data
+
+    # columns
+
+    def get_columns(self, tablename, schemaname=None):
+        """Return columns list or None."""
+        
+        return self._get_table_data('columns', tablename, schemaname)
+
+    def set_columns(self, columns, tablename, schemaname=None):
+        """Add list of columns to table cache."""
+
+        return self._set_table_data('columns', columns, tablename, schemaname)
+
+    # primary keys
+
+    def get_primary_keys(self, tablename, schemaname=None):
+        """Return primary key list or None."""
+        
+        return self._get_table_data('primary_keys', tablename, schemaname)
+
+    def set_primary_keys(self, pkeys, tablename, schemaname=None):
+        """Add list of primary keys to table cache."""
+
+        return self._set_table_data('primary_keys', pkeys, tablename, schemaname)
+
+    # foreign keys
+
+    def get_foreign_keys(self, tablename, schemaname=None):
+        """Return foreign key list or None."""
+        
+        return self._get_table_data('foreign_keys', tablename, schemaname)
+
+    def set_foreign_keys(self, fkeys, tablename, schemaname=None):
+        """Add list of foreign keys to table cache."""
+
+        return self._set_table_data('foreign_keys', fkeys, tablename, schemaname)
+
+    # indexes
+
+    def get_indexes(self, tablename, schemaname=None):
+        """Return indexes list or None."""
+        
+        return self._get_table_data('indexes', tablename, schemaname)
+
+    def set_indexes(self, indexes, tablename, schemaname=None):
+        """Add list of indexes to table cache."""
+
+        return self._set_table_data('indexes', indexes, tablename, schemaname)
+
 class Inspector(object):
     """performs database introspection
 
@@ -129,7 +361,7 @@ class Inspector(object):
 
         col_defs = self.engine.dialect.get_columns(self.conn, tablename,
                                                    schemaname,
-                                                   self.info_cache)
+                                                   info_cache=self.info_cache)
         for col_def in col_defs:
             # make this easy and only return instances for coltype
             coltype = col_def['type']
index 39240487c552f0698fb69111f8e58e344bece861..23e5befd31390286c98b8435302dffcfa86b0ab9 100644 (file)
@@ -70,7 +70,7 @@ def createViews(con, schema=None):
         if schema:
             fullname = "%s.%s" % (schema, tablename)
         view_name = fullname + '_v'
-        query = "CREATE OR REPLACE VIEW %s AS SELECT * FROM %s" % (view_name,
+        query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name,
                                                                    fullname)
         con.execute(sa.sql.text(query))