return value
-class MSInfoCache(reflection.DefaultInfoCache):
-
- def __init__(self, *args, **kwargs):
- reflection.DefaultInfoCache.__init__(self, *args, **kwargs)
-
-
class MSDialect(default.DefaultDialect):
name = 'mssql'
supports_default_values = True
ddl_compiler = MSDDLCompiler
type_compiler = MSTypeCompiler
preparer = MSIdentifierPreparer
- info_cache = MSInfoCache
def __init__(self,
auto_identity_insert=True, query_timeout=None,
row = c.fetchone()
return row is not None
- @reflection.caches
+ @reflection.cache
def get_schema_names(self, connection, info_cache=None):
import sqlalchemy.dialects.information_schema as ischema
s = sql.select([self.uppercase_table(ischema.schemata).c.schema_name],
schema_names = [r[0] for r in connection.execute(s)]
return schema_names
- @reflection.caches
+ @reflection.cache
def get_table_names(self, connection, schemaname, info_cache=None):
import sqlalchemy.dialects.information_schema as ischema
current_schema = schemaname or self.get_default_schema_name(connection)
table_names = [r[0] for r in connection.execute(s)]
return table_names
- @reflection.caches
+ @reflection.cache
def get_view_names(self, connection, schemaname=None, info_cache=None):
import sqlalchemy.dialects.information_schema as ischema
current_schema = schemaname or self.get_default_schema_name(connection)
view_names = [r[0] for r in connection.execute(s)]
return view_names
- @reflection.caches
+ @reflection.cache
def get_indexes(self, connection, tablename, schemaname=None,
info_cache=None):
current_schema = schemaname or self.get_default_schema_name(connection)
})
return indexes
- @reflection.caches
+ @reflection.cache
def get_view_definition(self, connection, viewname, schemaname=None,
info_cache=None):
import sqlalchemy.dialects.information_schema as ischema
view_def = rp.scalar()
return view_def
- @reflection.caches
+ @reflection.cache
def get_columns(self, connection, tablename, schemaname=None,
info_cache=None):
# Get base columns
cols.append(cdict)
return cols
- @reflection.caches
+ @reflection.cache
def get_primary_keys(self, connection, tablename, schemaname=None,
info_cache=None):
import sqlalchemy.dialects.information_schema as ischema
pkeys.append(row[0])
return pkeys
- @reflection.caches
+ @reflection.cache
def get_foreign_keys(self, connection, tablename, schemaname=None,
info_cache=None):
import sqlalchemy.dialects.information_schema as ischema
import datetime, random, re
-from sqlalchemy import util, sql, schema, log
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import util, sql, log
from sqlalchemy.engine import default, base, reflection
from sqlalchemy.sql import compiler, visitors, expression
from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
name = re.sub(r'^_+', '', savepoint.ident)
return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
-class OracleInfoCache(reflection.DefaultInfoCache):
- pass
-
class OracleDialect(default.DefaultDialect):
name = 'oracle'
supports_alter = True
type_compiler = OracleTypeCompiler
preparer = OracleIdentifierPreparer
defaultrunner = OracleDefaultRunner
- info_cache = OracleInfoCache
def __init__(self,
use_ansi=True,
else:
return None, None, None, None
- def _prepare_reflection_args(self, connection, tablename, schemaname=None,
+ def _prepare_reflection_args(self, connection, table_name, schema=None,
resolve_synonyms=False, dblink=''):
if resolve_synonyms:
- actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schemaname), desired_synonym=self._denormalize_name(tablename))
+ actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schema), desired_synonym=self._denormalize_name(table_name))
else:
actual_name, owner, dblink, synonym = None, None, None, None
if not actual_name:
- actual_name = self._denormalize_name(tablename)
+ actual_name = self._denormalize_name(table_name)
if not dblink:
dblink = ''
if not owner:
- owner = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
+ owner = self._denormalize_name(schema or self.get_default_schema_name(connection))
return (actual_name, owner, dblink, synonym)
- @reflection.caches
- def get_schema_names(self, connection, info_cache=None):
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
s = "SELECT username FROM all_users ORDER BY username"
cursor = connection.execute(s,)
return [self._normalize_name(row[0]) for row in cursor]
- @reflection.caches
- def get_table_names(self, connection, schemaname=None, info_cache=None):
- schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
- return self.table_names(connection, schemaname)
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ schema = self._denormalize_name(schema or self.get_default_schema_name(connection))
+ return self.table_names(connection, schema)
- @reflection.caches
- def get_view_names(self, connection, schemaname=None, info_cache=None):
- schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ schema = self._denormalize_name(schema or self.get_default_schema_name(connection))
s = "select view_name from all_views where OWNER = :owner"
cursor = connection.execute(s,
- {'owner':self._denormalize_name(schemaname)})
+ {'owner':self._denormalize_name(schema)})
return [self._normalize_name(row[0]) for row in cursor]
- @reflection.caches
- def get_columns(self, connection, tablename, schemaname=None,
- info_cache=None, resolve_synonyms=False, dblink=''):
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
- (tablename, schemaname, dblink, synonym) = \
- self._prepare_reflection_args(connection, tablename, schemaname,
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
resolve_synonyms, dblink)
columns = []
- c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':tablename, 'owner':schemaname})
+ c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':table_name, 'owner':schema})
while True:
row = c.fetchone()
colargs = []
if default is not None:
- colargs.append(schema.DefaultClause(sql.text(default)))
+ colargs.append(sa_schema.DefaultClause(sql.text(default)))
cdict = {
'name': colname,
'type': coltype,
columns.append(cdict)
return columns
- @reflection.caches
- def get_indexes(self, connection, tablename, schemaname=None,
- info_cache=None, resolve_synonyms=False, dblink=''):
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
- (tablename, schemaname, dblink, synonym) = \
- self._prepare_reflection_args(connection, tablename, schemaname,
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
resolve_synonyms, dblink)
indexes = []
q = """
ON a.INDEX_NAME = b.INDEX_NAME
AND a.TABLE_OWNER = b.TABLE_OWNER
AND a.TABLE_NAME = b.TABLE_NAME
- WHERE a.TABLE_NAME = :tablename
- AND a.TABLE_OWNER = :schemaname
+ WHERE a.TABLE_NAME = :table_name
+ AND a.TABLE_OWNER = :schema
ORDER BY a.INDEX_NAME, a.COLUMN_POSITION
""" % dict(dblink=dblink)
rp = connection.execute(q,
- dict(tablename=self._denormalize_name(tablename),
- schemaname=self._denormalize_name(schemaname)))
+ dict(table_name=self._denormalize_name(table_name),
+ schema=self._denormalize_name(schema)))
indexes = []
last_index_name = None
- pkeys = self.get_primary_keys(connection, tablename, schemaname,
- info_cache, resolve_synonyms, dblink)
+ pkeys = self.get_primary_keys(connection, table_name, schema,
+ resolve_synonyms, dblink,
+ info_cache=info_cache)
uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
for rset in rp:
# don't include the primary key columns
last_index_name = rset.index_name
return indexes
- def _get_constraint_data(self, connection, tablename, schemaname=None,
- info_cache=None, dblink=''):
+ @reflection.cache
+ def _get_constraint_data(self, connection, table_name, schema=None,
+ dblink='', **kw):
rp = connection.execute("""SELECT
ac.constraint_name,
AND ac.r_constraint_name = rem.constraint_name(+)
-- order multiple primary keys correctly
ORDER BY ac.constraint_name, loc.position, rem.position"""
- % {'dblink':dblink}, {'table_name' : tablename, 'owner' : schemaname})
+ % {'dblink':dblink}, {'table_name' : table_name, 'owner' : schema})
constraint_data = rp.fetchall()
return constraint_data
- @reflection.caches
- def get_primary_keys(self, connection, tablename, schemaname=None,
- info_cache=None, resolve_synonyms=False, dblink=''):
- (tablename, schemaname, dblink, synonym) = \
- self._prepare_reflection_args(connection, tablename, schemaname,
+ @reflection.cache
+ def get_primary_keys(self, connection, table_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
resolve_synonyms, dblink)
pkeys = []
- constraint_data = self._get_constraint_data(connection, tablename,
- schemaname, info_cache, dblink)
+ constraint_data = self._get_constraint_data(connection, table_name,
+ schema, dblink,
+ info_cache=kw.get('info_cache'))
for row in constraint_data:
#print "ROW:" , row
(cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
pkeys.append(local_column)
return pkeys
- @reflection.caches
- def get_foreign_keys(self, connection, tablename, schemaname=None,
- info_cache=None, resolve_synonyms=False, dblink=''):
- (tablename, schemaname, dblink, synonym) = \
- self._prepare_reflection_args(connection, tablename, schemaname,
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
resolve_synonyms, dblink)
- constraint_data = self._get_constraint_data(connection, tablename,
- schemaname, info_cache, dblink)
+ constraint_data = self._get_constraint_data(connection, table_name,
+ schema, dblink,
+ info_cache=kw.get('info_cache'))
fkeys = []
fks = {}
for row in constraint_data:
fkeys.append(fkey_d)
return fkeys
- @reflection.caches
- def get_view_definition(self, connection, viewname, schemaname=None,
- info_cache=None, resolve_synonyms=False, dblink=''):
- (viewname, schemaname, dblink, synonym) = \
- self._prepare_reflection_args(connection, viewname, schemaname,
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
+ (view_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, view_name, schema,
resolve_synonyms, dblink)
s = """
SELECT text FROM all_views
- WHERE owner = :schemaname
- AND view_name = :viewname
+ WHERE owner = :schema
+ AND view_name = :view_name
"""
rp = connection.execute(sql.text(s),
- viewname=viewname, schemaname=schemaname)
+ view_name=view_name, schema=schema)
if rp:
view_def = rp.scalar().decode(self.encoding)
return view_def
def reflecttable(self, connection, table, include_columns):
preparer = self.identifier_preparer
- info_cache = OracleInfoCache()
+ info_cache = {}
resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
resolve_synonyms)
# columns
- columns = self.get_columns(connection, actual_name, owner, info_cache,
- dblink)
+ columns = self.get_columns(connection, actual_name, owner, dblink,
+ info_cache=info_cache)
for cdict in columns:
colname = cdict['name']
coltype = cdict['type']
colargs = cdict['attrs']
if include_columns and colname not in include_columns:
continue
- table.append_column(schema.Column(colname, coltype,
+ table.append_column(sa_schema.Column(colname, coltype,
nullable=nullable, *colargs))
if not table.columns:
raise AssertionError("Couldn't find any column information for table %s" % actual_name)
# primary keys
for pkcol in self.get_primary_keys(connection, actual_name, owner,
- info_cache, dblink):
+ dblink, info_cache=info_cache):
if pkcol in table.c:
table.primary_key.add(table.c[pkcol])
fks = {}
fkeys = []
fkeys = self.get_foreign_keys(connection, actual_name, owner,
- info_cache, resolve_synonyms, dblink)
+ resolve_synonyms, dblink,
+ info_cache=info_cache)
refspecs = []
for fkey_d in fkeys:
conname = fkey_d['name']
referred_columns = fkey_d['referred_columns']
for (i, ref_col) in enumerate(referred_columns):
if not table.schema and self._denormalize_name(referred_schema) == self._denormalize_name(owner):
- t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
+ t = sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
refspec = ".".join([referred_table, ref_col])
else:
refspec = '.'.join([x for x in [referred_schema,
referred_table, ref_col] if x is not None])
- t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
+ t = sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
refspecs.append(refspec)
table.append_constraint(
- schema.ForeignKeyConstraint(constrained_columns, refspecs,
+ sa_schema.ForeignKeyConstraint(constrained_columns, refspecs,
name=conname, link_to_name=True))
import re
+from sqlalchemy import schema as sa_schema
from sqlalchemy import sql, schema, exc, util
from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler, expression
value = value[1:-1].replace('""','"')
return value
+class PGInspector(reflection.Inspector):
+
+ def __init__(self, conn):
+ reflection.Inspector.__init__(self, conn)
+
+ def get_table_oid(self, table_name, schema=None):
+ """Return the oid from `table_name` and `schema`."""
+
+ return self.dialect.get_table_oid(self.conn, table_name, schema,
+ info_cache=self.info_cache)
+
+
class PGDialect(default.DefaultDialect):
name = 'postgres'
supports_alter = True
type_compiler = PGTypeCompiler
preparer = PGIdentifierPreparer
defaultrunner = PGDefaultRunner
+ inspector = PGInspector
def do_begin_twophase(self, connection, xid):
return tuple([int(x) for x in m.group(1, 2, 3)])
@reflection.cache
- def get_table_oid(self, connection, tablename, schemaname=None, **kw):
- """Fetch the oid for schemaname.tablename.
+ def get_table_oid(self, connection, table_name, schema=None, **kw):
+ """Fetch the oid for schema.table_name.
Several reflection methods require the table oid. The idea for using
this method is that it can be fetched one time and cached for
"""
table_oid = None
- if schemaname is not None:
+ if schema is not None:
schema_where_clause = "n.nspname = :schema"
else:
schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
WHERE (%s)
AND c.relname = :table_name AND c.relkind in ('r','v')
""" % schema_where_clause
- # Since we're binding to unicode, tablename and schemaname must be
+ # Since we're binding to unicode, table_name and schema_name must be
# unicode.
- tablename = unicode(tablename)
- if schemaname is not None:
- schemaname = unicode(schemaname)
+ table_name = unicode(table_name)
+ if schema is not None:
+ schema = unicode(schema)
s = sql.text(query, bindparams=[
sql.bindparam('table_name', type_=sqltypes.Unicode),
sql.bindparam('schema', type_=sqltypes.Unicode)
],
typemap={'oid':sqltypes.Integer}
)
- c = connection.execute(s, table_name=tablename, schema=schemaname)
+ c = connection.execute(s, table_name=table_name, schema=schema)
table_oid = c.scalar()
if table_oid is None:
- raise exc.NoSuchTableError(tablename)
+ raise exc.NoSuchTableError(table_name)
return table_oid
@reflection.cache
return schema_names
@reflection.cache
- def get_table_names(self, connection, schemaname=None, **kw):
- if schemaname is not None:
- current_schema = schemaname
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
else:
current_schema = self.get_default_schema_name(connection)
table_names = self.table_names(connection, current_schema)
return table_names
@reflection.cache
- def get_view_names(self, connection, schemaname=None, **kw):
- if schemaname is not None:
- current_schema = schemaname
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
else:
current_schema = self.get_default_schema_name(connection)
s = """
return view_names
@reflection.cache
- def get_view_definition(self, connection, viewname, schemaname=None, **kw):
- if schemaname is not None:
- current_schema = schemaname
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
else:
current_schema = self.get_default_schema_name(connection)
s = """
SELECT definition FROM pg_views
- WHERE schemaname = :schemaname
- AND viewname = :viewname
+ WHERE schemaname = :schema
+ AND viewname = :view_name
"""
rp = connection.execute(sql.text(s),
- viewname=viewname, schemaname=current_schema)
+ view_name=view_name, schema=current_schema)
if rp:
view_def = rp.scalar().decode(self.encoding)
return view_def
@reflection.cache
- def get_columns(self, connection, tablename, schemaname=None, **kw):
+ def get_columns(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(connection, tablename, schemaname,
+ table_oid = self.get_table_oid(connection, table_name, schema,
info_cache=kw.get('info_cache'))
SQL_COLS = """
SELECT a.attname,
return columns
@reflection.cache
- def get_primary_keys(self, connection, tablename, schemaname=None, **kw):
- table_oid = self.get_table_oid(connection, tablename, schemaname,
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(connection, table_name, schema,
info_cache=kw.get('info_cache'))
PK_SQL = """
SELECT attname FROM pg_attribute
return primary_keys
@reflection.cache
- def get_foreign_keys(self, connection, tablename, schemaname=None, **kw):
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
preparer = self.identifier_preparer
- table_oid = self.get_table_oid(connection, tablename, schemaname,
+ table_oid = self.get_table_oid(connection, table_name, schema,
info_cache=kw.get('info_cache'))
FK_SQL = """
SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
if referred_schema:
referred_schema = preparer._unquote_identifier(referred_schema)
- elif schemaname is not None and schemaname == self.get_default_schema_name(connection):
+ elif schema is not None and schema == self.get_default_schema_name(connection):
# no schema (i.e. its the default schema), and the table we're
# reflecting has the default schema explicit, then use that.
# i.e. try to use the user's conventions
- referred_schema = schemaname
+ referred_schema = schema
referred_table = preparer._unquote_identifier(referred_table)
referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
fkey_d = {
return fkeys
@reflection.cache
- def get_indexes(self, connection, tablename, schemaname, **kw):
- table_oid = self.get_table_oid(connection, tablename, schemaname,
+ def get_indexes(self, connection, table_name, schema, **kw):
+ table_oid = self.get_table_oid(connection, table_name, schema,
info_cache=kw.get('info_cache'))
IDX_SQL = """
SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
def reflecttable(self, connection, table, include_columns):
preparer = self.identifier_preparer
- schemaname = table.schema
- tablename = table.name
+ schema = table.schema
+ table_name = table.name
info_cache = {}
# Py2K
- if isinstance(schemaname, str):
- schemaname = schemaname.decode(self.encoding)
- if isinstance(tablename, str):
- tablename = tablename.decode(self.encoding)
+ if isinstance(schema, str):
+ schema = schema.decode(self.encoding)
+ if isinstance(table_name, str):
+ table_name = table_name.decode(self.encoding)
# end Py2K
- for col_d in self.get_columns(connection, tablename, schemaname,
+ for col_d in self.get_columns(connection, table_name, schema,
info_cache=info_cache):
name = col_d['name']
coltype = col_d['type']
match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
if match is not None:
# the default is related to a Sequence
- sch = schemaname
+ sch = schema
if '.' not in match.group(2) and sch is not None:
# unconditionally quote the schema name. this could
# later be enhanced to obey quoting rules / "quote schema"
default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
- colargs.append(schema.DefaultClause(sql.text(default)))
- table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+ colargs.append(sa_schema.DefaultClause(sql.text(default)))
+ table.append_column(sa_schema.Column(name, coltype, nullable=nullable, *colargs))
# Now we have the table oid cached.
- table_oid = self.get_table_oid(connection, tablename, schemaname,
+ table_oid = self.get_table_oid(connection, table_name, schema,
info_cache=info_cache)
# Primary keys
- for pk in self.get_primary_keys(connection, tablename, schemaname,
+ for pk in self.get_primary_keys(connection, table_name, schema,
info_cache=info_cache):
if pk in table.c:
col = table.c[pk]
if col.default is None:
col.autoincrement = False
# Foreign keys
- fkeys = self.get_foreign_keys(connection, tablename, schemaname,
+ fkeys = self.get_foreign_keys(connection, table_name, schema,
info_cache=info_cache)
for fkey_d in fkeys:
conname = fkey_d['name']
referred_columns = fkey_d['referred_columns']
refspec = []
if referred_schema is not None:
- schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
+ sa_schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
autoload_with=connection)
for column in referred_columns:
refspec.append(".".join([referred_schema, referred_table, column]))
else:
- schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
+ sa_schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
- table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
+ table.append_constraint(sa_schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
# Indexes
- indexes = self.get_indexes(connection, tablename, schemaname,
+ indexes = self.get_indexes(connection, table_name, schema,
info_cache=info_cache)
for index_d in indexes:
name = index_d['name']
columns = index_d['column_names']
unique = index_d['unique']
- schema.Index(name, *[table.columns[c] for c in columns],
+ sa_schema.Index(name, *[table.columns[c] for c in columns],
**dict(unique=unique))
def _load_domains(self, connection):
'vacuum', 'values', 'view', 'virtual', 'when', 'where',
])
-class SQLiteInfoCache(reflection.DefaultInfoCache):
- pass
-
class SQLiteDialect(default.DefaultDialect):
name = 'sqlite'
supports_alter = False
preparer = SQLiteIdentifierPreparer
ischema_names = ischema_names
colspecs = colspecs
- info_cache = SQLiteInfoCache
def table_names(self, connection, schema):
if schema is not None:
return (row is not None)
- @reflection.caches
+ @reflection.cache
def get_columns(self, connection, tablename, schemaname=None,
info_cache=None):
quote = self.identifier_preparer.quote_identifier
})
return columns
- @reflection.caches
+ @reflection.cache
def get_foreign_keys(self, connection, tablename, schemaname=None,
info_cache=None):
quote = self.identifier_preparer.quote_identifier
raise NotImplementedError()
+ def get_table_names(self, connection, schema=None):
+ """Return a list of table names for `schema`."""
+
+ raise NotImplementedError
+
def do_begin(self, connection):
"""Provide an implementation of *connection.begin()*, given a DB-API connection."""
"""
-import inspect
import sqlalchemy
from sqlalchemy import util
from sqlalchemy.types import TypeEngine
info_cache[key] = ret
return ret
-# keeping this around until all dialects are fixed
-@util.decorator
-def caches(fn, self, con, *args, **kw):
- # what are we caching?
- fn_name = fn.__name__
- if not fn_name.startswith('get_'):
- # don't recognize this.
- return fn(self, con, *args, **kw)
- else:
- attr_to_cache = fn_name[4:]
- # The first arguments will always be self and con.
- # Assuming *args and *kw will be acceptable to info_cache method.
- if 'info_cache' in kw:
- kw_cp = kw.copy()
- info_cache = kw_cp.pop('info_cache')
- methodname = "%s_%s" % ('get', attr_to_cache)
- # fixme.
- for bad_kw in ('dblink', 'resolve_synonyms'):
- if bad_kw in kw_cp:
- del kw_cp[bad_kw]
- information = getattr(info_cache, methodname)(*args, **kw_cp)
- if information:
- return information
- information = fn(self, con, *args, **kw)
- if 'info_cache' in locals():
- methodname = "%s_%s" % ('set', attr_to_cache)
- getattr(info_cache, methodname)(information, *args, **kw_cp)
- return information
-
-class DefaultInfoCache(object):
- """Default implementation of InfoCache
-
- InfoCache provides a means for dialects to cache information obtained for
- reflection and a convenient interface for setting and retrieving cached
- data.
-
- """
-
- def __init__(self):
- self._cache = dict(schemas={})
- self.tables_are_complete = False
- self.schemas_are_complete = False
- self.views_are_complete = False
-
- def clear(self):
- """Clear the cache."""
- self._cache = dict(schemas={})
-
- # schemas
-
- def get_schemas(self):
- """Return the schemas dict."""
- return self._cache.get('schemas')
+class Inspector(object):
+ """Performs database schema inspection.
- def get_schema(self, schemaname, create=False):
- """Return cached schema and optionally create it if it does not exist.
+ The Inspector acts as a proxy to the dialects' reflection methods and
+ provides higher level functions for accessing database schema information.
+ """
+
+ def __init__(self, conn):
"""
- schema = self._cache['schemas'].get(schemaname)
- if schema is not None:
- return schema
- elif create:
- return self.add_schema(schemaname)
- return None
- def add_schema(self, schemaname):
- self._cache['schemas'][schemaname] = dict(tables={}, views={})
- return self.get_schema(schemaname)
-
- def get_schema_names(self, check_complete=True):
- """Return cached schema names.
-
- By default, only return them if they're complete.
+ conn
+ [sqlalchemy.engine.base.#Connectable]
"""
- if check_complete and self.schemas_are_complete:
- return self.get_schemas().keys()
- elif not check_complete:
- return self.get_schemas().keys()
+ self.conn = conn
+ # set the engine
+ if hasattr(conn, 'engine'):
+ self.engine = conn.engine
else:
- return None
-
- def set_schema_names(self, schemanames):
- for schemaname in schemanames:
- self.add_schema(schemaname)
- self.schemas_are_complete = True
+ self.engine = conn
+ self.dialect = self.engine.dialect
+ self.info_cache = {}
- # tables
+ @classmethod
+ def from_engine(cls, engine):
+ if hasattr(engine.dialect, 'inspector'):
+ return engine.dialect.inspector(engine)
+ return Inspector(engine)
- def get_table(self, tablename, schemaname=None, create=False,
- table_type='table'):
- """Return cached table and optionally create it if it does not exist.
+ def default_schema_name(self):
+ return self.dialect.get_default_schema_name(self.conn)
+ default_schema_name = property(default_schema_name)
+ def get_schema_names(self):
+ """Return all schema names.
"""
- cache = self._cache
- schema = self.get_schema(schemaname, create=create)
- if schema is None:
- return None
- if table_type == 'view':
- table = schema['views'].get(tablename)
- else:
- table = schema['tables'].get(tablename)
- if table is not None:
- return table
- elif create:
- return self.add_table(tablename, schemaname, table_type=table_type)
- return None
+ if hasattr(self.dialect, 'get_schema_names'):
+ return self.dialect.get_schema_names(self.conn,
+ info_cache=self.info_cache)
+ return []
- def get_table_names(self, schemaname=None, check_complete=True,
- table_type='table'):
- """Return cached table names.
+ def get_table_names(self, schema=None, order_by=None):
+ """Return all table names in `schema`.
+ schema:
+ Optional, retrieve names from a non-default schema.
- By default, only return them if they're complete.
+ This should probably not return view names or maybe it should return
+ them with an indicator t or v.
"""
- if table_type == 'view':
- complete = self.views_are_complete
- else:
- complete = self.tables_are_complete
- if check_complete and complete:
- return self.get_tables(schemaname, table_type=table_type).keys()
- elif not check_complete:
- return self.get_tables(schemaname, table_type=table_type).keys()
+ if hasattr(self.dialect, 'get_table_names'):
+ tnames = self.dialect.get_table_names(self.conn,
+ schema,
+ info_cache=self.info_cache)
else:
- return None
-
- def add_table(self, tablename, schemaname=None, table_type='table'):
- schema = self.get_schema(schemaname, create=True)
- if table_type == 'table':
- schema['tables'][tablename] = dict(columns={})
- else:
- schema['views'][tablename] = dict(columns={})
- return self.get_table(tablename, schemaname, table_type=table_type)
-
- def set_table_names(self, tablenames, schemaname=None, table_type='table'):
- for tablename in tablenames:
- self.add_table(tablename, schemaname, table_type)
- if table_type == 'view':
- self.views_are_complete = True
- else:
- self.tables_are_complete = True
-
- # views
-
- def get_view(self, viewname, schemaname=None, create=False):
- return self.get_table(viewname, schemaname, create, 'view')
-
- def get_view_names(self, schemaname=None, check_complete=True):
- return self.get_table_names(schemaname, check_complete, 'view')
-
- def add_view(self, viewname, schemaname=None):
- return self.add_table(viewname, schemaname, 'view')
-
- def set_view_names(self, viewnames, schemaname=None):
- return self.set_table_names(viewnames, schemaname, 'view')
-
- def get_view_definition(self, viewname, schemaname=None):
- view_cache = self.get_view(viewname, schemaname)
- if view_cache and 'definition' in view_cache:
- return view_cache['definition']
-
- def set_view_definition(self, definition, viewname, schemaname=None):
- view_cache = self.get_view(viewname, schemaname, create=True)
- view_cache['definition'] = definition
-
- # table data
-
- def _get_table_data(self, key, tablename, schemaname=None):
- table_cache = self.get_table(tablename, schemaname)
- if table_cache is not None and key in table_cache.keys():
- return table_cache[key]
-
- def _set_table_data(self, key, data, tablename, schemaname=None):
- """Cache data for schemaname.tablename using key.
-
- It will create a schema and table entry in the cache if needed.
+ tnames = self.engine.table_names(schema)
+ if order_by == 'foreign_key':
+ ordered_tnames = tnames[:]
+ # Order based on foreign key dependencies.
+ for tname in tnames:
+ table_pos = tnames.index(tname)
+ fkeys = self.get_foreign_keys(tname, schema)
+ for fkey in fkeys:
+ rtable = fkey['referred_table']
+ if rtable in ordered_tnames:
+ ref_pos = ordered_tnames.index(rtable)
+ # Make sure it's lower in the list than anything it
+ # references.
+ if table_pos > ref_pos:
+ ordered_tnames.pop(table_pos) # rtable moves up 1
+ # insert just below rtable
+ ordered_tnames.index(ref_pos, tname)
+ tnames = ordered_tnames
+ return tnames
+
+ def get_view_names(self, schema=None):
+ """Return all view names in `schema`.
+ schema:
+ Optional, retrieve names from a non-default schema.
"""
- table_cache = self.get_table(tablename, schemaname, create=True)
- table_cache[key] = data
-
- # columns
-
- def get_columns(self, tablename, schemaname=None):
- """Return columns list or None."""
-
- return self._get_table_data('columns', tablename, schemaname)
-
- def set_columns(self, columns, tablename, schemaname=None):
- """Add list of columns to table cache."""
-
- return self._set_table_data('columns', columns, tablename, schemaname)
-
- # primary keys
+ return self.dialect.get_view_names(self.conn, schema,
+ info_cache=self.info_cache)
- def get_primary_keys(self, tablename, schemaname=None):
- """Return primary key list or None."""
-
- return self._get_table_data('primary_keys', tablename, schemaname)
+ def get_view_definition(self, view_name, schema=None):
+ """Return definition for `view_name`.
+ schema:
+ Optional, retrieve names from a non-default schema.
- def set_primary_keys(self, pkeys, tablename, schemaname=None):
- """Add list of primary keys to table cache."""
+ """
+ return self.dialect.get_view_definition(
+ self.conn, view_name, schema, info_cache=self.info_cache)
- return self._set_table_data('primary_keys', pkeys, tablename, schemaname)
+ def get_columns(self, table_name, schema=None):
+ """Return information about columns in `table_name`.
- # foreign keys
+ Given a string `table_name` and an optional string `schema`, return
+ column information as a list of dicts with these keys:
- def get_foreign_keys(self, tablename, schemaname=None):
- """Return foreign key list or None."""
-
- return self._get_table_data('foreign_keys', tablename, schemaname)
+ name
+ the column's name
- def set_foreign_keys(self, fkeys, tablename, schemaname=None):
- """Add list of foreign keys to table cache."""
+ type
+ [sqlalchemy.types#TypeEngine]
- return self._set_table_data('foreign_keys', fkeys, tablename, schemaname)
+ nullable
+ boolean
- # indexes
+ default
+ the column's default value
- def get_indexes(self, tablename, schemaname=None):
- """Return indexes list or None."""
-
- return self._get_table_data('indexes', tablename, schemaname)
+ attrs
+ dict containing optional column attributes
- def set_indexes(self, indexes, tablename, schemaname=None):
- """Add list of indexes to table cache."""
+ """
- return self._set_table_data('indexes', indexes, tablename, schemaname)
+ col_defs = self.dialect.get_columns(self.conn, table_name,
+ schema,
+ info_cache=self.info_cache)
+ for col_def in col_defs:
+ # make this easy and only return instances for coltype
+ coltype = col_def['type']
+ if not isinstance(coltype, TypeEngine):
+ col_def['type'] = coltype()
+ return col_defs
-class Inspector(object):
- """Performs database introspection.
+ def get_primary_keys(self, table_name, schema=None):
+ """Return information about primary keys in `table_name`.
- The Inspector acts as a proxy to the dialects' reflection methods and
- provides higher level functions for accessing database schema information.
+ Given a string `table_name`, and an optional string `schema`, return
+ primary key information as a list of column names.
- """
-
- def __init__(self, conn):
"""
- conn
- [sqlalchemy.engine.base.#Connectable]
-
- Upon initialization, new members are added corresponding to the
- refection members of the current dialect.
+ pkeys = self.dialect.get_primary_keys(self.conn, table_name,
+ schema,
+ info_cache=self.info_cache)
- Dev Notes:
-
- I used attribute assignment rather than __getattr__ because
- I want the Inspector to be inspectable including providing proper
- documentation strings for the methods is supports.
-
- The primary reason for this approach:
+ return pkeys
- 1. DRY.
- 2. Provides access to dialect specific reflection methods.
+ def get_foreign_keys(self, table_name, schema=None):
+ """Return information about foreign_keys in `table_name`.
- """
- self.conn = conn
- # set the engine
- if hasattr(conn, 'engine'):
- self.engine = conn.engine
- else:
- self.engine = conn
- # fixme. This is just until all dialects are converted
- if hasattr(self, 'info_cache'):
- self.info_cache = self.engine.dialect.info_cache()
- else:
- self.info_cache = {}
- # add methods from dialect
- def filter_reflect_members(m):
- if inspect.ismethod(m) and m.__name__.startswith('get_'):
- argspec = inspect.getargspec(m)
- if isinstance(argspec, tuple) and 'connection' in argspec[0]:
- return True
- return False
- reflection_members = inspect.getmembers(self.engine.dialect,
- filter_reflect_members)
- def wrap_reflection_method(fn):
- def decorated(*args, **kwargs):
- args = (self.conn,) + args
- kwargs['info_cache'] = self.info_cache
- return fn(*args, **kwargs)
- return decorated
- for (member_name, member) in reflection_members:
- if not hasattr(self, member_name):
- doc = "This method mirrors the dialect method %s." % member_name
- wrapped_member = wrap_reflection_method(member)
- wrapped_member.__doc__ = "%s\n\n%s" % (doc, member.__doc__)
- setattr(self, member_name, wrapped_member)
-
- @property
- def default_schema_name(self):
- return self.engine.dialect.get_default_schema_name(self.conn)
-
- def get_foreign_keys(self, tablename, schemaname=None):
- """Return information about foreign_keys in `tablename`.
-
- Given a string `tablename`, and an optional string `schemaname`, return
+ Given a string `table_name`, and an optional string `schema`, return
foreign key information as a list of dicts with these keys:
constrained_columns
"""
- fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename,
- schemaname,
+ fk_defs = self.dialect.get_foreign_keys(self.conn, table_name,
+ schema,
info_cache=self.info_cache)
for fk_def in fk_defs:
referred_schema = fk_def['referred_schema']
# always set the referred_schema.
- if referred_schema is None and schemaname is None:
- referred_schema = self.engine.dialect.get_default_schema_name(
+ if referred_schema is None and schema is None:
+ referred_schema = self.dialect.get_default_schema_name(
self.conn)
fk_def['referred_schema'] = referred_schema
return fk_defs
- def get_relation_map(self, schemaname=None):
- """Provide a mapping of the relations between all tables in schemaname.
+ def get_indexes(self, table_name, schema=None):
+ """Return information about indexes in `table_name`.
- This is an example of a higher level function where Inspector can be
- very useful.
+ Given a string `table_name` and an optional string `schema`, return
+ index information as a list of dicts with these keys:
+
+ name
+ the index's name
+
+ column_names
+ list of column names in order
+
+ unique
+ boolean
"""
- #todo
- pass
+
+ indexes = self.dialect.get_indexes(self.conn, table_name,
+ schema,
+ info_cache=self.info_cache)
+ return indexes
from testlib.sa import MetaData, Table, Column
from testlib import TestBase, testing, engines
+create_inspector = Inspector.from_engine
+
if 'set' not in dir(__builtins__):
from sets import Set as set
con.execute(sa.sql.text(query))
def createViews(con, schema=None):
- for tablename in ('users', 'email_addresses'):
- fullname = tablename
+ for table_name in ('users', 'email_addresses'):
+ fullname = table_name
if schema:
- fullname = "%s.%s" % (schema, tablename)
+ fullname = "%s.%s" % (schema, table_name)
view_name = fullname + '_v'
query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name,
fullname)
con.execute(sa.sql.text(query))
def dropViews(con, schema=None):
- for tablename in ('email_addresses', 'users'):
- fullname = tablename
+ for table_name in ('email_addresses', 'users'):
+ fullname = table_name
if schema:
- fullname = "%s.%s" % (schema, tablename)
+ fullname = "%s.%s" % (schema, table_name)
view_name = fullname + '_v'
query = "DROP VIEW %s" % view_name
con.execute(sa.sql.text(query))
insp = Inspector(meta.bind)
self.assert_(getSchema() in insp.get_schema_names())
- def _test_get_table_names(self, schemaname=None, table_type='table',
+ def _test_get_table_names(self, schema=None, table_type='table',
order_by=None):
meta = MetaData(testing.db)
- (users, addresses) = createTables(meta, schemaname)
+ (users, addresses) = createTables(meta, schema)
meta.create_all()
- createViews(meta.bind, schemaname)
+ createViews(meta.bind, schema)
try:
insp = Inspector(meta.bind)
if table_type == 'view':
- table_names = insp.get_view_names(schemaname)
+ table_names = insp.get_view_names(schema)
table_names.sort()
answer = ['email_addresses_v', 'users_v']
else:
- table_names = insp.get_table_names(schemaname,
+ table_names = insp.get_table_names(schema,
order_by=order_by)
table_names.sort()
if order_by == 'foreign_key':
answer = ['email_addresses', 'users']
self.assertEqual(table_names, answer)
finally:
- dropViews(meta.bind, schemaname)
+ dropViews(meta.bind, schema)
addresses.drop()
users.drop()
def test_get_view_names_with_schema(self):
self._test_get_table_names(getSchema(), table_type='view')
- def _test_get_columns(self, schemaname=None, table_type='table'):
+ def _test_get_columns(self, schema=None, table_type='table'):
meta = MetaData(testing.db)
- (users, addresses) = createTables(meta, schemaname)
+ (users, addresses) = createTables(meta, schema)
table_names = ['users', 'email_addresses']
meta.create_all()
if table_type == 'view':
- createViews(meta.bind, schemaname)
+ createViews(meta.bind, schema)
table_names = ['users_v', 'email_addresses_v']
try:
insp = Inspector(meta.bind)
- for (tablename, table) in zip(table_names, (users, addresses)):
- schema_name = schemaname
- if schemaname and testing.against('oracle'):
- schema_name = schemaname.upper()
- cols = insp.get_columns(tablename, schemaname=schema_name)
+ for (table_name, table) in zip(table_names, (users, addresses)):
+ schema_name = schema
+ if schema and testing.against('oracle'):
+ schema_name = schema.upper()
+ cols = insp.get_columns(table_name, schema=schema_name)
self.assert_(len(cols) > 0, len(cols))
# should be in order
for (i, col) in enumerate(table.columns):
ctype)))
finally:
if table_type == 'view':
- dropViews(meta.bind, schemaname)
+ dropViews(meta.bind, schema)
addresses.drop()
users.drop()
self._test_get_columns()
def test_get_columns_with_schema(self):
- self._test_get_columns(schemaname=getSchema())
+ self._test_get_columns(schema=getSchema())
def test_get_view_columns(self):
self._test_get_columns(table_type='view')
def test_get_view_columns_with_schema(self):
- self._test_get_columns(schemaname=getSchema(), table_type='view')
+ self._test_get_columns(schema=getSchema(), table_type='view')
- def _test_get_primary_keys(self, schemaname=None):
+ def _test_get_primary_keys(self, schema=None):
meta = MetaData(testing.db)
- (users, addresses) = createTables(meta, schemaname)
+ (users, addresses) = createTables(meta, schema)
meta.create_all()
insp = Inspector(meta.bind)
try:
users_pkeys = insp.get_primary_keys(users.name,
- schemaname=schemaname)
+ schema=schema)
self.assertEqual(users_pkeys, ['user_id'])
addr_pkeys = insp.get_primary_keys(addresses.name,
- schemaname=schemaname)
+ schema=schema)
self.assertEqual(addr_pkeys, ['address_id'])
finally:
self._test_get_primary_keys()
def test_get_primary_keys_with_schema(self):
- self._test_get_primary_keys(schemaname=getSchema())
+ self._test_get_primary_keys(schema=getSchema())
- def _test_get_foreign_keys(self, schemaname=None):
+ def _test_get_foreign_keys(self, schema=None):
meta = MetaData(testing.db)
- (users, addresses) = createTables(meta, schemaname)
+ (users, addresses) = createTables(meta, schema)
meta.create_all()
insp = Inspector(meta.bind)
try:
- expected_schema = schemaname
- if schemaname is None:
+ expected_schema = schema
+ if schema is None:
expected_schema = meta.bind.dialect.get_default_schema_name(
meta.bind)
# users
users_fkeys = insp.get_foreign_keys(users.name,
- schemaname=schemaname)
+ schema=schema)
fkey1 = users_fkeys[0]
self.assert_(fkey1['name'] is not None)
self.assertEqual(fkey1['referred_schema'], expected_schema)
self.assertEqual(fkey1['constrained_columns'], ['parent_user_id'])
#addresses
addr_fkeys = insp.get_foreign_keys(addresses.name,
- schemaname=schemaname)
+ schema=schema)
fkey1 = addr_fkeys[0]
self.assert_(fkey1['name'] is not None)
self.assertEqual(fkey1['referred_schema'], expected_schema)
self._test_get_foreign_keys()
def test_get_foreign_keys_with_schema(self):
- self._test_get_foreign_keys(schemaname=getSchema())
+ self._test_get_foreign_keys(schema=getSchema())
- def _test_get_indexes(self, schemaname=None):
+ def _test_get_indexes(self, schema=None):
meta = MetaData(testing.db)
- (users, addresses) = createTables(meta, schemaname)
+ (users, addresses) = createTables(meta, schema)
meta.create_all()
- createIndexes(meta.bind, schemaname)
+ createIndexes(meta.bind, schema)
try:
insp = Inspector(meta.bind)
- indexes = insp.get_indexes('users', schemaname=schemaname)
+ indexes = insp.get_indexes('users', schema=schema)
indexes.sort()
if testing.against('oracle'):
expected_indexes = [
self._test_get_indexes()
def test_get_indexes_with_schema(self):
- self._test_get_indexes(schemaname=getSchema())
+ self._test_get_indexes(schema=getSchema())
- def _test_get_view_definition(self, schemaname=None):
+ def _test_get_view_definition(self, schema=None):
meta = MetaData(testing.db)
- (users, addresses) = createTables(meta, schemaname)
+ (users, addresses) = createTables(meta, schema)
meta.create_all()
- createViews(meta.bind, schemaname)
+ createViews(meta.bind, schema)
view_name1 = 'users_v'
view_name2 = 'email_addresses_v'
try:
insp = Inspector(meta.bind)
- v1 = insp.get_view_definition(view_name1, schemaname=schemaname)
+ v1 = insp.get_view_definition(view_name1, schema=schema)
self.assert_(v1)
- v2 = insp.get_view_definition(view_name2, schemaname=schemaname)
+ v2 = insp.get_view_definition(view_name2, schema=schema)
self.assert_(v2)
finally:
- dropViews(meta.bind, schemaname)
+ dropViews(meta.bind, schema)
addresses.drop()
users.drop()
self._test_get_view_definition()
def test_get_view_definition_with_schema(self):
- self._test_get_view_definition(schemaname=getSchema())
+ self._test_get_view_definition(schema=getSchema())
+
+ def _test_get_table_oid(self, table_name, schema=None):
+ if testing.against('postgres'):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ try:
+ insp = create_inspector(meta.bind)
+ oid = insp.get_table_oid(table_name, schema)
+ self.assert_(isinstance(oid, int))
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_table_oid(self):
+ self._test_get_table_oid('users')
+
+ def test_get_table_oid_with_schema(self):
+ self._test_get_table_oid('users', schema=getSchema())
if __name__ == "__main__":
testenv.main()