--- /dev/null
+"""Provides an abstraction for obtaining database schema information.
+
+Development Notes:
+
+I'm still trying to decide upon conventions for both the Inspector interface as well as the dialect interface the Inspector is to consume. Below are some of the current conventions.
+
+ 1. Inspector methods should return lists of dicts in most cases for the
+ following reasons:
+ * They're both simple standard types.
+ * Using a dict instead of a tuple allows easy expansion of attributes.
+ * Using a list for the outer structure maintains order and is easy to work
+ with (e.g. list comprehension [d['name'] for d in cols]).
+ * Being consistent is just good.
+ 2. Records that contain a name, such as the column name in a column record
+ should use the key 'name' in the dict. This allows the user to expect a
+ 'name' key and to know what it will reference.
+
+
+"""
+import sqlalchemy
+from sqlalchemy.types import TypeEngine
+
+class Inspector(object):
+ """performs database introspection
+
+ """
+
+ def __init__(self, conn):
+ """
+
+ conn
+ [sqlalchemy.engine.base.#Connectable]
+
+ """
+ self.info_cache = {}
+ self.conn = conn
+ # set the engine
+ if hasattr(conn, 'engine'):
+ self.engine = conn.engine
+ else:
+ self.engine = conn
+
+ def default_schema_name(self):
+ return self.engine.dialect.get_default_schema_name(self.conn)
+ default_schema_name = property(default_schema_name)
+
+ def get_schema_names(self):
+ """Return all schema names.
+
+ """
+ if hasattr(self.engine.dialect, 'get_schema_names'):
+ return self.engine.dialect.get_schema_names(self.conn,
+ self.info_cache)
+ return []
+
+ def get_table_names(self, schemaname=None):
+ """Return all table names in `schemaname`.
+ schemaname:
+ Optional, retrieve names from a non-default schema.
+
+ This should probably not return view names or maybe it should return
+ them with an indicator t or v.
+
+ """
+ if hasattr(self.engine.dialect, 'get_table_names'):
+ return self.engine.dialect.get_table_names(self.conn, schemaname,
+ self.info_cache)
+ return self.engine.table_names(schemaname)
+
+ def get_view_names(self, schemaname=None):
+ """Return all view names in `schemaname`.
+ schemaname:
+ Optional, retrieve names from a non-default schema.
+
+ """
+ return self.engine.dialect.get_view_names(self.conn, schemaname,
+ self.info_cache)
+
+ def get_view_definition(self, view_name, schemaname=None):
+ """Return definition for `view_name`.
+ schemaname:
+ Optional, retrieve names from a non-default schema.
+
+ """
+ return self.engine.dialect.get_view_definition(
+ self.conn, view_name, schemaname, self.info_cache)
+
+ def get_columns(self, tablename, schemaname=None):
+ """Return information about columns in `tablename`.
+
+ Given a string `tablename` and an optional string `schemaname`, return
+ column information as a list of dicts with these keys:
+
+ name
+ the column's name
+
+ type
+ [sqlalchemy.types#TypeEngine]
+
+ nullable
+ boolean
+
+ default
+ the column's default value
+
+ attrs
+ dict containing optional column attributes
+
+ """
+
+ col_defs = self.engine.dialect.get_columns(self.conn, tablename,
+ schemaname,
+ self.info_cache)
+ for col_def in col_defs:
+ # make this easy and only return instances for coltype
+ coltype = col_def['type']
+ if not isinstance(coltype, TypeEngine):
+ col_def['type'] = coltype()
+ return col_defs
+
+ def get_primary_keys(self, tablename, schemaname=None):
+ """Return information about primary keys in `tablename`.
+
+ Given a string `tablename`, and an optional string `schemaname`, return
+ primary key information as a list of column names:
+
+ """
+
+ pkeys = self.engine.dialect.get_primary_keys(self.conn, tablename,
+ schemaname,
+ self.info_cache)
+
+ return pkeys
+
+ def get_foreign_keys(self, tablename, schemaname=None):
+ """Return information about foreign_keys in `tablename`.
+
+ Given a string `tablename`, and an optional string `schemaname`, return
+ foreign key information as a list of dicts with these keys:
+
+ constrained_columns
+ a list of column names that make up the foreign key
+
+ referred_schema
+ the name of the referred schema
+
+ referred_table
+ the name of the referred table
+
+ referred_columns
+ a list of column names in the referred table that correspond to
+ constrained_columns
+
+ """
+
+ fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename,
+ schemaname,
+ self.info_cache)
+ for fk_def in fk_defs:
+ referred_schema = fk_def['referred_schema']
+ # always set the referred_schema.
+ if referred_schema is None and schemaname is None:
+ referred_schema = self.engine.dialect.get_default_schema_name(
+ self.conn)
+ fk_def['referred_schema'] = referred_schema
+ return fk_defs
+
+ def get_indexes(self, tablename, schemaname=None):
+ """Return information about indexes in `tablename`.
+
+ Given a string `tablename` and an optional string `schemaname`, return
+ index information as a list of dicts with these keys:
+
+ name
+ the index's name
+
+ column_names
+ list of column names in order
+
+ unique
+ boolean
+
+ """
+
+ indexes = self.engine.dialect.get_indexes(self.conn, tablename,
+ schemaname,
+ self.info_cache)
+ return indexes
--- /dev/null
+"""tests for sqlalchemy.engine.reflection
+
+"""
+
+import testenv; testenv.configure_for_tests()
+import sqlalchemy as sa
+from sqlalchemy.engine.reflection import Inspector
+from testlib.sa import MetaData, Table, Column
+from testlib import TestBase, testing, engines
+
+if 'set' not in dir(__builtins__):
+ from sets import Set as set
+
+def getSchema():
+ if testing.against('oracle'):
+ return u'scott'
+ else:
+ return u'test_schema'
+
+def createTables(meta, schema=None):
+ if schema:
+ parent_user_id = Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('%s.users.user_id' % schema)
+ )
+ else:
+ parent_user_id = Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('users.user_id')
+ )
+
+ users = Table('users', meta,
+ Column('user_id', sa.INT, primary_key=True),
+ Column('user_name', sa.VARCHAR(20), nullable=False),
+ Column('test1', sa.CHAR(5), nullable=False),
+ Column('test2', sa.Float(5), nullable=False),
+ Column('test3', sa.Text),
+ Column('test4', sa.Numeric, nullable = False),
+ Column('test5', sa.DateTime),
+ parent_user_id,
+ Column('test6', sa.DateTime, nullable=False),
+ Column('test7', sa.Text),
+ Column('test8', sa.Binary),
+ Column('test_passivedefault2', sa.Integer, server_default='5'),
+ Column('test9', sa.Binary(100)),
+ Column('test_numeric', sa.Numeric()),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ addresses = Table(u'email_addresses', meta,
+ Column('address_id', sa.Integer, primary_key = True),
+ Column('remote_user_id', sa.Integer,
+ sa.ForeignKey(users.c.user_id)),
+ Column('email_address', sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ return (users, addresses)
+
+def createIndexes(con, schema=None):
+ fullname = 'users'
+ if schema:
+ fullname = "%s.%s" % (schema, 'users')
+ query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname
+ con.execute(sa.sql.text(query))
+
+def createViews(con, schema=None):
+ for tablename in (u'users', u'email_addresses'):
+ fullname = tablename
+ if schema:
+ fullname = "%s.%s" % (schema, tablename)
+ view_name = fullname + '_v'
+ query = "CREATE OR REPLACE VIEW %s AS SELECT * FROM %s" % (view_name,
+ fullname)
+ con.execute(sa.sql.text(query))
+
+def dropViews(con, schema=None):
+ for tablename in (u'email_addresses', 'users'):
+ fullname = tablename
+ if schema:
+ fullname = "%s.%s" % (schema, tablename)
+ view_name = fullname + '_v'
+ query = "DROP VIEW %s" % view_name
+ con.execute(sa.sql.text(query))
+
+
+class ReflectionTest(TestBase):
+
+ def test_get_schema_names(self):
+ meta = MetaData(testing.db)
+ insp = Inspector(meta.bind)
+ self.assert_(getSchema() in insp.get_schema_names())
+
+ def _test_get_table_names(self, schemaname=None, table_type='table'):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schemaname)
+ meta.create_all()
+ createViews(meta.bind, schemaname)
+ try:
+ insp = Inspector(meta.bind)
+ if table_type == 'view':
+ table_names = insp.get_view_names(schemaname)
+ table_names.sort()
+ answer = [u'email_addresses_v', u'users_v']
+ else:
+ table_names = insp.get_table_names(schemaname)
+ table_names.sort()
+ answer = [u'email_addresses', 'users']
+ self.assertEqual(table_names, answer)
+ finally:
+ dropViews(meta.bind, schemaname)
+ addresses.drop()
+ users.drop()
+
+ def test_get_table_names(self):
+ self._test_get_table_names()
+
+ def test_get_table_names_with_schema(self):
+ self._test_get_table_names(getSchema())
+
+ def test_get_view_names(self):
+ self._test_get_table_names(table_type='view')
+
+ def test_get_view_names_with_schema(self):
+ self._test_get_table_names(getSchema(), table_type='view')
+
+ def _test_get_columns(self, schemaname=None, table_type='table'):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schemaname)
+ table_names = [u'users', u'email_addresses']
+ meta.create_all()
+ if table_type == 'view':
+ createViews(meta.bind, schemaname)
+ table_names = [u'users_v', u'email_addresses_v']
+ try:
+ insp = Inspector(meta.bind)
+ for (tablename, table) in zip(table_names, (users, addresses)):
+ schema_name = schemaname
+ if schemaname and testing.against('oracle'):
+ schema_name = schema.upper()
+ cols = insp.get_columns(tablename, schemaname=schema_name)
+ self.assert_(len(cols) > 0, len(cols))
+ # should be in order
+ for (i, col) in enumerate(table.columns):
+ self.assertEqual(col.name, cols[i]['name'])
+ # coltype is tricky
+ # It may not inherit from col.type while they share
+ # the same base.
+ coltype = cols[i]['type'].__class__
+ self.assert_(
+ issubclass(coltype, col.type.__class__) or \
+ len(
+ set(
+ coltype.__bases__
+ ).intersection(col.type.__class__.__bases__)) > 0
+ ,("%s, %s", (col.type, coltype)))
+ finally:
+ if table_type == 'view':
+ dropViews(meta.bind, schemaname)
+ addresses.drop()
+ users.drop()
+
+ def test_get_columns(self):
+ self._test_get_columns()
+
+ def test_get_columns_with_schema(self):
+ self._test_get_columns(schemaname=getSchema())
+
+ def test_get_view_columns(self):
+ self._test_get_columns(table_type='view')
+
+ def test_get_view_columns_with_schema(self):
+ self._test_get_columns(schemaname=getSchema(), table_type='view')
+
+ def _test_get_primary_keys(self, schemaname=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schemaname)
+ meta.create_all()
+ insp = Inspector(meta.bind)
+ try:
+ users_pkeys = insp.get_primary_keys(unicode(users.name),
+ schemaname=schemaname)
+ self.assertEqual(users_pkeys, ['user_id'])
+ addr_pkeys = insp.get_primary_keys(unicode(addresses.name),
+ schemaname=schemaname)
+ self.assertEqual(addr_pkeys, ['address_id'])
+
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_primary_keys(self):
+ self._test_get_primary_keys()
+
+ def test_get_primary_keys_with_schema(self):
+ self._test_get_primary_keys(schemaname=getSchema())
+
+ def _test_get_foreign_keys(self, schemaname=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schemaname)
+ meta.create_all()
+ insp = Inspector(meta.bind)
+ try:
+ expected_schema = schemaname
+ if schemaname is None:
+ expected_schema = meta.bind.dialect.get_default_schema_name(
+ meta.bind)
+ # users
+ users_fkeys = insp.get_foreign_keys(unicode(users.name),
+ schemaname=schemaname)
+ fkey1 = users_fkeys[0]
+ self.assert_(fkey1['name'] is not None)
+ self.assertEqual(fkey1['referred_schema'], expected_schema)
+ self.assertEqual(fkey1['referred_table'], users.name)
+ self.assertEqual(fkey1['referred_columns'], ['user_id', ])
+ self.assertEqual(fkey1['constrained_columns'], ['parent_user_id'])
+ #addresses
+ addr_fkeys = insp.get_foreign_keys(addresses.name,
+ schemaname=schemaname)
+ fkey1 = addr_fkeys[0]
+ self.assert_(fkey1['name'] is not None)
+ self.assertEqual(fkey1['referred_schema'], expected_schema)
+ self.assertEqual(fkey1['referred_table'], users.name)
+ self.assertEqual(fkey1['referred_columns'], ['user_id', ])
+ self.assertEqual(fkey1['constrained_columns'], ['remote_user_id'])
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_foreign_keys(self):
+ self._test_get_foreign_keys()
+
+ def test_get_foreign_keys_with_schema(self):
+ self._test_get_foreign_keys(schemaname=getSchema())
+
+ def _test_get_indexes(self, schemaname=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schemaname)
+ meta.create_all()
+ createIndexes(meta.bind, schemaname)
+ try:
+ insp = Inspector(meta.bind)
+ indexes = insp.get_indexes(u'users', schemaname=schemaname)
+ indexes.sort()
+ expected_indexes = [
+ {'unique': False,
+ 'column_names': ['test1', 'test2'],
+ 'name': 'users_t_idx'}]
+ self.assertEqual(indexes, expected_indexes)
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_indexes(self):
+ self._test_get_indexes()
+
+ def test_get_indexes_with_schema(self):
+ self._test_get_indexes(schemaname=getSchema())
+
+ def _test_get_view_definition(self, schemaname=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schemaname)
+ meta.create_all()
+ createViews(meta.bind, schemaname)
+ view_name1 = u'users_v'
+ view_name2 = u'email_addresses_v'
+ try:
+ insp = Inspector(meta.bind)
+ v1 = insp.get_view_definition(view_name1, schemaname=schemaname)
+ self.assert_(v1)
+ v2 = insp.get_view_definition(view_name2, schemaname=schemaname)
+ self.assert_(v2)
+ finally:
+ dropViews(meta.bind, schemaname)
+ addresses.drop()
+ users.drop()
+
+ def test_get_view_definition(self):
+ self._test_get_view_definition()
+
+ def test_get_view_definition_with_schema(self):
+ self._test_get_view_definition(schemaname=getSchema())
+
+if __name__ == "__main__":
+ testenv.main()