From 65d14c191cc65961a7eb671e516f023ab8b4fd43 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 4 Feb 2009 06:42:26 +0000 Subject: [PATCH] added reflection (inspector) and tests --- lib/sqlalchemy/engine/reflection.py | 188 ++++++++++++++++++ test/reflection.py | 283 ++++++++++++++++++++++++++++ 2 files changed, 471 insertions(+) create mode 100644 lib/sqlalchemy/engine/reflection.py create mode 100644 test/reflection.py diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py new file mode 100644 index 0000000000..99e86499e5 --- /dev/null +++ b/lib/sqlalchemy/engine/reflection.py @@ -0,0 +1,188 @@ +"""Provides an abstraction for obtaining database schema information. + +Development Notes: + +I'm still trying to decide upon conventions for both the Inspector interface as well as the dialect interface the Inspector is to consume. Below are some of the current conventions. + + 1. Inspector methods should return lists of dicts in most cases for the + following reasons: + * They're both simple standard types. + * Using a dict instead of a tuple allows easy expansion of attributes. + * Using a list for the outer structure maintains order and is easy to work + with (e.g. list comprehension [d['name'] for d in cols]). + * Being consistent is just good. + 2. Records that contain a name, such as the column name in a column record + should use the key 'name' in the dict. This allows the user to expect a + 'name' key and to know what it will reference. + + +""" +import sqlalchemy +from sqlalchemy.types import TypeEngine + +class Inspector(object): + """performs database introspection + + """ + + def __init__(self, conn): + """ + + conn + [sqlalchemy.engine.base.#Connectable] + + """ + self.info_cache = {} + self.conn = conn + # set the engine + if hasattr(conn, 'engine'): + self.engine = conn.engine + else: + self.engine = conn + + def default_schema_name(self): + return self.engine.dialect.get_default_schema_name(self.conn) + default_schema_name = property(default_schema_name) + + def get_schema_names(self): + """Return all schema names. + + """ + if hasattr(self.engine.dialect, 'get_schema_names'): + return self.engine.dialect.get_schema_names(self.conn, + self.info_cache) + return [] + + def get_table_names(self, schemaname=None): + """Return all table names in `schemaname`. + schemaname: + Optional, retrieve names from a non-default schema. + + This should probably not return view names or maybe it should return + them with an indicator t or v. + + """ + if hasattr(self.engine.dialect, 'get_table_names'): + return self.engine.dialect.get_table_names(self.conn, schemaname, + self.info_cache) + return self.engine.table_names(schemaname) + + def get_view_names(self, schemaname=None): + """Return all view names in `schemaname`. + schemaname: + Optional, retrieve names from a non-default schema. + + """ + return self.engine.dialect.get_view_names(self.conn, schemaname, + self.info_cache) + + def get_view_definition(self, view_name, schemaname=None): + """Return definition for `view_name`. + schemaname: + Optional, retrieve names from a non-default schema. + + """ + return self.engine.dialect.get_view_definition( + self.conn, view_name, schemaname, self.info_cache) + + def get_columns(self, tablename, schemaname=None): + """Return information about columns in `tablename`. + + Given a string `tablename` and an optional string `schemaname`, return + column information as a list of dicts with these keys: + + name + the column's name + + type + [sqlalchemy.types#TypeEngine] + + nullable + boolean + + default + the column's default value + + attrs + dict containing optional column attributes + + """ + + col_defs = self.engine.dialect.get_columns(self.conn, tablename, + schemaname, + self.info_cache) + for col_def in col_defs: + # make this easy and only return instances for coltype + coltype = col_def['type'] + if not isinstance(coltype, TypeEngine): + col_def['type'] = coltype() + return col_defs + + def get_primary_keys(self, tablename, schemaname=None): + """Return information about primary keys in `tablename`. + + Given a string `tablename`, and an optional string `schemaname`, return + primary key information as a list of column names: + + """ + + pkeys = self.engine.dialect.get_primary_keys(self.conn, tablename, + schemaname, + self.info_cache) + + return pkeys + + def get_foreign_keys(self, tablename, schemaname=None): + """Return information about foreign_keys in `tablename`. + + Given a string `tablename`, and an optional string `schemaname`, return + foreign key information as a list of dicts with these keys: + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + + """ + + fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename, + schemaname, + self.info_cache) + for fk_def in fk_defs: + referred_schema = fk_def['referred_schema'] + # always set the referred_schema. + if referred_schema is None and schemaname is None: + referred_schema = self.engine.dialect.get_default_schema_name( + self.conn) + fk_def['referred_schema'] = referred_schema + return fk_defs + + def get_indexes(self, tablename, schemaname=None): + """Return information about indexes in `tablename`. + + Given a string `tablename` and an optional string `schemaname`, return + index information as a list of dicts with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + + """ + + indexes = self.engine.dialect.get_indexes(self.conn, tablename, + schemaname, + self.info_cache) + return indexes diff --git a/test/reflection.py b/test/reflection.py new file mode 100644 index 0000000000..e3d911a4c9 --- /dev/null +++ b/test/reflection.py @@ -0,0 +1,283 @@ +"""tests for sqlalchemy.engine.reflection + +""" + +import testenv; testenv.configure_for_tests() +import sqlalchemy as sa +from sqlalchemy.engine.reflection import Inspector +from testlib.sa import MetaData, Table, Column +from testlib import TestBase, testing, engines + +if 'set' not in dir(__builtins__): + from sets import Set as set + +def getSchema(): + if testing.against('oracle'): + return u'scott' + else: + return u'test_schema' + +def createTables(meta, schema=None): + if schema: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('%s.users.user_id' % schema) + ) + else: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('users.user_id') + ) + + users = Table('users', meta, + Column('user_id', sa.INT, primary_key=True), + Column('user_name', sa.VARCHAR(20), nullable=False), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + Column('test3', sa.Text), + Column('test4', sa.Numeric, nullable = False), + Column('test5', sa.DateTime), + parent_user_id, + Column('test6', sa.DateTime, nullable=False), + Column('test7', sa.Text), + Column('test8', sa.Binary), + Column('test_passivedefault2', sa.Integer, server_default='5'), + Column('test9', sa.Binary(100)), + Column('test_numeric', sa.Numeric()), + schema=schema, + test_needs_fk=True, + ) + addresses = Table(u'email_addresses', meta, + Column('address_id', sa.Integer, primary_key = True), + Column('remote_user_id', sa.Integer, + sa.ForeignKey(users.c.user_id)), + Column('email_address', sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + return (users, addresses) + +def createIndexes(con, schema=None): + fullname = 'users' + if schema: + fullname = "%s.%s" % (schema, 'users') + query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname + con.execute(sa.sql.text(query)) + +def createViews(con, schema=None): + for tablename in (u'users', u'email_addresses'): + fullname = tablename + if schema: + fullname = "%s.%s" % (schema, tablename) + view_name = fullname + '_v' + query = "CREATE OR REPLACE VIEW %s AS SELECT * FROM %s" % (view_name, + fullname) + con.execute(sa.sql.text(query)) + +def dropViews(con, schema=None): + for tablename in (u'email_addresses', 'users'): + fullname = tablename + if schema: + fullname = "%s.%s" % (schema, tablename) + view_name = fullname + '_v' + query = "DROP VIEW %s" % view_name + con.execute(sa.sql.text(query)) + + +class ReflectionTest(TestBase): + + def test_get_schema_names(self): + meta = MetaData(testing.db) + insp = Inspector(meta.bind) + self.assert_(getSchema() in insp.get_schema_names()) + + def _test_get_table_names(self, schemaname=None, table_type='table'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schemaname) + meta.create_all() + createViews(meta.bind, schemaname) + try: + insp = Inspector(meta.bind) + if table_type == 'view': + table_names = insp.get_view_names(schemaname) + table_names.sort() + answer = [u'email_addresses_v', u'users_v'] + else: + table_names = insp.get_table_names(schemaname) + table_names.sort() + answer = [u'email_addresses', 'users'] + self.assertEqual(table_names, answer) + finally: + dropViews(meta.bind, schemaname) + addresses.drop() + users.drop() + + def test_get_table_names(self): + self._test_get_table_names() + + def test_get_table_names_with_schema(self): + self._test_get_table_names(getSchema()) + + def test_get_view_names(self): + self._test_get_table_names(table_type='view') + + def test_get_view_names_with_schema(self): + self._test_get_table_names(getSchema(), table_type='view') + + def _test_get_columns(self, schemaname=None, table_type='table'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schemaname) + table_names = [u'users', u'email_addresses'] + meta.create_all() + if table_type == 'view': + createViews(meta.bind, schemaname) + table_names = [u'users_v', u'email_addresses_v'] + try: + insp = Inspector(meta.bind) + for (tablename, table) in zip(table_names, (users, addresses)): + schema_name = schemaname + if schemaname and testing.against('oracle'): + schema_name = schema.upper() + cols = insp.get_columns(tablename, schemaname=schema_name) + self.assert_(len(cols) > 0, len(cols)) + # should be in order + for (i, col) in enumerate(table.columns): + self.assertEqual(col.name, cols[i]['name']) + # coltype is tricky + # It may not inherit from col.type while they share + # the same base. + coltype = cols[i]['type'].__class__ + self.assert_( + issubclass(coltype, col.type.__class__) or \ + len( + set( + coltype.__bases__ + ).intersection(col.type.__class__.__bases__)) > 0 + ,("%s, %s", (col.type, coltype))) + finally: + if table_type == 'view': + dropViews(meta.bind, schemaname) + addresses.drop() + users.drop() + + def test_get_columns(self): + self._test_get_columns() + + def test_get_columns_with_schema(self): + self._test_get_columns(schemaname=getSchema()) + + def test_get_view_columns(self): + self._test_get_columns(table_type='view') + + def test_get_view_columns_with_schema(self): + self._test_get_columns(schemaname=getSchema(), table_type='view') + + def _test_get_primary_keys(self, schemaname=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schemaname) + meta.create_all() + insp = Inspector(meta.bind) + try: + users_pkeys = insp.get_primary_keys(unicode(users.name), + schemaname=schemaname) + self.assertEqual(users_pkeys, ['user_id']) + addr_pkeys = insp.get_primary_keys(unicode(addresses.name), + schemaname=schemaname) + self.assertEqual(addr_pkeys, ['address_id']) + + finally: + addresses.drop() + users.drop() + + def test_get_primary_keys(self): + self._test_get_primary_keys() + + def test_get_primary_keys_with_schema(self): + self._test_get_primary_keys(schemaname=getSchema()) + + def _test_get_foreign_keys(self, schemaname=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schemaname) + meta.create_all() + insp = Inspector(meta.bind) + try: + expected_schema = schemaname + if schemaname is None: + expected_schema = meta.bind.dialect.get_default_schema_name( + meta.bind) + # users + users_fkeys = insp.get_foreign_keys(unicode(users.name), + schemaname=schemaname) + fkey1 = users_fkeys[0] + self.assert_(fkey1['name'] is not None) + self.assertEqual(fkey1['referred_schema'], expected_schema) + self.assertEqual(fkey1['referred_table'], users.name) + self.assertEqual(fkey1['referred_columns'], ['user_id', ]) + self.assertEqual(fkey1['constrained_columns'], ['parent_user_id']) + #addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, + schemaname=schemaname) + fkey1 = addr_fkeys[0] + self.assert_(fkey1['name'] is not None) + self.assertEqual(fkey1['referred_schema'], expected_schema) + self.assertEqual(fkey1['referred_table'], users.name) + self.assertEqual(fkey1['referred_columns'], ['user_id', ]) + self.assertEqual(fkey1['constrained_columns'], ['remote_user_id']) + finally: + addresses.drop() + users.drop() + + def test_get_foreign_keys(self): + self._test_get_foreign_keys() + + def test_get_foreign_keys_with_schema(self): + self._test_get_foreign_keys(schemaname=getSchema()) + + def _test_get_indexes(self, schemaname=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schemaname) + meta.create_all() + createIndexes(meta.bind, schemaname) + try: + insp = Inspector(meta.bind) + indexes = insp.get_indexes(u'users', schemaname=schemaname) + indexes.sort() + expected_indexes = [ + {'unique': False, + 'column_names': ['test1', 'test2'], + 'name': 'users_t_idx'}] + self.assertEqual(indexes, expected_indexes) + finally: + addresses.drop() + users.drop() + + def test_get_indexes(self): + self._test_get_indexes() + + def test_get_indexes_with_schema(self): + self._test_get_indexes(schemaname=getSchema()) + + def _test_get_view_definition(self, schemaname=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schemaname) + meta.create_all() + createViews(meta.bind, schemaname) + view_name1 = u'users_v' + view_name2 = u'email_addresses_v' + try: + insp = Inspector(meta.bind) + v1 = insp.get_view_definition(view_name1, schemaname=schemaname) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schemaname=schemaname) + self.assert_(v2) + finally: + dropViews(meta.bind, schemaname) + addresses.drop() + users.drop() + + def test_get_view_definition(self): + self._test_get_view_definition() + + def test_get_view_definition_with_schema(self): + self._test_get_view_definition(schemaname=getSchema()) + +if __name__ == "__main__": + testenv.main() -- 2.47.3