]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dialects can subclass Inspector
authorRandall Smith <randall@tnr.cc>
Mon, 2 Mar 2009 05:46:16 +0000 (05:46 +0000)
committerRandall Smith <randall@tnr.cc>
Mon, 2 Mar 2009 05:46:16 +0000 (05:46 +0000)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/reflection.py
test/reflection.py

index b45e7cd5aa4d01e5386034a77a8dc0820cce1223..f6bb9ad85032c233d12ce22ecac946dce9d8f929 100644 (file)
@@ -1044,12 +1044,6 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
         return value
 
 
-class MSInfoCache(reflection.DefaultInfoCache):
-    
-    def __init__(self, *args, **kwargs):
-        reflection.DefaultInfoCache.__init__(self, *args, **kwargs)
-
-
 class MSDialect(default.DefaultDialect):
     name = 'mssql'
     supports_default_values = True
@@ -1071,7 +1065,6 @@ class MSDialect(default.DefaultDialect):
     ddl_compiler = MSDDLCompiler
     type_compiler = MSTypeCompiler
     preparer = MSIdentifierPreparer
-    info_cache = MSInfoCache
 
     def __init__(self,
                  auto_identity_insert=True, query_timeout=None,
@@ -1147,7 +1140,7 @@ class MSDialect(default.DefaultDialect):
         row  = c.fetchone()
         return row is not None
 
-    @reflection.cache
+    @reflection.cache
     def get_schema_names(self, connection, info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
         s = sql.select([self.uppercase_table(ischema.schemata).c.schema_name],
@@ -1156,7 +1149,7 @@ class MSDialect(default.DefaultDialect):
         schema_names = [r[0] for r in connection.execute(s)]
         return schema_names
 
-    @reflection.caches
+    @reflection.cache
     def get_table_names(self, connection, schemaname, info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
         current_schema = schemaname or self.get_default_schema_name(connection)
@@ -1171,7 +1164,7 @@ class MSDialect(default.DefaultDialect):
         table_names = [r[0] for r in connection.execute(s)]
         return table_names
 
-    @reflection.caches
+    @reflection.cache
     def get_view_names(self, connection, schemaname=None, info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
         current_schema = schemaname or self.get_default_schema_name(connection)
@@ -1186,7 +1179,7 @@ class MSDialect(default.DefaultDialect):
         view_names = [r[0] for r in connection.execute(s)]
         return view_names
 
-    @reflection.caches
+    @reflection.cache
     def get_indexes(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         current_schema = schemaname or self.get_default_schema_name(connection)
@@ -1203,7 +1196,7 @@ class MSDialect(default.DefaultDialect):
                 })
         return indexes
 
-    @reflection.caches
+    @reflection.cache
     def get_view_definition(self, connection, viewname, schemaname=None,
                             info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
@@ -1220,7 +1213,7 @@ class MSDialect(default.DefaultDialect):
             view_def = rp.scalar()
             return view_def
 
-    @reflection.caches
+    @reflection.cache
     def get_columns(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         # Get base columns
@@ -1281,7 +1274,7 @@ class MSDialect(default.DefaultDialect):
             cols.append(cdict)
         return cols
 
-    @reflection.caches
+    @reflection.cache
     def get_primary_keys(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
@@ -1305,7 +1298,7 @@ class MSDialect(default.DefaultDialect):
                 pkeys.append(row[0])
         return pkeys
 
-    @reflection.caches
+    @reflection.cache
     def get_foreign_keys(self, connection, tablename, schemaname=None,
                                                             info_cache=None):
         import sqlalchemy.dialects.information_schema as ischema
index 6fc87e1adb4b2c5887ee3d7a17c3b5517bcfd294..2c4f326e8ba9dfeab517db8942cbcb559f1ca086 100644 (file)
@@ -74,7 +74,8 @@ is not in use this flag should be left off.
 
 import datetime, random, re
 
-from sqlalchemy import util, sql, schema, log
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import util, sql, log
 from sqlalchemy.engine import default, base, reflection
 from sqlalchemy.sql import compiler, visitors, expression
 from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
@@ -447,9 +448,6 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
         name = re.sub(r'^_+', '', savepoint.ident)
         return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
         
-class OracleInfoCache(reflection.DefaultInfoCache):
-    pass
-
 class OracleDialect(default.DefaultDialect):
     name = 'oracle'
     supports_alter = True
@@ -474,7 +472,6 @@ class OracleDialect(default.DefaultDialect):
     type_compiler = OracleTypeCompiler
     preparer = OracleIdentifierPreparer
     defaultrunner = OracleDefaultRunner
-    info_cache = OracleInfoCache
     
     def __init__(self, 
                 use_ansi=True, 
@@ -568,50 +565,50 @@ class OracleDialect(default.DefaultDialect):
             else:
                 return None, None, None, None
 
-    def _prepare_reflection_args(self, connection, tablename, schemaname=None,
+    def _prepare_reflection_args(self, connection, table_name, schema=None,
                                  resolve_synonyms=False, dblink=''):
 
         if resolve_synonyms:
-            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schemaname), desired_synonym=self._denormalize_name(tablename))
+            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schema), desired_synonym=self._denormalize_name(table_name))
         else:
             actual_name, owner, dblink, synonym = None, None, None, None
         if not actual_name:
-            actual_name = self._denormalize_name(tablename)
+            actual_name = self._denormalize_name(table_name)
         if not dblink:
             dblink = ''
         if not owner:
-            owner = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
+            owner = self._denormalize_name(schema or self.get_default_schema_name(connection))
         return (actual_name, owner, dblink, synonym)
 
-    @reflection.caches
-    def get_schema_names(self, connection, info_cache=None):
+    @reflection.cache
+    def get_schema_names(self, connection, **kw):
         s = "SELECT username FROM all_users ORDER BY username"
         cursor = connection.execute(s,)
         return [self._normalize_name(row[0]) for row in cursor]
 
-    @reflection.caches
-    def get_table_names(self, connection, schemaname=None, info_cache=None):
-        schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
-        return self.table_names(connection, schemaname)
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        schema = self._denormalize_name(schema or self.get_default_schema_name(connection))
+        return self.table_names(connection, schema)
 
-    @reflection.caches
-    def get_view_names(self, connection, schemaname=None, info_cache=None):
-        schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        schema = self._denormalize_name(schema or self.get_default_schema_name(connection))
         s = "select view_name from all_views where OWNER = :owner"
         cursor = connection.execute(s,
-                {'owner':self._denormalize_name(schemaname)})
+                {'owner':self._denormalize_name(schema)})
         return [self._normalize_name(row[0]) for row in cursor]
 
-    @reflection.caches
-    def get_columns(self, connection, tablename, schemaname=None,
-                    info_cache=None, resolve_synonyms=False, dblink=''):
+    @reflection.cache
+    def get_columns(self, connection, table_name, schema=None,
+                    resolve_synonyms=False, dblink='', **kw):
 
         
-        (tablename, schemaname, dblink, synonym) = \
-            self._prepare_reflection_args(connection, tablename, schemaname,
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
                                           resolve_synonyms, dblink)
         columns = []
-        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':tablename, 'owner':schemaname})
+        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':table_name, 'owner':schema})
 
         while True:
             row = c.fetchone()
@@ -645,7 +642,7 @@ class OracleDialect(default.DefaultDialect):
 
             colargs = []
             if default is not None:
-                colargs.append(schema.DefaultClause(sql.text(default)))
+                colargs.append(sa_schema.DefaultClause(sql.text(default)))
             cdict = {
                 'name': colname,
                 'type': coltype,
@@ -656,13 +653,13 @@ class OracleDialect(default.DefaultDialect):
             columns.append(cdict)
         return columns
 
-    @reflection.caches
-    def get_indexes(self, connection, tablename, schemaname=None,
-                    info_cache=None, resolve_synonyms=False, dblink=''):
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None,
+                    resolve_synonyms=False, dblink='', **kw):
 
         
-        (tablename, schemaname, dblink, synonym) = \
-            self._prepare_reflection_args(connection, tablename, schemaname,
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
                                           resolve_synonyms, dblink)
         indexes = []
         q = """
@@ -672,17 +669,18 @@ class OracleDialect(default.DefaultDialect):
             ON a.INDEX_NAME = b.INDEX_NAME
             AND a.TABLE_OWNER = b.TABLE_OWNER
             AND a.TABLE_NAME = b.TABLE_NAME
-        WHERE a.TABLE_NAME = :tablename
-        AND a.TABLE_OWNER = :schemaname
+        WHERE a.TABLE_NAME = :table_name
+        AND a.TABLE_OWNER = :schema
         ORDER BY a.INDEX_NAME, a.COLUMN_POSITION
         """ % dict(dblink=dblink)
         rp = connection.execute(q,
-            dict(tablename=self._denormalize_name(tablename),
-                 schemaname=self._denormalize_name(schemaname)))
+            dict(table_name=self._denormalize_name(table_name),
+                 schema=self._denormalize_name(schema)))
         indexes = []
         last_index_name = None
-        pkeys = self.get_primary_keys(connection, tablename, schemaname,
-                                      info_cache, resolve_synonyms, dblink)
+        pkeys = self.get_primary_keys(connection, table_name, schema,
+                                      resolve_synonyms, dblink,
+                                      info_cache=info_cache)
         uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
         for rset in rp:
             # don't include the primary key columns
@@ -696,8 +694,9 @@ class OracleDialect(default.DefaultDialect):
             last_index_name = rset.index_name
         return indexes
 
-    def _get_constraint_data(self, connection, tablename, schemaname=None,
-                             info_cache=None, dblink=''):
+    @reflection.cache
+    def _get_constraint_data(self, connection, table_name, schema=None,
+                            dblink='', **kw):
 
         rp = connection.execute("""SELECT
              ac.constraint_name,
@@ -718,19 +717,20 @@ class OracleDialect(default.DefaultDialect):
            AND ac.r_constraint_name = rem.constraint_name(+)
            -- order multiple primary keys correctly
            ORDER BY ac.constraint_name, loc.position, rem.position"""
-         % {'dblink':dblink}, {'table_name' : tablename, 'owner' : schemaname})
+         % {'dblink':dblink}, {'table_name' : table_name, 'owner' : schema})
         constraint_data = rp.fetchall()
         return constraint_data
 
-    @reflection.caches
-    def get_primary_keys(self, connection, tablename, schemaname=None,
-                         info_cache=None, resolve_synonyms=False, dblink=''):
-        (tablename, schemaname, dblink, synonym) = \
-            self._prepare_reflection_args(connection, tablename, schemaname,
+    @reflection.cache
+    def get_primary_keys(self, connection, table_name, schema=None,
+                         resolve_synonyms=False, dblink='', **kw):
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
                                           resolve_synonyms, dblink)
         pkeys = []
-        constraint_data = self._get_constraint_data(connection, tablename,
-                                        schemaname, info_cache, dblink)
+        constraint_data = self._get_constraint_data(connection, table_name,
+                                        schema, dblink,
+                                        info_cache=kw.get('info_cache'))
         for row in constraint_data:
             #print "ROW:" , row
             (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
@@ -738,15 +738,16 @@ class OracleDialect(default.DefaultDialect):
                 pkeys.append(local_column)
         return pkeys
 
-    @reflection.caches
-    def get_foreign_keys(self, connection, tablename, schemaname=None,
-                         info_cache=None, resolve_synonyms=False, dblink=''):
-        (tablename, schemaname, dblink, synonym) = \
-            self._prepare_reflection_args(connection, tablename, schemaname,
+    @reflection.cache
+    def get_foreign_keys(self, connection, table_name, schema=None,
+                         resolve_synonyms=False, dblink='', **kw):
+        (table_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, table_name, schema,
                                           resolve_synonyms, dblink)
 
-        constraint_data = self._get_constraint_data(connection, tablename,
-                                                schemaname, info_cache, dblink)
+        constraint_data = self._get_constraint_data(connection, table_name,
+                                                schema, dblink,
+                                                info_cache=kw.get('info_cache'))
         fkeys = []
         fks = {}
         for row in constraint_data:
@@ -786,26 +787,26 @@ class OracleDialect(default.DefaultDialect):
                 fkeys.append(fkey_d)
         return fkeys
 
-    @reflection.caches
-    def get_view_definition(self, connection, viewname, schemaname=None,
-                            info_cache=None, resolve_synonyms=False, dblink=''):
-        (viewname, schemaname, dblink, synonym) = \
-            self._prepare_reflection_args(connection, viewname, schemaname,
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None,
+                            resolve_synonyms=False, dblink='', **kw):
+        (view_name, schema, dblink, synonym) = \
+            self._prepare_reflection_args(connection, view_name, schema,
                                           resolve_synonyms, dblink)
         s = """
         SELECT text FROM all_views
-        WHERE owner = :schemaname
-        AND view_name = :viewname
+        WHERE owner = :schema
+        AND view_name = :view_name
         """
         rp = connection.execute(sql.text(s),
-                                viewname=viewname, schemaname=schemaname)
+                                view_name=view_name, schema=schema)
         if rp:
             view_def = rp.scalar().decode(self.encoding)
             return view_def
 
     def reflecttable(self, connection, table, include_columns):
         preparer = self.identifier_preparer
-        info_cache = OracleInfoCache()
+        info_cache = {}
 
         resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
 
@@ -814,8 +815,8 @@ class OracleDialect(default.DefaultDialect):
                                           resolve_synonyms)
 
         # columns
-        columns = self.get_columns(connection, actual_name, owner, info_cache,
-                                                                        dblink)
+        columns = self.get_columns(connection, actual_name, owner, dblink,
+                                   info_cache=info_cache)
         for cdict in columns:
             colname = cdict['name']
             coltype = cdict['type']
@@ -823,14 +824,14 @@ class OracleDialect(default.DefaultDialect):
             colargs = cdict['attrs']
             if include_columns and colname not in include_columns:
                 continue
-            table.append_column(schema.Column(colname, coltype,
+            table.append_column(sa_schema.Column(colname, coltype,
                                               nullable=nullable, *colargs))
         if not table.columns:
             raise AssertionError("Couldn't find any column information for table %s" % actual_name)
 
         # primary keys
         for pkcol in self.get_primary_keys(connection, actual_name, owner,
-                                                           info_cache, dblink):
+                                           dblink, info_cache=info_cache):
             if pkcol in table.c:
                 table.primary_key.add(table.c[pkcol])
 
@@ -838,7 +839,8 @@ class OracleDialect(default.DefaultDialect):
         fks = {}
         fkeys = []
         fkeys = self.get_foreign_keys(connection, actual_name, owner,
-                                      info_cache, resolve_synonyms, dblink)
+                                      resolve_synonyms, dblink,
+                                      info_cache=info_cache)
         refspecs = []
         for fkey_d in fkeys:
             conname = fkey_d['name']
@@ -848,17 +850,17 @@ class OracleDialect(default.DefaultDialect):
             referred_columns = fkey_d['referred_columns']
             for (i, ref_col) in enumerate(referred_columns):
                 if not table.schema and self._denormalize_name(referred_schema) == self._denormalize_name(owner):
-                    t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
+                    t = sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
 
                     refspec =  ".".join([referred_table, ref_col])
                 else:
                     refspec = '.'.join([x for x in [referred_schema,
                                     referred_table, ref_col] if x is not None])
 
-                    t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
+                    t = sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
                 refspecs.append(refspec)
             table.append_constraint(
-                schema.ForeignKeyConstraint(constrained_columns, refspecs,
+                sa_schema.ForeignKeyConstraint(constrained_columns, refspecs,
                                         name=conname, link_to_name=True))
 
 
index cd64a3c648d381fdc4c2b1295eaf571b0994b9fe..010122876274f89e46e6113bd67fb2a6d2726d15 100644 (file)
@@ -66,6 +66,7 @@ option to the Index constructor::
 
 import re
 
+from sqlalchemy import schema as sa_schema
 from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.engine import base, default, reflection
 from sqlalchemy.sql import compiler, expression
@@ -397,6 +398,18 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
             value = value[1:-1].replace('""','"')
         return value
 
+class PGInspector(reflection.Inspector):
+
+    def __init__(self, conn):
+        reflection.Inspector.__init__(self, conn)
+
+    def get_table_oid(self, table_name, schema=None):
+        """Return the oid from `table_name` and `schema`."""
+
+        return self.dialect.get_table_oid(self.conn, table_name, schema,
+                                          info_cache=self.info_cache)
+    
+
 class PGDialect(default.DefaultDialect):
     name = 'postgres'
     supports_alter = True
@@ -417,6 +430,7 @@ class PGDialect(default.DefaultDialect):
     type_compiler = PGTypeCompiler
     preparer = PGIdentifierPreparer
     defaultrunner = PGDefaultRunner
+    inspector = PGInspector
 
 
     def do_begin_twophase(self, connection, xid):
@@ -500,8 +514,8 @@ class PGDialect(default.DefaultDialect):
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
     @reflection.cache
-    def get_table_oid(self, connection, tablename, schemaname=None, **kw):
-        """Fetch the oid for schemaname.tablename.
+    def get_table_oid(self, connection, table_name, schema=None, **kw):
+        """Fetch the oid for schema.table_name.
 
         Several reflection methods require the table oid.  The idea for using
         this method is that it can be fetched one time and cached for
@@ -509,7 +523,7 @@ class PGDialect(default.DefaultDialect):
 
         """
         table_oid = None
-        if schemaname is not None:
+        if schema is not None:
             schema_where_clause = "n.nspname = :schema"
         else:
             schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
@@ -520,21 +534,21 @@ class PGDialect(default.DefaultDialect):
             WHERE (%s)
             AND c.relname = :table_name AND c.relkind in ('r','v')
         """ % schema_where_clause
-        # Since we're binding to unicode, tablename and schemaname must be
+        # Since we're binding to unicode, table_name and schema_name must be
         # unicode.
-        tablename = unicode(tablename)
-        if schemaname is not None:
-            schemaname = unicode(schemaname)
+        table_name = unicode(table_name)
+        if schema is not None:
+            schema = unicode(schema)
         s = sql.text(query, bindparams=[
             sql.bindparam('table_name', type_=sqltypes.Unicode),
             sql.bindparam('schema', type_=sqltypes.Unicode)
             ],
             typemap={'oid':sqltypes.Integer}
         )
-        c = connection.execute(s, table_name=tablename, schema=schemaname)
+        c = connection.execute(s, table_name=table_name, schema=schema)
         table_oid = c.scalar()
         if table_oid is None:
-            raise exc.NoSuchTableError(tablename)
+            raise exc.NoSuchTableError(table_name)
         return table_oid
 
     @reflection.cache
@@ -551,18 +565,18 @@ class PGDialect(default.DefaultDialect):
         return schema_names
 
     @reflection.cache
-    def get_table_names(self, connection, schemaname=None, **kw):
-        if schemaname is not None:
-            current_schema = schemaname
+    def get_table_names(self, connection, schema=None, **kw):
+        if schema is not None:
+            current_schema = schema
         else:
             current_schema = self.get_default_schema_name(connection)
         table_names = self.table_names(connection, current_schema)
         return table_names
 
     @reflection.cache
-    def get_view_names(self, connection, schemaname=None, **kw):
-        if schemaname is not None:
-            current_schema = schemaname
+    def get_view_names(self, connection, schema=None, **kw):
+        if schema is not None:
+            current_schema = schema
         else:
             current_schema = self.get_default_schema_name(connection)
         s = """
@@ -575,26 +589,26 @@ class PGDialect(default.DefaultDialect):
         return view_names
 
     @reflection.cache
-    def get_view_definition(self, connection, viewname, schemaname=None, **kw):
-        if schemaname is not None:
-            current_schema = schemaname
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+        if schema is not None:
+            current_schema = schema
         else:
             current_schema = self.get_default_schema_name(connection)
         s = """
         SELECT definition FROM pg_views
-        WHERE schemaname = :schemaname
-        AND viewname = :viewname
+        WHERE schemaname = :schema
+        AND viewname = :view_name
         """
         rp = connection.execute(sql.text(s),
-                                viewname=viewname, schemaname=current_schema)
+                                view_name=view_name, schema=current_schema)
         if rp:
             view_def = rp.scalar().decode(self.encoding)
             return view_def
 
     @reflection.cache
-    def get_columns(self, connection, tablename, schemaname=None, **kw):
+    def get_columns(self, connection, table_name, schema=None, **kw):
 
-        table_oid = self.get_table_oid(connection, tablename, schemaname,
+        table_oid = self.get_table_oid(connection, table_name, schema,
                                        info_cache=kw.get('info_cache'))
         SQL_COLS = """
             SELECT a.attname,
@@ -680,8 +694,8 @@ class PGDialect(default.DefaultDialect):
         return columns
 
     @reflection.cache
-    def get_primary_keys(self, connection, tablename, schemaname=None, **kw):
-        table_oid = self.get_table_oid(connection, tablename, schemaname,
+    def get_primary_keys(self, connection, table_name, schema=None, **kw):
+        table_oid = self.get_table_oid(connection, table_name, schema,
                                        info_cache=kw.get('info_cache'))
         PK_SQL = """
           SELECT attname FROM pg_attribute
@@ -697,9 +711,9 @@ class PGDialect(default.DefaultDialect):
         return primary_keys
 
     @reflection.cache
-    def get_foreign_keys(self, connection, tablename, schemaname=None, **kw):
+    def get_foreign_keys(self, connection, table_name, schema=None, **kw):
         preparer = self.identifier_preparer
-        table_oid = self.get_table_oid(connection, tablename, schemaname,
+        table_oid = self.get_table_oid(connection, table_name, schema,
                                        info_cache=kw.get('info_cache'))
         FK_SQL = """
           SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
@@ -717,11 +731,11 @@ class PGDialect(default.DefaultDialect):
             constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
             if referred_schema:
                 referred_schema = preparer._unquote_identifier(referred_schema)
-            elif schemaname is not None and schemaname == self.get_default_schema_name(connection):
+            elif schema is not None and schema == self.get_default_schema_name(connection):
                 # no schema (i.e. its the default schema), and the table we're
                 # reflecting has the default schema explicit, then use that.
                 # i.e. try to use the user's conventions
-                referred_schema = schemaname
+                referred_schema = schema
             referred_table = preparer._unquote_identifier(referred_table)
             referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
             fkey_d = {
@@ -735,8 +749,8 @@ class PGDialect(default.DefaultDialect):
         return fkeys
 
     @reflection.cache
-    def get_indexes(self, connection, tablename, schemaname, **kw):
-        table_oid = self.get_table_oid(connection, tablename, schemaname,
+    def get_indexes(self, connection, table_name, schema, **kw):
+        table_oid = self.get_table_oid(connection, table_name, schema,
                                        info_cache=kw.get('info_cache'))
         IDX_SQL = """
           SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
@@ -778,16 +792,16 @@ class PGDialect(default.DefaultDialect):
 
     def reflecttable(self, connection, table, include_columns):
         preparer = self.identifier_preparer
-        schemaname = table.schema
-        tablename = table.name
+        schema = table.schema
+        table_name = table.name
         info_cache = {}
         # Py2K
-        if isinstance(schemaname, str):
-            schemaname = schemaname.decode(self.encoding)
-        if isinstance(tablename, str):
-            tablename = tablename.decode(self.encoding)
+        if isinstance(schema, str):
+            schema = schema.decode(self.encoding)
+        if isinstance(table_name, str):
+            table_name = table_name.decode(self.encoding)
         # end Py2K
-        for col_d in self.get_columns(connection, tablename, schemaname,
+        for col_d in self.get_columns(connection, table_name, schema,
                                       info_cache=info_cache):
             name = col_d['name']
             coltype = col_d['type']
@@ -800,18 +814,18 @@ class PGDialect(default.DefaultDialect):
                 match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
                 if match is not None:
                     # the default is related to a Sequence
-                    sch = schemaname
+                    sch = schema
                     if '.' not in match.group(2) and sch is not None:
                         # unconditionally quote the schema name.  this could
                         # later be enhanced to obey quoting rules / "quote schema"
                         default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
-                colargs.append(schema.DefaultClause(sql.text(default)))
-            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+                colargs.append(sa_schema.DefaultClause(sql.text(default)))
+            table.append_column(sa_schema.Column(name, coltype, nullable=nullable, *colargs))
         # Now we have the table oid cached.
-        table_oid = self.get_table_oid(connection, tablename, schemaname,
+        table_oid = self.get_table_oid(connection, table_name, schema,
                                        info_cache=info_cache)
         # Primary keys
-        for pk in self.get_primary_keys(connection, tablename, schemaname,
+        for pk in self.get_primary_keys(connection, table_name, schema,
                                         info_cache=info_cache):
             if pk in table.c:
                 col = table.c[pk]
@@ -819,7 +833,7 @@ class PGDialect(default.DefaultDialect):
                 if col.default is None:
                     col.autoincrement = False
         # Foreign keys
-        fkeys = self.get_foreign_keys(connection, tablename, schemaname,
+        fkeys = self.get_foreign_keys(connection, table_name, schema,
                                       info_cache=info_cache)
         for fkey_d in fkeys:
             conname = fkey_d['name']
@@ -829,25 +843,25 @@ class PGDialect(default.DefaultDialect):
             referred_columns = fkey_d['referred_columns']
             refspec = []
             if referred_schema is not None:
-                schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
+                sa_schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
                             autoload_with=connection)
                 for column in referred_columns:
                     refspec.append(".".join([referred_schema, referred_table, column]))
             else:
-                schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
+                sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
                 for column in referred_columns:
                     refspec.append(".".join([referred_table, column]))
 
-            table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
+            table.append_constraint(sa_schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
 
         # Indexes 
-        indexes = self.get_indexes(connection, tablename, schemaname,
+        indexes = self.get_indexes(connection, table_name, schema,
                                    info_cache=info_cache)
         for index_d in indexes:
             name = index_d['name']
             columns = index_d['column_names']
             unique = index_d['unique']
-            schema.Index(name, *[table.columns[c] for c in columns], 
+            sa_schema.Index(name, *[table.columns[c] for c in columns], 
                          **dict(unique=unique))
 
     def _load_domains(self, connection):
index 8b13b4dccab2722295517711be5b68d77711fb8d..ed6160e83eb934ba537143f89341817869eb1f76 100644 (file)
@@ -236,9 +236,6 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
         'vacuum', 'values', 'view', 'virtual', 'when', 'where',
         ])
 
-class SQLiteInfoCache(reflection.DefaultInfoCache):
-    pass
-
 class SQLiteDialect(default.DefaultDialect):
     name = 'sqlite'
     supports_alter = False
@@ -254,7 +251,6 @@ class SQLiteDialect(default.DefaultDialect):
     preparer = SQLiteIdentifierPreparer
     ischema_names = ischema_names
     colspecs = colspecs
-    info_cache = SQLiteInfoCache
     
     def table_names(self, connection, schema):
         if schema is not None:
@@ -295,7 +291,7 @@ class SQLiteDialect(default.DefaultDialect):
 
         return (row is not None)
 
-    @reflection.caches
+    @reflection.cache
     def get_columns(self, connection, tablename, schemaname=None,
                                                         info_cache=None):
         quote = self.identifier_preparer.quote_identifier
@@ -342,7 +338,7 @@ class SQLiteDialect(default.DefaultDialect):
             })
         return columns
 
-    @reflection.caches
+    @reflection.cache
     def get_foreign_keys(self, connection, tablename, schemaname=None,
                                                         info_cache=None):
         quote = self.identifier_preparer.quote_identifier
index 6944a52624a2f96012d577911ef64e4ae0fec01a..5225b6d4b76157d995582d26af5fc4d9ef9bd924 100644 (file)
@@ -203,6 +203,11 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    def get_table_names(self, connection, schema=None):
+        """Return a list of table names for `schema`."""
+
+        raise NotImplementedError
+
     def do_begin(self, connection):
         """Provide an implementation of *connection.begin()*, given a DB-API connection."""
 
index f1766ec2a6690e499446b0bb5f19cf24f3683b89..bb22cc42c40709f70b1baee98187e640c2407ea0 100644 (file)
@@ -17,7 +17,6 @@ I'm still trying to decide upon conventions for both the Inspector interface as
 
 
 """
-import inspect
 import sqlalchemy
 from sqlalchemy import util
 from sqlalchemy.types import TypeEngine
@@ -35,307 +34,153 @@ def cache(fn, self, con, *args, **kw):
         info_cache[key] = ret
     return ret
 
-# keeping this around until all dialects are fixed
-@util.decorator
-def caches(fn, self, con, *args, **kw):
-    # what are we caching?
-    fn_name = fn.__name__
-    if not fn_name.startswith('get_'):
-        # don't recognize this.
-        return fn(self, con, *args, **kw)
-    else:
-        attr_to_cache = fn_name[4:]
-    # The first arguments will always be self and con.
-    # Assuming *args and *kw will be acceptable to info_cache method.
-    if 'info_cache' in kw:
-        kw_cp = kw.copy()
-        info_cache = kw_cp.pop('info_cache')
-        methodname = "%s_%s" % ('get', attr_to_cache)
-        # fixme.
-        for bad_kw in ('dblink', 'resolve_synonyms'):
-            if bad_kw in kw_cp:
-                del kw_cp[bad_kw]
-        information = getattr(info_cache, methodname)(*args, **kw_cp)
-        if information:
-            return information
-    information = fn(self, con, *args, **kw)
-    if 'info_cache' in locals():
-        methodname = "%s_%s" % ('set', attr_to_cache)
-        getattr(info_cache, methodname)(information, *args, **kw_cp)
-    return information 
-
-class DefaultInfoCache(object):
-    """Default implementation of InfoCache
-
-    InfoCache provides a means for dialects to cache information obtained for
-    reflection and a convenient interface for setting and retrieving cached
-    data.
-
-    """
-    
-    def __init__(self):
-        self._cache = dict(schemas={})
-        self.tables_are_complete = False
-        self.schemas_are_complete = False
-        self.views_are_complete = False
-
-    def clear(self):
-        """Clear the cache."""
-        self._cache = dict(schemas={})
-
-    # schemas
-
-    def get_schemas(self):
-        """Return the schemas dict."""
-        return self._cache.get('schemas')
 
+class Inspector(object):
+    """Performs database schema inspection.
 
-    def get_schema(self, schemaname, create=False):
-        """Return cached schema and optionally create it if it does not exist.
+    The Inspector acts as a proxy to the dialects' reflection methods and
+    provides higher level functions for accessing database schema information.
 
+    """
+    
+    def __init__(self, conn):
         """
-        schema = self._cache['schemas'].get(schemaname)
-        if schema is not None:
-            return schema
-        elif create:
-            return self.add_schema(schemaname)
-        return None
 
-    def add_schema(self, schemaname):
-        self._cache['schemas'][schemaname] = dict(tables={}, views={})
-        return self.get_schema(schemaname)
-
-    def get_schema_names(self, check_complete=True):
-        """Return cached schema names.
-
-        By default, only return them if they're complete.
+        conn
+          [sqlalchemy.engine.base.#Connectable]
 
         """
-        if check_complete and self.schemas_are_complete:
-            return self.get_schemas().keys()
-        elif not check_complete:
-            return self.get_schemas().keys()
+        self.conn = conn
+        # set the engine
+        if hasattr(conn, 'engine'):
+            self.engine = conn.engine
         else:
-            return None
-
-    def set_schema_names(self, schemanames):
-        for schemaname in schemanames:
-            self.add_schema(schemaname)
-        self.schemas_are_complete = True
+            self.engine = conn
+        self.dialect = self.engine.dialect
+        self.info_cache = {}
 
-    # tables
+    @classmethod
+    def from_engine(cls, engine):
+        if hasattr(engine.dialect, 'inspector'):
+            return engine.dialect.inspector(engine)
+        return Inspector(engine)
 
-    def get_table(self, tablename, schemaname=None, create=False,
-                                                        table_type='table'):
-        """Return cached table and optionally create it if it does not exist.
+    def default_schema_name(self):
+        return self.dialect.get_default_schema_name(self.conn)
+    default_schema_name = property(default_schema_name)
 
+    def get_schema_names(self):
+        """Return all schema names.
 
         """
-        cache = self._cache
-        schema = self.get_schema(schemaname, create=create)
-        if schema is None:
-            return None
-        if table_type == 'view':
-            table = schema['views'].get(tablename)
-        else:
-            table = schema['tables'].get(tablename)
-        if table is not None:
-            return table
-        elif create:
-            return self.add_table(tablename, schemaname, table_type=table_type)
-        return None
+        if hasattr(self.dialect, 'get_schema_names'):
+            return self.dialect.get_schema_names(self.conn,
+                                                    info_cache=self.info_cache)
+        return []
 
-    def get_table_names(self, schemaname=None, check_complete=True,
-                                                        table_type='table'):
-        """Return cached table names.
+    def get_table_names(self, schema=None, order_by=None):
+        """Return all table names in `schema`.
+        schema:
+          Optional, retrieve names from a non-default schema.
 
-        By default, only return them if they're complete.
+        This should probably not return view names or maybe it should return
+        them with an indicator t or v.
 
         """
-        if table_type == 'view':
-            complete = self.views_are_complete
-        else:
-            complete = self.tables_are_complete
-        if check_complete and complete:
-            return self.get_tables(schemaname, table_type=table_type).keys()
-        elif not check_complete:
-            return self.get_tables(schemaname, table_type=table_type).keys()
+        if hasattr(self.dialect, 'get_table_names'):
+            tnames = self.dialect.get_table_names(self.conn,
+            schema,
+                                                    info_cache=self.info_cache)
         else:
-            return None
-
-    def add_table(self, tablename, schemaname=None, table_type='table'):
-        schema = self.get_schema(schemaname, create=True)
-        if table_type == 'table':
-            schema['tables'][tablename] = dict(columns={})
-        else:
-            schema['views'][tablename] = dict(columns={})
-        return self.get_table(tablename, schemaname, table_type=table_type)
-
-    def set_table_names(self, tablenames, schemaname=None, table_type='table'):
-        for tablename in tablenames:
-            self.add_table(tablename, schemaname, table_type)
-        if table_type == 'view':
-            self.views_are_complete = True
-        else:
-            self.tables_are_complete = True
-            
-    # views
-
-    def get_view(self, viewname, schemaname=None, create=False):
-        return self.get_table(viewname, schemaname, create, 'view')
-
-    def get_view_names(self, schemaname=None, check_complete=True):
-        return self.get_table_names(schemaname, check_complete, 'view')
-
-    def add_view(self, viewname, schemaname=None):
-        return self.add_table(viewname, schemaname, 'view')
-
-    def set_view_names(self, viewnames, schemaname=None):
-        return self.set_table_names(viewnames, schemaname, 'view')
-
-    def get_view_definition(self, viewname, schemaname=None):
-        view_cache = self.get_view(viewname, schemaname)
-        if view_cache and 'definition' in view_cache:
-            return view_cache['definition']
-
-    def set_view_definition(self, definition, viewname, schemaname=None):
-        view_cache = self.get_view(viewname, schemaname, create=True)
-        view_cache['definition'] = definition
-
-    # table data
-
-    def _get_table_data(self, key, tablename, schemaname=None):
-        table_cache = self.get_table(tablename, schemaname)
-        if table_cache is not None and key in table_cache.keys():
-            return table_cache[key]
-
-    def _set_table_data(self, key, data, tablename, schemaname=None):
-        """Cache data for schemaname.tablename using key.
-
-        It will create a schema and table entry in the cache if needed.
+            tnames = self.engine.table_names(schema)
+        if order_by == 'foreign_key':
+            ordered_tnames = tnames[:]
+            # Order based on foreign key dependencies.
+            for tname in tnames:
+                table_pos = tnames.index(tname)
+                fkeys = self.get_foreign_keys(tname, schema)
+                for fkey in fkeys:
+                    rtable = fkey['referred_table']
+                    if rtable in ordered_tnames:
+                        ref_pos = ordered_tnames.index(rtable)
+                        # Make sure it's lower in the list than anything it
+                        # references.
+                        if table_pos > ref_pos:
+                            ordered_tnames.pop(table_pos) # rtable moves up 1
+                            # insert just below rtable
+                            ordered_tnames.index(ref_pos, tname)
+            tnames = ordered_tnames
+        return tnames
+
+    def get_view_names(self, schema=None):
+        """Return all view names in `schema`.
+        schema:
+          Optional, retrieve names from a non-default schema.
 
         """
-        table_cache = self.get_table(tablename, schemaname, create=True)
-        table_cache[key] = data
-
-    # columns
-
-    def get_columns(self, tablename, schemaname=None):
-        """Return columns list or None."""
-        
-        return self._get_table_data('columns', tablename, schemaname)
-
-    def set_columns(self, columns, tablename, schemaname=None):
-        """Add list of columns to table cache."""
-
-        return self._set_table_data('columns', columns, tablename, schemaname)
-
-    # primary keys
+        return self.dialect.get_view_names(self.conn, schema,
+                                                  info_cache=self.info_cache)
 
-    def get_primary_keys(self, tablename, schemaname=None):
-        """Return primary key list or None."""
-        
-        return self._get_table_data('primary_keys', tablename, schemaname)
+    def get_view_definition(self, view_name, schema=None):
+        """Return definition for `view_name`.
+        schema:
+          Optional, retrieve names from a non-default schema.
 
-    def set_primary_keys(self, pkeys, tablename, schemaname=None):
-        """Add list of primary keys to table cache."""
+        """
+        return self.dialect.get_view_definition(
+            self.conn, view_name, schema, info_cache=self.info_cache)
 
-        return self._set_table_data('primary_keys', pkeys, tablename, schemaname)
+    def get_columns(self, table_name, schema=None):
+        """Return information about columns in `table_name`.
 
-    # foreign keys
+        Given a string `table_name` and an optional string `schema`, return
+        column information as a list of dicts with these keys:
 
-    def get_foreign_keys(self, tablename, schemaname=None):
-        """Return foreign key list or None."""
-        
-        return self._get_table_data('foreign_keys', tablename, schemaname)
+        name
+          the column's name
 
-    def set_foreign_keys(self, fkeys, tablename, schemaname=None):
-        """Add list of foreign keys to table cache."""
+        type
+          [sqlalchemy.types#TypeEngine]
 
-        return self._set_table_data('foreign_keys', fkeys, tablename, schemaname)
+        nullable
+          boolean
 
-    # indexes
+        default
+          the column's default value
 
-    def get_indexes(self, tablename, schemaname=None):
-        """Return indexes list or None."""
-        
-        return self._get_table_data('indexes', tablename, schemaname)
+        attrs
+          dict containing optional column attributes
 
-    def set_indexes(self, indexes, tablename, schemaname=None):
-        """Add list of indexes to table cache."""
+        """
 
-        return self._set_table_data('indexes', indexes, tablename, schemaname)
+        col_defs = self.dialect.get_columns(self.conn, table_name,
+                                                   schema,
+                                                   info_cache=self.info_cache)
+        for col_def in col_defs:
+            # make this easy and only return instances for coltype
+            coltype = col_def['type']
+            if not isinstance(coltype, TypeEngine):
+                col_def['type'] = coltype()
+        return col_defs
 
-class Inspector(object):
-    """Performs database introspection.
+    def get_primary_keys(self, table_name, schema=None):
+        """Return information about primary keys in `table_name`.
 
-    The Inspector acts as a proxy to the dialects' reflection methods and
-    provides higher level functions for accessing database schema information.
+        Given a string `table_name`, and an optional string `schema`, return 
+        primary key information as a list of column names.
 
-    """
-    
-    def __init__(self, conn):
         """
 
-        conn
-          [sqlalchemy.engine.base.#Connectable]
-
-        Upon initialization, new members are added corresponding to the
-        refection members of the current dialect.
+        pkeys = self.dialect.get_primary_keys(self.conn, table_name,
+                                                     schema,
+                                            info_cache=self.info_cache)
 
-        Dev Notes:
-        
-        I used attribute assignment rather than __getattr__ because 
-        I want the Inspector to be inspectable including providing proper
-        documentation strings for the methods is supports.
-        
-        The primary reason for this approach:
+        return pkeys
 
-        1. DRY.
-        2. Provides access to dialect specific reflection methods.
+    def get_foreign_keys(self, table_name, schema=None):
+        """Return information about foreign_keys in `table_name`.
 
-        """
-        self.conn = conn
-        # set the engine
-        if hasattr(conn, 'engine'):
-            self.engine = conn.engine
-        else:
-            self.engine = conn
-        # fixme. This is just until all dialects are converted
-        if hasattr(self, 'info_cache'):
-            self.info_cache = self.engine.dialect.info_cache()
-        else:
-            self.info_cache = {}
-        # add methods from dialect
-        def filter_reflect_members(m):
-            if inspect.ismethod(m) and m.__name__.startswith('get_'):
-                argspec = inspect.getargspec(m)
-                if isinstance(argspec, tuple) and 'connection' in argspec[0]:
-                    return True
-            return False
-        reflection_members = inspect.getmembers(self.engine.dialect,
-                                                filter_reflect_members)
-        def wrap_reflection_method(fn):
-            def decorated(*args, **kwargs):
-                args = (self.conn,) + args
-                kwargs['info_cache'] = self.info_cache
-                return fn(*args, **kwargs)
-            return decorated
-        for (member_name, member) in reflection_members:
-            if not hasattr(self, member_name):
-                doc = "This method mirrors the dialect method %s." % member_name
-                wrapped_member = wrap_reflection_method(member)
-                wrapped_member.__doc__ = "%s\n\n%s" % (doc, member.__doc__)
-                setattr(self, member_name, wrapped_member)
-
-    @property
-    def default_schema_name(self):
-        return self.engine.dialect.get_default_schema_name(self.conn)
-
-    def get_foreign_keys(self, tablename, schemaname=None):
-        """Return information about foreign_keys in `tablename`.
-
-        Given a string `tablename`, and an optional string `schemaname`, return 
+        Given a string `table_name`, and an optional string `schema`, return 
         foreign key information as a list of dicts with these keys:
 
         constrained_columns
@@ -353,24 +198,36 @@ class Inspector(object):
 
         """
 
-        fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename,
-                                                       schemaname,
+        fk_defs = self.dialect.get_foreign_keys(self.conn, table_name,
+                                                       schema,
                                                 info_cache=self.info_cache)
         for fk_def in fk_defs:
             referred_schema = fk_def['referred_schema']
             # always set the referred_schema.
-            if referred_schema is None and schemaname is None:
-                referred_schema = self.engine.dialect.get_default_schema_name(
+            if referred_schema is None and schema is None:
+                referred_schema = self.dialect.get_default_schema_name(
                                                                     self.conn)
                 fk_def['referred_schema'] = referred_schema
         return fk_defs
 
-    def get_relation_map(self, schemaname=None):
-        """Provide a mapping of the relations between all tables in schemaname.
+    def get_indexes(self, table_name, schema=None):
+        """Return information about indexes in `table_name`.
 
-        This is an example of a higher level function where Inspector can be
-        very useful.
+        Given a string `table_name` and an optional string `schema`, return
+        index information as a list of dicts with these keys:
+
+        name
+          the index's name
+
+        column_names
+          list of column names in order
+
+        unique
+          boolean
 
         """
-        #todo
-        pass
+
+        indexes = self.dialect.get_indexes(self.conn, table_name,
+                                                  schema,
+                                            info_cache=self.info_cache)
+        return indexes
index 23e5befd31390286c98b8435302dffcfa86b0ab9..999f0e2c223accf4df9ce0c9a75fa06e879dc385 100644 (file)
@@ -9,6 +9,8 @@ from sqlalchemy.engine.reflection import Inspector
 from testlib.sa import MetaData, Table, Column
 from testlib import TestBase, testing, engines
 
+create_inspector = Inspector.from_engine
+
 if 'set' not in dir(__builtins__):
     from sets import Set as set
 
@@ -65,20 +67,20 @@ def createIndexes(con, schema=None):
     con.execute(sa.sql.text(query))
 
 def createViews(con, schema=None):
-    for tablename in ('users', 'email_addresses'):
-        fullname = tablename
+    for table_name in ('users', 'email_addresses'):
+        fullname = table_name
         if schema:
-            fullname = "%s.%s" % (schema, tablename)
+            fullname = "%s.%s" % (schema, table_name)
         view_name = fullname + '_v'
         query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name,
                                                                    fullname)
         con.execute(sa.sql.text(query))
 
 def dropViews(con, schema=None):
-    for tablename in ('email_addresses', 'users'):
-        fullname = tablename
+    for table_name in ('email_addresses', 'users'):
+        fullname = table_name
         if schema:
-            fullname = "%s.%s" % (schema, tablename)
+            fullname = "%s.%s" % (schema, table_name)
         view_name = fullname + '_v'
         query = "DROP VIEW %s" % view_name
         con.execute(sa.sql.text(query))
@@ -91,20 +93,20 @@ class ReflectionTest(TestBase):
         insp = Inspector(meta.bind)
         self.assert_(getSchema() in insp.get_schema_names())
 
-    def _test_get_table_names(self, schemaname=None, table_type='table',
+    def _test_get_table_names(self, schema=None, table_type='table',
                               order_by=None):
         meta = MetaData(testing.db)
-        (users, addresses) = createTables(meta, schemaname)
+        (users, addresses) = createTables(meta, schema)
         meta.create_all()
-        createViews(meta.bind, schemaname)
+        createViews(meta.bind, schema)
         try:
             insp = Inspector(meta.bind)
             if table_type == 'view':
-                table_names = insp.get_view_names(schemaname)
+                table_names = insp.get_view_names(schema)
                 table_names.sort()
                 answer = ['email_addresses_v', 'users_v']
             else:
-                table_names = insp.get_table_names(schemaname,
+                table_names = insp.get_table_names(schema,
                                                    order_by=order_by)
                 table_names.sort()
                 if order_by == 'foreign_key':
@@ -113,7 +115,7 @@ class ReflectionTest(TestBase):
                     answer = ['email_addresses', 'users']
             self.assertEqual(table_names, answer)
         finally:
-            dropViews(meta.bind, schemaname)
+            dropViews(meta.bind, schema)
             addresses.drop()
             users.drop()
 
@@ -132,21 +134,21 @@ class ReflectionTest(TestBase):
     def test_get_view_names_with_schema(self):
         self._test_get_table_names(getSchema(), table_type='view')
 
-    def _test_get_columns(self, schemaname=None, table_type='table'):
+    def _test_get_columns(self, schema=None, table_type='table'):
         meta = MetaData(testing.db)
-        (users, addresses) = createTables(meta, schemaname)
+        (users, addresses) = createTables(meta, schema)
         table_names = ['users', 'email_addresses']
         meta.create_all()
         if table_type == 'view':
-            createViews(meta.bind, schemaname)
+            createViews(meta.bind, schema)
             table_names = ['users_v', 'email_addresses_v']
         try:
             insp = Inspector(meta.bind)
-            for (tablename, table) in zip(table_names, (users, addresses)):
-                schema_name = schemaname
-                if schemaname and testing.against('oracle'):
-                    schema_name = schemaname.upper()
-                cols = insp.get_columns(tablename, schemaname=schema_name)
+            for (table_name, table) in zip(table_names, (users, addresses)):
+                schema_name = schema
+                if schema and testing.against('oracle'):
+                    schema_name = schema.upper()
+                cols = insp.get_columns(table_name, schema=schema_name)
                 self.assert_(len(cols) > 0, len(cols))
                 # should be in order
                 for (i, col) in enumerate(table.columns):
@@ -172,7 +174,7 @@ class ReflectionTest(TestBase):
                                           ctype)))
         finally:
             if table_type == 'view':
-                dropViews(meta.bind, schemaname)
+                dropViews(meta.bind, schema)
             addresses.drop()
             users.drop()
 
@@ -180,25 +182,25 @@ class ReflectionTest(TestBase):
         self._test_get_columns()
 
     def test_get_columns_with_schema(self):
-        self._test_get_columns(schemaname=getSchema())
+        self._test_get_columns(schema=getSchema())
 
     def test_get_view_columns(self):
         self._test_get_columns(table_type='view')
 
     def test_get_view_columns_with_schema(self):
-        self._test_get_columns(schemaname=getSchema(), table_type='view')
+        self._test_get_columns(schema=getSchema(), table_type='view')
 
-    def _test_get_primary_keys(self, schemaname=None):
+    def _test_get_primary_keys(self, schema=None):
         meta = MetaData(testing.db)
-        (users, addresses) = createTables(meta, schemaname)
+        (users, addresses) = createTables(meta, schema)
         meta.create_all()
         insp = Inspector(meta.bind)
         try:
             users_pkeys = insp.get_primary_keys(users.name,
-                                                schemaname=schemaname)
+                                                schema=schema)
             self.assertEqual(users_pkeys,  ['user_id'])
             addr_pkeys = insp.get_primary_keys(addresses.name,
-                                               schemaname=schemaname)
+                                               schema=schema)
             self.assertEqual(addr_pkeys,  ['address_id'])
 
         finally:
@@ -209,21 +211,21 @@ class ReflectionTest(TestBase):
         self._test_get_primary_keys()
 
     def test_get_primary_keys_with_schema(self):
-        self._test_get_primary_keys(schemaname=getSchema())
+        self._test_get_primary_keys(schema=getSchema())
 
-    def _test_get_foreign_keys(self, schemaname=None):
+    def _test_get_foreign_keys(self, schema=None):
         meta = MetaData(testing.db)
-        (users, addresses) = createTables(meta, schemaname)
+        (users, addresses) = createTables(meta, schema)
         meta.create_all()
         insp = Inspector(meta.bind)
         try:
-            expected_schema = schemaname
-            if schemaname is None:
+            expected_schema = schema
+            if schema is None:
                 expected_schema = meta.bind.dialect.get_default_schema_name(
                                     meta.bind)
             # users
             users_fkeys = insp.get_foreign_keys(users.name,
-                                                schemaname=schemaname)
+                                                schema=schema)
             fkey1 = users_fkeys[0]
             self.assert_(fkey1['name'] is not None)
             self.assertEqual(fkey1['referred_schema'], expected_schema)
@@ -232,7 +234,7 @@ class ReflectionTest(TestBase):
             self.assertEqual(fkey1['constrained_columns'], ['parent_user_id'])
             #addresses
             addr_fkeys = insp.get_foreign_keys(addresses.name,
-                                               schemaname=schemaname)
+                                               schema=schema)
             fkey1 = addr_fkeys[0]
             self.assert_(fkey1['name'] is not None)
             self.assertEqual(fkey1['referred_schema'], expected_schema)
@@ -247,16 +249,16 @@ class ReflectionTest(TestBase):
         self._test_get_foreign_keys()
 
     def test_get_foreign_keys_with_schema(self):
-        self._test_get_foreign_keys(schemaname=getSchema())
+        self._test_get_foreign_keys(schema=getSchema())
 
-    def _test_get_indexes(self, schemaname=None):
+    def _test_get_indexes(self, schema=None):
         meta = MetaData(testing.db)
-        (users, addresses) = createTables(meta, schemaname)
+        (users, addresses) = createTables(meta, schema)
         meta.create_all()
-        createIndexes(meta.bind, schemaname)
+        createIndexes(meta.bind, schema)
         try:
             insp = Inspector(meta.bind)
-            indexes = insp.get_indexes('users', schemaname=schemaname)
+            indexes = insp.get_indexes('users', schema=schema)
             indexes.sort()
             if testing.against('oracle'):
                 expected_indexes = [
@@ -277,23 +279,23 @@ class ReflectionTest(TestBase):
         self._test_get_indexes()
 
     def test_get_indexes_with_schema(self):
-        self._test_get_indexes(schemaname=getSchema())
+        self._test_get_indexes(schema=getSchema())
 
-    def _test_get_view_definition(self, schemaname=None):
+    def _test_get_view_definition(self, schema=None):
         meta = MetaData(testing.db)
-        (users, addresses) = createTables(meta, schemaname)
+        (users, addresses) = createTables(meta, schema)
         meta.create_all()
-        createViews(meta.bind, schemaname)
+        createViews(meta.bind, schema)
         view_name1 = 'users_v'
         view_name2 = 'email_addresses_v'
         try:
             insp = Inspector(meta.bind)
-            v1 = insp.get_view_definition(view_name1, schemaname=schemaname)
+            v1 = insp.get_view_definition(view_name1, schema=schema)
             self.assert_(v1)
-            v2 = insp.get_view_definition(view_name2, schemaname=schemaname)
+            v2 = insp.get_view_definition(view_name2, schema=schema)
             self.assert_(v2)
         finally:
-            dropViews(meta.bind, schemaname)
+            dropViews(meta.bind, schema)
             addresses.drop()
             users.drop()
 
@@ -301,7 +303,26 @@ class ReflectionTest(TestBase):
         self._test_get_view_definition()
 
     def test_get_view_definition_with_schema(self):
-        self._test_get_view_definition(schemaname=getSchema())
+        self._test_get_view_definition(schema=getSchema())
+
+    def _test_get_table_oid(self, table_name, schema=None):
+        if testing.against('postgres'):
+            meta = MetaData(testing.db)
+            (users, addresses) = createTables(meta, schema)
+            meta.create_all()
+            try:
+                insp = create_inspector(meta.bind)
+                oid = insp.get_table_oid(table_name, schema)
+                self.assert_(isinstance(oid, int))
+            finally:
+                addresses.drop()
+                users.drop()
+
+    def test_get_table_oid(self):
+        self._test_get_table_oid('users')
+
+    def test_get_table_oid_with_schema(self):
+        self._test_get_table_oid('users', schema=getSchema())
 
 if __name__ == "__main__":
     testenv.main()