From 0aef2c952c906fd533cbb5973e90ca10edef9320 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Mon, 2 Mar 2009 05:46:16 +0000 Subject: [PATCH] dialects can subclass Inspector --- lib/sqlalchemy/dialects/mssql/base.py | 23 +- lib/sqlalchemy/dialects/oracle/base.py | 144 ++++---- lib/sqlalchemy/dialects/postgres/base.py | 112 ++++--- lib/sqlalchemy/dialects/sqlite/base.py | 8 +- lib/sqlalchemy/engine/base.py | 5 + lib/sqlalchemy/engine/reflection.py | 403 ++++++++--------------- test/reflection.py | 115 ++++--- 7 files changed, 349 insertions(+), 461 deletions(-) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index b45e7cd5aa..f6bb9ad850 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1044,12 +1044,6 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): return value -class MSInfoCache(reflection.DefaultInfoCache): - - def __init__(self, *args, **kwargs): - reflection.DefaultInfoCache.__init__(self, *args, **kwargs) - - class MSDialect(default.DefaultDialect): name = 'mssql' supports_default_values = True @@ -1071,7 +1065,6 @@ class MSDialect(default.DefaultDialect): ddl_compiler = MSDDLCompiler type_compiler = MSTypeCompiler preparer = MSIdentifierPreparer - info_cache = MSInfoCache def __init__(self, auto_identity_insert=True, query_timeout=None, @@ -1147,7 +1140,7 @@ class MSDialect(default.DefaultDialect): row = c.fetchone() return row is not None - @reflection.caches + @reflection.cache def get_schema_names(self, connection, info_cache=None): import sqlalchemy.dialects.information_schema as ischema s = sql.select([self.uppercase_table(ischema.schemata).c.schema_name], @@ -1156,7 +1149,7 @@ class MSDialect(default.DefaultDialect): schema_names = [r[0] for r in connection.execute(s)] return schema_names - @reflection.caches + @reflection.cache def get_table_names(self, connection, schemaname, info_cache=None): import sqlalchemy.dialects.information_schema as ischema current_schema = schemaname or self.get_default_schema_name(connection) @@ -1171,7 +1164,7 @@ class MSDialect(default.DefaultDialect): table_names = [r[0] for r in connection.execute(s)] return table_names - @reflection.caches + @reflection.cache def get_view_names(self, connection, schemaname=None, info_cache=None): import sqlalchemy.dialects.information_schema as ischema current_schema = schemaname or self.get_default_schema_name(connection) @@ -1186,7 +1179,7 @@ class MSDialect(default.DefaultDialect): view_names = [r[0] for r in connection.execute(s)] return view_names - @reflection.caches + @reflection.cache def get_indexes(self, connection, tablename, schemaname=None, info_cache=None): current_schema = schemaname or self.get_default_schema_name(connection) @@ -1203,7 +1196,7 @@ class MSDialect(default.DefaultDialect): }) return indexes - @reflection.caches + @reflection.cache def get_view_definition(self, connection, viewname, schemaname=None, info_cache=None): import sqlalchemy.dialects.information_schema as ischema @@ -1220,7 +1213,7 @@ class MSDialect(default.DefaultDialect): view_def = rp.scalar() return view_def - @reflection.caches + @reflection.cache def get_columns(self, connection, tablename, schemaname=None, info_cache=None): # Get base columns @@ -1281,7 +1274,7 @@ class MSDialect(default.DefaultDialect): cols.append(cdict) return cols - @reflection.caches + @reflection.cache def get_primary_keys(self, connection, tablename, schemaname=None, info_cache=None): import sqlalchemy.dialects.information_schema as ischema @@ -1305,7 +1298,7 @@ class MSDialect(default.DefaultDialect): pkeys.append(row[0]) return pkeys - @reflection.caches + @reflection.cache def get_foreign_keys(self, connection, tablename, schemaname=None, info_cache=None): import sqlalchemy.dialects.information_schema as ischema diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 6fc87e1adb..2c4f326e8b 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -74,7 +74,8 @@ is not in use this flag should be left off. import datetime, random, re -from sqlalchemy import util, sql, schema, log +from sqlalchemy import schema as sa_schema +from sqlalchemy import util, sql, log from sqlalchemy.engine import default, base, reflection from sqlalchemy.sql import compiler, visitors, expression from sqlalchemy.sql import operators as sql_operators, functions as sql_functions @@ -447,9 +448,6 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer): name = re.sub(r'^_+', '', savepoint.ident) return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) -class OracleInfoCache(reflection.DefaultInfoCache): - pass - class OracleDialect(default.DefaultDialect): name = 'oracle' supports_alter = True @@ -474,7 +472,6 @@ class OracleDialect(default.DefaultDialect): type_compiler = OracleTypeCompiler preparer = OracleIdentifierPreparer defaultrunner = OracleDefaultRunner - info_cache = OracleInfoCache def __init__(self, use_ansi=True, @@ -568,50 +565,50 @@ class OracleDialect(default.DefaultDialect): else: return None, None, None, None - def _prepare_reflection_args(self, connection, tablename, schemaname=None, + def _prepare_reflection_args(self, connection, table_name, schema=None, resolve_synonyms=False, dblink=''): if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schemaname), desired_synonym=self._denormalize_name(tablename)) + actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schema), desired_synonym=self._denormalize_name(table_name)) else: actual_name, owner, dblink, synonym = None, None, None, None if not actual_name: - actual_name = self._denormalize_name(tablename) + actual_name = self._denormalize_name(table_name) if not dblink: dblink = '' if not owner: - owner = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) + owner = self._denormalize_name(schema or self.get_default_schema_name(connection)) return (actual_name, owner, dblink, synonym) - @reflection.caches - def get_schema_names(self, connection, info_cache=None): + @reflection.cache + def get_schema_names(self, connection, **kw): s = "SELECT username FROM all_users ORDER BY username" cursor = connection.execute(s,) return [self._normalize_name(row[0]) for row in cursor] - @reflection.caches - def get_table_names(self, connection, schemaname=None, info_cache=None): - schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) - return self.table_names(connection, schemaname) + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + schema = self._denormalize_name(schema or self.get_default_schema_name(connection)) + return self.table_names(connection, schema) - @reflection.caches - def get_view_names(self, connection, schemaname=None, info_cache=None): - schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + schema = self._denormalize_name(schema or self.get_default_schema_name(connection)) s = "select view_name from all_views where OWNER = :owner" cursor = connection.execute(s, - {'owner':self._denormalize_name(schemaname)}) + {'owner':self._denormalize_name(schema)}) return [self._normalize_name(row[0]) for row in cursor] - @reflection.caches - def get_columns(self, connection, tablename, schemaname=None, - info_cache=None, resolve_synonyms=False, dblink=''): + @reflection.cache + def get_columns(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): - (tablename, schemaname, dblink, synonym) = \ - self._prepare_reflection_args(connection, tablename, schemaname, + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, resolve_synonyms, dblink) columns = [] - c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':tablename, 'owner':schemaname}) + c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':table_name, 'owner':schema}) while True: row = c.fetchone() @@ -645,7 +642,7 @@ class OracleDialect(default.DefaultDialect): colargs = [] if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) + colargs.append(sa_schema.DefaultClause(sql.text(default))) cdict = { 'name': colname, 'type': coltype, @@ -656,13 +653,13 @@ class OracleDialect(default.DefaultDialect): columns.append(cdict) return columns - @reflection.caches - def get_indexes(self, connection, tablename, schemaname=None, - info_cache=None, resolve_synonyms=False, dblink=''): + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): - (tablename, schemaname, dblink, synonym) = \ - self._prepare_reflection_args(connection, tablename, schemaname, + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, resolve_synonyms, dblink) indexes = [] q = """ @@ -672,17 +669,18 @@ class OracleDialect(default.DefaultDialect): ON a.INDEX_NAME = b.INDEX_NAME AND a.TABLE_OWNER = b.TABLE_OWNER AND a.TABLE_NAME = b.TABLE_NAME - WHERE a.TABLE_NAME = :tablename - AND a.TABLE_OWNER = :schemaname + WHERE a.TABLE_NAME = :table_name + AND a.TABLE_OWNER = :schema ORDER BY a.INDEX_NAME, a.COLUMN_POSITION """ % dict(dblink=dblink) rp = connection.execute(q, - dict(tablename=self._denormalize_name(tablename), - schemaname=self._denormalize_name(schemaname))) + dict(table_name=self._denormalize_name(table_name), + schema=self._denormalize_name(schema))) indexes = [] last_index_name = None - pkeys = self.get_primary_keys(connection, tablename, schemaname, - info_cache, resolve_synonyms, dblink) + pkeys = self.get_primary_keys(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) uniqueness = dict(NONUNIQUE=False, UNIQUE=True) for rset in rp: # don't include the primary key columns @@ -696,8 +694,9 @@ class OracleDialect(default.DefaultDialect): last_index_name = rset.index_name return indexes - def _get_constraint_data(self, connection, tablename, schemaname=None, - info_cache=None, dblink=''): + @reflection.cache + def _get_constraint_data(self, connection, table_name, schema=None, + dblink='', **kw): rp = connection.execute("""SELECT ac.constraint_name, @@ -718,19 +717,20 @@ class OracleDialect(default.DefaultDialect): AND ac.r_constraint_name = rem.constraint_name(+) -- order multiple primary keys correctly ORDER BY ac.constraint_name, loc.position, rem.position""" - % {'dblink':dblink}, {'table_name' : tablename, 'owner' : schemaname}) + % {'dblink':dblink}, {'table_name' : table_name, 'owner' : schema}) constraint_data = rp.fetchall() return constraint_data - @reflection.caches - def get_primary_keys(self, connection, tablename, schemaname=None, - info_cache=None, resolve_synonyms=False, dblink=''): - (tablename, schemaname, dblink, synonym) = \ - self._prepare_reflection_args(connection, tablename, schemaname, + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, resolve_synonyms, dblink) pkeys = [] - constraint_data = self._get_constraint_data(connection, tablename, - schemaname, info_cache, dblink) + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) for row in constraint_data: #print "ROW:" , row (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) @@ -738,15 +738,16 @@ class OracleDialect(default.DefaultDialect): pkeys.append(local_column) return pkeys - @reflection.caches - def get_foreign_keys(self, connection, tablename, schemaname=None, - info_cache=None, resolve_synonyms=False, dblink=''): - (tablename, schemaname, dblink, synonym) = \ - self._prepare_reflection_args(connection, tablename, schemaname, + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, resolve_synonyms, dblink) - constraint_data = self._get_constraint_data(connection, tablename, - schemaname, info_cache, dblink) + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) fkeys = [] fks = {} for row in constraint_data: @@ -786,26 +787,26 @@ class OracleDialect(default.DefaultDialect): fkeys.append(fkey_d) return fkeys - @reflection.caches - def get_view_definition(self, connection, viewname, schemaname=None, - info_cache=None, resolve_synonyms=False, dblink=''): - (viewname, schemaname, dblink, synonym) = \ - self._prepare_reflection_args(connection, viewname, schemaname, + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + (view_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, view_name, schema, resolve_synonyms, dblink) s = """ SELECT text FROM all_views - WHERE owner = :schemaname - AND view_name = :viewname + WHERE owner = :schema + AND view_name = :view_name """ rp = connection.execute(sql.text(s), - viewname=viewname, schemaname=schemaname) + view_name=view_name, schema=schema) if rp: view_def = rp.scalar().decode(self.encoding) return view_def def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer - info_cache = OracleInfoCache() + info_cache = {} resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False) @@ -814,8 +815,8 @@ class OracleDialect(default.DefaultDialect): resolve_synonyms) # columns - columns = self.get_columns(connection, actual_name, owner, info_cache, - dblink) + columns = self.get_columns(connection, actual_name, owner, dblink, + info_cache=info_cache) for cdict in columns: colname = cdict['name'] coltype = cdict['type'] @@ -823,14 +824,14 @@ class OracleDialect(default.DefaultDialect): colargs = cdict['attrs'] if include_columns and colname not in include_columns: continue - table.append_column(schema.Column(colname, coltype, + table.append_column(sa_schema.Column(colname, coltype, nullable=nullable, *colargs)) if not table.columns: raise AssertionError("Couldn't find any column information for table %s" % actual_name) # primary keys for pkcol in self.get_primary_keys(connection, actual_name, owner, - info_cache, dblink): + dblink, info_cache=info_cache): if pkcol in table.c: table.primary_key.add(table.c[pkcol]) @@ -838,7 +839,8 @@ class OracleDialect(default.DefaultDialect): fks = {} fkeys = [] fkeys = self.get_foreign_keys(connection, actual_name, owner, - info_cache, resolve_synonyms, dblink) + resolve_synonyms, dblink, + info_cache=info_cache) refspecs = [] for fkey_d in fkeys: conname = fkey_d['name'] @@ -848,17 +850,17 @@ class OracleDialect(default.DefaultDialect): referred_columns = fkey_d['referred_columns'] for (i, ref_col) in enumerate(referred_columns): if not table.schema and self._denormalize_name(referred_schema) == self._denormalize_name(owner): - t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) + t = sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) refspec = ".".join([referred_table, ref_col]) else: refspec = '.'.join([x for x in [referred_schema, referred_table, ref_col] if x is not None]) - t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) + t = sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) refspecs.append(refspec) table.append_constraint( - schema.ForeignKeyConstraint(constrained_columns, refspecs, + sa_schema.ForeignKeyConstraint(constrained_columns, refspecs, name=conname, link_to_name=True)) diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index cd64a3c648..0101228762 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -66,6 +66,7 @@ option to the Index constructor:: import re +from sqlalchemy import schema as sa_schema from sqlalchemy import sql, schema, exc, util from sqlalchemy.engine import base, default, reflection from sqlalchemy.sql import compiler, expression @@ -397,6 +398,18 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): value = value[1:-1].replace('""','"') return value +class PGInspector(reflection.Inspector): + + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_oid(self, table_name, schema=None): + """Return the oid from `table_name` and `schema`.""" + + return self.dialect.get_table_oid(self.conn, table_name, schema, + info_cache=self.info_cache) + + class PGDialect(default.DefaultDialect): name = 'postgres' supports_alter = True @@ -417,6 +430,7 @@ class PGDialect(default.DefaultDialect): type_compiler = PGTypeCompiler preparer = PGIdentifierPreparer defaultrunner = PGDefaultRunner + inspector = PGInspector def do_begin_twophase(self, connection, xid): @@ -500,8 +514,8 @@ class PGDialect(default.DefaultDialect): return tuple([int(x) for x in m.group(1, 2, 3)]) @reflection.cache - def get_table_oid(self, connection, tablename, schemaname=None, **kw): - """Fetch the oid for schemaname.tablename. + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name. Several reflection methods require the table oid. The idea for using this method is that it can be fetched one time and cached for @@ -509,7 +523,7 @@ class PGDialect(default.DefaultDialect): """ table_oid = None - if schemaname is not None: + if schema is not None: schema_where_clause = "n.nspname = :schema" else: schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" @@ -520,21 +534,21 @@ class PGDialect(default.DefaultDialect): WHERE (%s) AND c.relname = :table_name AND c.relkind in ('r','v') """ % schema_where_clause - # Since we're binding to unicode, tablename and schemaname must be + # Since we're binding to unicode, table_name and schema_name must be # unicode. - tablename = unicode(tablename) - if schemaname is not None: - schemaname = unicode(schemaname) + table_name = unicode(table_name) + if schema is not None: + schema = unicode(schema) s = sql.text(query, bindparams=[ sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode) ], typemap={'oid':sqltypes.Integer} ) - c = connection.execute(s, table_name=tablename, schema=schemaname) + c = connection.execute(s, table_name=table_name, schema=schema) table_oid = c.scalar() if table_oid is None: - raise exc.NoSuchTableError(tablename) + raise exc.NoSuchTableError(table_name) return table_oid @reflection.cache @@ -551,18 +565,18 @@ class PGDialect(default.DefaultDialect): return schema_names @reflection.cache - def get_table_names(self, connection, schemaname=None, **kw): - if schemaname is not None: - current_schema = schemaname + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema else: current_schema = self.get_default_schema_name(connection) table_names = self.table_names(connection, current_schema) return table_names @reflection.cache - def get_view_names(self, connection, schemaname=None, **kw): - if schemaname is not None: - current_schema = schemaname + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema else: current_schema = self.get_default_schema_name(connection) s = """ @@ -575,26 +589,26 @@ class PGDialect(default.DefaultDialect): return view_names @reflection.cache - def get_view_definition(self, connection, viewname, schemaname=None, **kw): - if schemaname is not None: - current_schema = schemaname + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + current_schema = schema else: current_schema = self.get_default_schema_name(connection) s = """ SELECT definition FROM pg_views - WHERE schemaname = :schemaname - AND viewname = :viewname + WHERE schemaname = :schema + AND viewname = :view_name """ rp = connection.execute(sql.text(s), - viewname=viewname, schemaname=current_schema) + view_name=view_name, schema=current_schema) if rp: view_def = rp.scalar().decode(self.encoding) return view_def @reflection.cache - def get_columns(self, connection, tablename, schemaname=None, **kw): + def get_columns(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, tablename, schemaname, + table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) SQL_COLS = """ SELECT a.attname, @@ -680,8 +694,8 @@ class PGDialect(default.DefaultDialect): return columns @reflection.cache - def get_primary_keys(self, connection, tablename, schemaname=None, **kw): - table_oid = self.get_table_oid(connection, tablename, schemaname, + def get_primary_keys(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) PK_SQL = """ SELECT attname FROM pg_attribute @@ -697,9 +711,9 @@ class PGDialect(default.DefaultDialect): return primary_keys @reflection.cache - def get_foreign_keys(self, connection, tablename, schemaname=None, **kw): + def get_foreign_keys(self, connection, table_name, schema=None, **kw): preparer = self.identifier_preparer - table_oid = self.get_table_oid(connection, tablename, schemaname, + table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) FK_SQL = """ SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef @@ -717,11 +731,11 @@ class PGDialect(default.DefaultDialect): constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] if referred_schema: referred_schema = preparer._unquote_identifier(referred_schema) - elif schemaname is not None and schemaname == self.get_default_schema_name(connection): + elif schema is not None and schema == self.get_default_schema_name(connection): # no schema (i.e. its the default schema), and the table we're # reflecting has the default schema explicit, then use that. # i.e. try to use the user's conventions - referred_schema = schemaname + referred_schema = schema referred_table = preparer._unquote_identifier(referred_table) referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] fkey_d = { @@ -735,8 +749,8 @@ class PGDialect(default.DefaultDialect): return fkeys @reflection.cache - def get_indexes(self, connection, tablename, schemaname, **kw): - table_oid = self.get_table_oid(connection, tablename, schemaname, + def get_indexes(self, connection, table_name, schema, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, info_cache=kw.get('info_cache')) IDX_SQL = """ SELECT c.relname, i.indisunique, i.indexprs, i.indpred, @@ -778,16 +792,16 @@ class PGDialect(default.DefaultDialect): def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer - schemaname = table.schema - tablename = table.name + schema = table.schema + table_name = table.name info_cache = {} # Py2K - if isinstance(schemaname, str): - schemaname = schemaname.decode(self.encoding) - if isinstance(tablename, str): - tablename = tablename.decode(self.encoding) + if isinstance(schema, str): + schema = schema.decode(self.encoding) + if isinstance(table_name, str): + table_name = table_name.decode(self.encoding) # end Py2K - for col_d in self.get_columns(connection, tablename, schemaname, + for col_d in self.get_columns(connection, table_name, schema, info_cache=info_cache): name = col_d['name'] coltype = col_d['type'] @@ -800,18 +814,18 @@ class PGDialect(default.DefaultDialect): match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) if match is not None: # the default is related to a Sequence - sch = schemaname + sch = schema if '.' not in match.group(2) and sch is not None: # unconditionally quote the schema name. this could # later be enhanced to obey quoting rules / "quote schema" default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) - colargs.append(schema.DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) + colargs.append(sa_schema.DefaultClause(sql.text(default))) + table.append_column(sa_schema.Column(name, coltype, nullable=nullable, *colargs)) # Now we have the table oid cached. - table_oid = self.get_table_oid(connection, tablename, schemaname, + table_oid = self.get_table_oid(connection, table_name, schema, info_cache=info_cache) # Primary keys - for pk in self.get_primary_keys(connection, tablename, schemaname, + for pk in self.get_primary_keys(connection, table_name, schema, info_cache=info_cache): if pk in table.c: col = table.c[pk] @@ -819,7 +833,7 @@ class PGDialect(default.DefaultDialect): if col.default is None: col.autoincrement = False # Foreign keys - fkeys = self.get_foreign_keys(connection, tablename, schemaname, + fkeys = self.get_foreign_keys(connection, table_name, schema, info_cache=info_cache) for fkey_d in fkeys: conname = fkey_d['name'] @@ -829,25 +843,25 @@ class PGDialect(default.DefaultDialect): referred_columns = fkey_d['referred_columns'] refspec = [] if referred_schema is not None: - schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, + sa_schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection) for column in referred_columns: refspec.append(".".join([referred_schema, referred_table, column])) else: - schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) + sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) for column in referred_columns: refspec.append(".".join([referred_table, column])) - table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) + table.append_constraint(sa_schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) # Indexes - indexes = self.get_indexes(connection, tablename, schemaname, + indexes = self.get_indexes(connection, table_name, schema, info_cache=info_cache) for index_d in indexes: name = index_d['name'] columns = index_d['column_names'] unique = index_d['unique'] - schema.Index(name, *[table.columns[c] for c in columns], + sa_schema.Index(name, *[table.columns[c] for c in columns], **dict(unique=unique)) def _load_domains(self, connection): diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 8b13b4dcca..ed6160e83e 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -236,9 +236,6 @@ class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): 'vacuum', 'values', 'view', 'virtual', 'when', 'where', ]) -class SQLiteInfoCache(reflection.DefaultInfoCache): - pass - class SQLiteDialect(default.DefaultDialect): name = 'sqlite' supports_alter = False @@ -254,7 +251,6 @@ class SQLiteDialect(default.DefaultDialect): preparer = SQLiteIdentifierPreparer ischema_names = ischema_names colspecs = colspecs - info_cache = SQLiteInfoCache def table_names(self, connection, schema): if schema is not None: @@ -295,7 +291,7 @@ class SQLiteDialect(default.DefaultDialect): return (row is not None) - @reflection.caches + @reflection.cache def get_columns(self, connection, tablename, schemaname=None, info_cache=None): quote = self.identifier_preparer.quote_identifier @@ -342,7 +338,7 @@ class SQLiteDialect(default.DefaultDialect): }) return columns - @reflection.caches + @reflection.cache def get_foreign_keys(self, connection, tablename, schemaname=None, info_cache=None): quote = self.identifier_preparer.quote_identifier diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 6944a52624..5225b6d4b7 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -203,6 +203,11 @@ class Dialect(object): raise NotImplementedError() + def get_table_names(self, connection, schema=None): + """Return a list of table names for `schema`.""" + + raise NotImplementedError + def do_begin(self, connection): """Provide an implementation of *connection.begin()*, given a DB-API connection.""" diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index f1766ec2a6..bb22cc42c4 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -17,7 +17,6 @@ I'm still trying to decide upon conventions for both the Inspector interface as """ -import inspect import sqlalchemy from sqlalchemy import util from sqlalchemy.types import TypeEngine @@ -35,307 +34,153 @@ def cache(fn, self, con, *args, **kw): info_cache[key] = ret return ret -# keeping this around until all dialects are fixed -@util.decorator -def caches(fn, self, con, *args, **kw): - # what are we caching? - fn_name = fn.__name__ - if not fn_name.startswith('get_'): - # don't recognize this. - return fn(self, con, *args, **kw) - else: - attr_to_cache = fn_name[4:] - # The first arguments will always be self and con. - # Assuming *args and *kw will be acceptable to info_cache method. - if 'info_cache' in kw: - kw_cp = kw.copy() - info_cache = kw_cp.pop('info_cache') - methodname = "%s_%s" % ('get', attr_to_cache) - # fixme. - for bad_kw in ('dblink', 'resolve_synonyms'): - if bad_kw in kw_cp: - del kw_cp[bad_kw] - information = getattr(info_cache, methodname)(*args, **kw_cp) - if information: - return information - information = fn(self, con, *args, **kw) - if 'info_cache' in locals(): - methodname = "%s_%s" % ('set', attr_to_cache) - getattr(info_cache, methodname)(information, *args, **kw_cp) - return information - -class DefaultInfoCache(object): - """Default implementation of InfoCache - - InfoCache provides a means for dialects to cache information obtained for - reflection and a convenient interface for setting and retrieving cached - data. - - """ - - def __init__(self): - self._cache = dict(schemas={}) - self.tables_are_complete = False - self.schemas_are_complete = False - self.views_are_complete = False - - def clear(self): - """Clear the cache.""" - self._cache = dict(schemas={}) - - # schemas - - def get_schemas(self): - """Return the schemas dict.""" - return self._cache.get('schemas') +class Inspector(object): + """Performs database schema inspection. - def get_schema(self, schemaname, create=False): - """Return cached schema and optionally create it if it does not exist. + The Inspector acts as a proxy to the dialects' reflection methods and + provides higher level functions for accessing database schema information. + """ + + def __init__(self, conn): """ - schema = self._cache['schemas'].get(schemaname) - if schema is not None: - return schema - elif create: - return self.add_schema(schemaname) - return None - def add_schema(self, schemaname): - self._cache['schemas'][schemaname] = dict(tables={}, views={}) - return self.get_schema(schemaname) - - def get_schema_names(self, check_complete=True): - """Return cached schema names. - - By default, only return them if they're complete. + conn + [sqlalchemy.engine.base.#Connectable] """ - if check_complete and self.schemas_are_complete: - return self.get_schemas().keys() - elif not check_complete: - return self.get_schemas().keys() + self.conn = conn + # set the engine + if hasattr(conn, 'engine'): + self.engine = conn.engine else: - return None - - def set_schema_names(self, schemanames): - for schemaname in schemanames: - self.add_schema(schemaname) - self.schemas_are_complete = True + self.engine = conn + self.dialect = self.engine.dialect + self.info_cache = {} - # tables + @classmethod + def from_engine(cls, engine): + if hasattr(engine.dialect, 'inspector'): + return engine.dialect.inspector(engine) + return Inspector(engine) - def get_table(self, tablename, schemaname=None, create=False, - table_type='table'): - """Return cached table and optionally create it if it does not exist. + def default_schema_name(self): + return self.dialect.get_default_schema_name(self.conn) + default_schema_name = property(default_schema_name) + def get_schema_names(self): + """Return all schema names. """ - cache = self._cache - schema = self.get_schema(schemaname, create=create) - if schema is None: - return None - if table_type == 'view': - table = schema['views'].get(tablename) - else: - table = schema['tables'].get(tablename) - if table is not None: - return table - elif create: - return self.add_table(tablename, schemaname, table_type=table_type) - return None + if hasattr(self.dialect, 'get_schema_names'): + return self.dialect.get_schema_names(self.conn, + info_cache=self.info_cache) + return [] - def get_table_names(self, schemaname=None, check_complete=True, - table_type='table'): - """Return cached table names. + def get_table_names(self, schema=None, order_by=None): + """Return all table names in `schema`. + schema: + Optional, retrieve names from a non-default schema. - By default, only return them if they're complete. + This should probably not return view names or maybe it should return + them with an indicator t or v. """ - if table_type == 'view': - complete = self.views_are_complete - else: - complete = self.tables_are_complete - if check_complete and complete: - return self.get_tables(schemaname, table_type=table_type).keys() - elif not check_complete: - return self.get_tables(schemaname, table_type=table_type).keys() + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names(self.conn, + schema, + info_cache=self.info_cache) else: - return None - - def add_table(self, tablename, schemaname=None, table_type='table'): - schema = self.get_schema(schemaname, create=True) - if table_type == 'table': - schema['tables'][tablename] = dict(columns={}) - else: - schema['views'][tablename] = dict(columns={}) - return self.get_table(tablename, schemaname, table_type=table_type) - - def set_table_names(self, tablenames, schemaname=None, table_type='table'): - for tablename in tablenames: - self.add_table(tablename, schemaname, table_type) - if table_type == 'view': - self.views_are_complete = True - else: - self.tables_are_complete = True - - # views - - def get_view(self, viewname, schemaname=None, create=False): - return self.get_table(viewname, schemaname, create, 'view') - - def get_view_names(self, schemaname=None, check_complete=True): - return self.get_table_names(schemaname, check_complete, 'view') - - def add_view(self, viewname, schemaname=None): - return self.add_table(viewname, schemaname, 'view') - - def set_view_names(self, viewnames, schemaname=None): - return self.set_table_names(viewnames, schemaname, 'view') - - def get_view_definition(self, viewname, schemaname=None): - view_cache = self.get_view(viewname, schemaname) - if view_cache and 'definition' in view_cache: - return view_cache['definition'] - - def set_view_definition(self, definition, viewname, schemaname=None): - view_cache = self.get_view(viewname, schemaname, create=True) - view_cache['definition'] = definition - - # table data - - def _get_table_data(self, key, tablename, schemaname=None): - table_cache = self.get_table(tablename, schemaname) - if table_cache is not None and key in table_cache.keys(): - return table_cache[key] - - def _set_table_data(self, key, data, tablename, schemaname=None): - """Cache data for schemaname.tablename using key. - - It will create a schema and table entry in the cache if needed. + tnames = self.engine.table_names(schema) + if order_by == 'foreign_key': + ordered_tnames = tnames[:] + # Order based on foreign key dependencies. + for tname in tnames: + table_pos = tnames.index(tname) + fkeys = self.get_foreign_keys(tname, schema) + for fkey in fkeys: + rtable = fkey['referred_table'] + if rtable in ordered_tnames: + ref_pos = ordered_tnames.index(rtable) + # Make sure it's lower in the list than anything it + # references. + if table_pos > ref_pos: + ordered_tnames.pop(table_pos) # rtable moves up 1 + # insert just below rtable + ordered_tnames.index(ref_pos, tname) + tnames = ordered_tnames + return tnames + + def get_view_names(self, schema=None): + """Return all view names in `schema`. + schema: + Optional, retrieve names from a non-default schema. """ - table_cache = self.get_table(tablename, schemaname, create=True) - table_cache[key] = data - - # columns - - def get_columns(self, tablename, schemaname=None): - """Return columns list or None.""" - - return self._get_table_data('columns', tablename, schemaname) - - def set_columns(self, columns, tablename, schemaname=None): - """Add list of columns to table cache.""" - - return self._set_table_data('columns', columns, tablename, schemaname) - - # primary keys + return self.dialect.get_view_names(self.conn, schema, + info_cache=self.info_cache) - def get_primary_keys(self, tablename, schemaname=None): - """Return primary key list or None.""" - - return self._get_table_data('primary_keys', tablename, schemaname) + def get_view_definition(self, view_name, schema=None): + """Return definition for `view_name`. + schema: + Optional, retrieve names from a non-default schema. - def set_primary_keys(self, pkeys, tablename, schemaname=None): - """Add list of primary keys to table cache.""" + """ + return self.dialect.get_view_definition( + self.conn, view_name, schema, info_cache=self.info_cache) - return self._set_table_data('primary_keys', pkeys, tablename, schemaname) + def get_columns(self, table_name, schema=None): + """Return information about columns in `table_name`. - # foreign keys + Given a string `table_name` and an optional string `schema`, return + column information as a list of dicts with these keys: - def get_foreign_keys(self, tablename, schemaname=None): - """Return foreign key list or None.""" - - return self._get_table_data('foreign_keys', tablename, schemaname) + name + the column's name - def set_foreign_keys(self, fkeys, tablename, schemaname=None): - """Add list of foreign keys to table cache.""" + type + [sqlalchemy.types#TypeEngine] - return self._set_table_data('foreign_keys', fkeys, tablename, schemaname) + nullable + boolean - # indexes + default + the column's default value - def get_indexes(self, tablename, schemaname=None): - """Return indexes list or None.""" - - return self._get_table_data('indexes', tablename, schemaname) + attrs + dict containing optional column attributes - def set_indexes(self, indexes, tablename, schemaname=None): - """Add list of indexes to table cache.""" + """ - return self._set_table_data('indexes', indexes, tablename, schemaname) + col_defs = self.dialect.get_columns(self.conn, table_name, + schema, + info_cache=self.info_cache) + for col_def in col_defs: + # make this easy and only return instances for coltype + coltype = col_def['type'] + if not isinstance(coltype, TypeEngine): + col_def['type'] = coltype() + return col_defs -class Inspector(object): - """Performs database introspection. + def get_primary_keys(self, table_name, schema=None): + """Return information about primary keys in `table_name`. - The Inspector acts as a proxy to the dialects' reflection methods and - provides higher level functions for accessing database schema information. + Given a string `table_name`, and an optional string `schema`, return + primary key information as a list of column names. - """ - - def __init__(self, conn): """ - conn - [sqlalchemy.engine.base.#Connectable] - - Upon initialization, new members are added corresponding to the - refection members of the current dialect. + pkeys = self.dialect.get_primary_keys(self.conn, table_name, + schema, + info_cache=self.info_cache) - Dev Notes: - - I used attribute assignment rather than __getattr__ because - I want the Inspector to be inspectable including providing proper - documentation strings for the methods is supports. - - The primary reason for this approach: + return pkeys - 1. DRY. - 2. Provides access to dialect specific reflection methods. + def get_foreign_keys(self, table_name, schema=None): + """Return information about foreign_keys in `table_name`. - """ - self.conn = conn - # set the engine - if hasattr(conn, 'engine'): - self.engine = conn.engine - else: - self.engine = conn - # fixme. This is just until all dialects are converted - if hasattr(self, 'info_cache'): - self.info_cache = self.engine.dialect.info_cache() - else: - self.info_cache = {} - # add methods from dialect - def filter_reflect_members(m): - if inspect.ismethod(m) and m.__name__.startswith('get_'): - argspec = inspect.getargspec(m) - if isinstance(argspec, tuple) and 'connection' in argspec[0]: - return True - return False - reflection_members = inspect.getmembers(self.engine.dialect, - filter_reflect_members) - def wrap_reflection_method(fn): - def decorated(*args, **kwargs): - args = (self.conn,) + args - kwargs['info_cache'] = self.info_cache - return fn(*args, **kwargs) - return decorated - for (member_name, member) in reflection_members: - if not hasattr(self, member_name): - doc = "This method mirrors the dialect method %s." % member_name - wrapped_member = wrap_reflection_method(member) - wrapped_member.__doc__ = "%s\n\n%s" % (doc, member.__doc__) - setattr(self, member_name, wrapped_member) - - @property - def default_schema_name(self): - return self.engine.dialect.get_default_schema_name(self.conn) - - def get_foreign_keys(self, tablename, schemaname=None): - """Return information about foreign_keys in `tablename`. - - Given a string `tablename`, and an optional string `schemaname`, return + Given a string `table_name`, and an optional string `schema`, return foreign key information as a list of dicts with these keys: constrained_columns @@ -353,24 +198,36 @@ class Inspector(object): """ - fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename, - schemaname, + fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, + schema, info_cache=self.info_cache) for fk_def in fk_defs: referred_schema = fk_def['referred_schema'] # always set the referred_schema. - if referred_schema is None and schemaname is None: - referred_schema = self.engine.dialect.get_default_schema_name( + if referred_schema is None and schema is None: + referred_schema = self.dialect.get_default_schema_name( self.conn) fk_def['referred_schema'] = referred_schema return fk_defs - def get_relation_map(self, schemaname=None): - """Provide a mapping of the relations between all tables in schemaname. + def get_indexes(self, table_name, schema=None): + """Return information about indexes in `table_name`. - This is an example of a higher level function where Inspector can be - very useful. + Given a string `table_name` and an optional string `schema`, return + index information as a list of dicts with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean """ - #todo - pass + + indexes = self.dialect.get_indexes(self.conn, table_name, + schema, + info_cache=self.info_cache) + return indexes diff --git a/test/reflection.py b/test/reflection.py index 23e5befd31..999f0e2c22 100644 --- a/test/reflection.py +++ b/test/reflection.py @@ -9,6 +9,8 @@ from sqlalchemy.engine.reflection import Inspector from testlib.sa import MetaData, Table, Column from testlib import TestBase, testing, engines +create_inspector = Inspector.from_engine + if 'set' not in dir(__builtins__): from sets import Set as set @@ -65,20 +67,20 @@ def createIndexes(con, schema=None): con.execute(sa.sql.text(query)) def createViews(con, schema=None): - for tablename in ('users', 'email_addresses'): - fullname = tablename + for table_name in ('users', 'email_addresses'): + fullname = table_name if schema: - fullname = "%s.%s" % (schema, tablename) + fullname = "%s.%s" % (schema, table_name) view_name = fullname + '_v' query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name, fullname) con.execute(sa.sql.text(query)) def dropViews(con, schema=None): - for tablename in ('email_addresses', 'users'): - fullname = tablename + for table_name in ('email_addresses', 'users'): + fullname = table_name if schema: - fullname = "%s.%s" % (schema, tablename) + fullname = "%s.%s" % (schema, table_name) view_name = fullname + '_v' query = "DROP VIEW %s" % view_name con.execute(sa.sql.text(query)) @@ -91,20 +93,20 @@ class ReflectionTest(TestBase): insp = Inspector(meta.bind) self.assert_(getSchema() in insp.get_schema_names()) - def _test_get_table_names(self, schemaname=None, table_type='table', + def _test_get_table_names(self, schema=None, table_type='table', order_by=None): meta = MetaData(testing.db) - (users, addresses) = createTables(meta, schemaname) + (users, addresses) = createTables(meta, schema) meta.create_all() - createViews(meta.bind, schemaname) + createViews(meta.bind, schema) try: insp = Inspector(meta.bind) if table_type == 'view': - table_names = insp.get_view_names(schemaname) + table_names = insp.get_view_names(schema) table_names.sort() answer = ['email_addresses_v', 'users_v'] else: - table_names = insp.get_table_names(schemaname, + table_names = insp.get_table_names(schema, order_by=order_by) table_names.sort() if order_by == 'foreign_key': @@ -113,7 +115,7 @@ class ReflectionTest(TestBase): answer = ['email_addresses', 'users'] self.assertEqual(table_names, answer) finally: - dropViews(meta.bind, schemaname) + dropViews(meta.bind, schema) addresses.drop() users.drop() @@ -132,21 +134,21 @@ class ReflectionTest(TestBase): def test_get_view_names_with_schema(self): self._test_get_table_names(getSchema(), table_type='view') - def _test_get_columns(self, schemaname=None, table_type='table'): + def _test_get_columns(self, schema=None, table_type='table'): meta = MetaData(testing.db) - (users, addresses) = createTables(meta, schemaname) + (users, addresses) = createTables(meta, schema) table_names = ['users', 'email_addresses'] meta.create_all() if table_type == 'view': - createViews(meta.bind, schemaname) + createViews(meta.bind, schema) table_names = ['users_v', 'email_addresses_v'] try: insp = Inspector(meta.bind) - for (tablename, table) in zip(table_names, (users, addresses)): - schema_name = schemaname - if schemaname and testing.against('oracle'): - schema_name = schemaname.upper() - cols = insp.get_columns(tablename, schemaname=schema_name) + for (table_name, table) in zip(table_names, (users, addresses)): + schema_name = schema + if schema and testing.against('oracle'): + schema_name = schema.upper() + cols = insp.get_columns(table_name, schema=schema_name) self.assert_(len(cols) > 0, len(cols)) # should be in order for (i, col) in enumerate(table.columns): @@ -172,7 +174,7 @@ class ReflectionTest(TestBase): ctype))) finally: if table_type == 'view': - dropViews(meta.bind, schemaname) + dropViews(meta.bind, schema) addresses.drop() users.drop() @@ -180,25 +182,25 @@ class ReflectionTest(TestBase): self._test_get_columns() def test_get_columns_with_schema(self): - self._test_get_columns(schemaname=getSchema()) + self._test_get_columns(schema=getSchema()) def test_get_view_columns(self): self._test_get_columns(table_type='view') def test_get_view_columns_with_schema(self): - self._test_get_columns(schemaname=getSchema(), table_type='view') + self._test_get_columns(schema=getSchema(), table_type='view') - def _test_get_primary_keys(self, schemaname=None): + def _test_get_primary_keys(self, schema=None): meta = MetaData(testing.db) - (users, addresses) = createTables(meta, schemaname) + (users, addresses) = createTables(meta, schema) meta.create_all() insp = Inspector(meta.bind) try: users_pkeys = insp.get_primary_keys(users.name, - schemaname=schemaname) + schema=schema) self.assertEqual(users_pkeys, ['user_id']) addr_pkeys = insp.get_primary_keys(addresses.name, - schemaname=schemaname) + schema=schema) self.assertEqual(addr_pkeys, ['address_id']) finally: @@ -209,21 +211,21 @@ class ReflectionTest(TestBase): self._test_get_primary_keys() def test_get_primary_keys_with_schema(self): - self._test_get_primary_keys(schemaname=getSchema()) + self._test_get_primary_keys(schema=getSchema()) - def _test_get_foreign_keys(self, schemaname=None): + def _test_get_foreign_keys(self, schema=None): meta = MetaData(testing.db) - (users, addresses) = createTables(meta, schemaname) + (users, addresses) = createTables(meta, schema) meta.create_all() insp = Inspector(meta.bind) try: - expected_schema = schemaname - if schemaname is None: + expected_schema = schema + if schema is None: expected_schema = meta.bind.dialect.get_default_schema_name( meta.bind) # users users_fkeys = insp.get_foreign_keys(users.name, - schemaname=schemaname) + schema=schema) fkey1 = users_fkeys[0] self.assert_(fkey1['name'] is not None) self.assertEqual(fkey1['referred_schema'], expected_schema) @@ -232,7 +234,7 @@ class ReflectionTest(TestBase): self.assertEqual(fkey1['constrained_columns'], ['parent_user_id']) #addresses addr_fkeys = insp.get_foreign_keys(addresses.name, - schemaname=schemaname) + schema=schema) fkey1 = addr_fkeys[0] self.assert_(fkey1['name'] is not None) self.assertEqual(fkey1['referred_schema'], expected_schema) @@ -247,16 +249,16 @@ class ReflectionTest(TestBase): self._test_get_foreign_keys() def test_get_foreign_keys_with_schema(self): - self._test_get_foreign_keys(schemaname=getSchema()) + self._test_get_foreign_keys(schema=getSchema()) - def _test_get_indexes(self, schemaname=None): + def _test_get_indexes(self, schema=None): meta = MetaData(testing.db) - (users, addresses) = createTables(meta, schemaname) + (users, addresses) = createTables(meta, schema) meta.create_all() - createIndexes(meta.bind, schemaname) + createIndexes(meta.bind, schema) try: insp = Inspector(meta.bind) - indexes = insp.get_indexes('users', schemaname=schemaname) + indexes = insp.get_indexes('users', schema=schema) indexes.sort() if testing.against('oracle'): expected_indexes = [ @@ -277,23 +279,23 @@ class ReflectionTest(TestBase): self._test_get_indexes() def test_get_indexes_with_schema(self): - self._test_get_indexes(schemaname=getSchema()) + self._test_get_indexes(schema=getSchema()) - def _test_get_view_definition(self, schemaname=None): + def _test_get_view_definition(self, schema=None): meta = MetaData(testing.db) - (users, addresses) = createTables(meta, schemaname) + (users, addresses) = createTables(meta, schema) meta.create_all() - createViews(meta.bind, schemaname) + createViews(meta.bind, schema) view_name1 = 'users_v' view_name2 = 'email_addresses_v' try: insp = Inspector(meta.bind) - v1 = insp.get_view_definition(view_name1, schemaname=schemaname) + v1 = insp.get_view_definition(view_name1, schema=schema) self.assert_(v1) - v2 = insp.get_view_definition(view_name2, schemaname=schemaname) + v2 = insp.get_view_definition(view_name2, schema=schema) self.assert_(v2) finally: - dropViews(meta.bind, schemaname) + dropViews(meta.bind, schema) addresses.drop() users.drop() @@ -301,7 +303,26 @@ class ReflectionTest(TestBase): self._test_get_view_definition() def test_get_view_definition_with_schema(self): - self._test_get_view_definition(schemaname=getSchema()) + self._test_get_view_definition(schema=getSchema()) + + def _test_get_table_oid(self, table_name, schema=None): + if testing.against('postgres'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + try: + insp = create_inspector(meta.bind) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, int)) + finally: + addresses.drop() + users.drop() + + def test_get_table_oid(self): + self._test_get_table_oid('users') + + def test_get_table_oid_with_schema(self): + self._test_get_table_oid('users', schema=getSchema()) if __name__ == "__main__": testenv.main() -- 2.47.3