]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added reflection (inspector) and tests
authorRandall Smith <randall@tnr.cc>
Wed, 4 Feb 2009 06:42:26 +0000 (06:42 +0000)
committerRandall Smith <randall@tnr.cc>
Wed, 4 Feb 2009 06:42:26 +0000 (06:42 +0000)
lib/sqlalchemy/engine/reflection.py [new file with mode: 0644]
test/reflection.py [new file with mode: 0644]

diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
new file mode 100644 (file)
index 0000000..99e8649
--- /dev/null
@@ -0,0 +1,188 @@
+"""Provides an abstraction for obtaining database schema information.
+
+Development Notes:
+
+I'm still trying to decide upon conventions for both the Inspector interface as well as the dialect interface the Inspector is to consume.  Below are some of the current conventions.
+
+  1. Inspector methods should return lists of dicts in most cases for the 
+     following reasons:
+    * They're both simple standard types.
+    * Using a dict instead of a tuple allows easy expansion of attributes.
+    * Using a list for the outer structure maintains order and is easy to work 
+       with (e.g. list comprehension [d['name'] for d in cols]).
+    * Being consistent is just good.
+  2. Records that contain a name, such as the column name in a column record
+     should use the key 'name' in the dict.  This allows the user to expect a
+     'name' key and to know what it will reference.
+
+
+"""
+import sqlalchemy
+from sqlalchemy.types import TypeEngine
+
+class Inspector(object):
+    """performs database introspection
+
+    """
+    
+    def __init__(self, conn):
+        """
+
+        conn
+          [sqlalchemy.engine.base.#Connectable]
+
+        """
+        self.info_cache = {}
+        self.conn = conn
+        # set the engine
+        if hasattr(conn, 'engine'):
+            self.engine = conn.engine
+        else:
+            self.engine = conn
+
+    def default_schema_name(self):
+        return self.engine.dialect.get_default_schema_name(self.conn)
+    default_schema_name = property(default_schema_name)
+
+    def get_schema_names(self):
+        """Return all schema names.
+
+        """
+        if hasattr(self.engine.dialect, 'get_schema_names'):
+            return self.engine.dialect.get_schema_names(self.conn,
+                                                        self.info_cache)
+        return []
+
+    def get_table_names(self, schemaname=None):
+        """Return all table names in `schemaname`.
+        schemaname:
+          Optional, retrieve names from a non-default schema.
+
+        This should probably not return view names or maybe it should return
+        them with an indicator t or v.
+
+        """
+        if hasattr(self.engine.dialect, 'get_table_names'):
+            return self.engine.dialect.get_table_names(self.conn, schemaname,
+                                                       self.info_cache)
+        return self.engine.table_names(schemaname)
+
+    def get_view_names(self, schemaname=None):
+        """Return all view names in `schemaname`.
+        schemaname:
+          Optional, retrieve names from a non-default schema.
+
+        """
+        return self.engine.dialect.get_view_names(self.conn, schemaname,
+                                                  self.info_cache)
+
+    def get_view_definition(self, view_name, schemaname=None):
+        """Return definition for `view_name`.
+        schemaname:
+          Optional, retrieve names from a non-default schema.
+
+        """
+        return self.engine.dialect.get_view_definition(
+            self.conn, view_name, schemaname, self.info_cache)
+
+    def get_columns(self, tablename, schemaname=None):
+        """Return information about columns in `tablename`.
+
+        Given a string `tablename` and an optional string `schemaname`, return
+        column information as a list of dicts with these keys:
+
+        name
+          the column's name
+
+        type
+          [sqlalchemy.types#TypeEngine]
+
+        nullable
+          boolean
+
+        default
+          the column's default value
+
+        attrs
+          dict containing optional column attributes
+
+        """
+
+        col_defs = self.engine.dialect.get_columns(self.conn, tablename,
+                                                   schemaname,
+                                                   self.info_cache)
+        for col_def in col_defs:
+            # make this easy and only return instances for coltype
+            coltype = col_def['type']
+            if not isinstance(coltype, TypeEngine):
+                col_def['type'] = coltype()
+        return col_defs
+
+    def get_primary_keys(self, tablename, schemaname=None):
+        """Return information about primary keys in `tablename`.
+
+        Given a string `tablename`, and an optional string `schemaname`, return 
+        primary key information as a list of column names:
+
+        """
+
+        pkeys = self.engine.dialect.get_primary_keys(self.conn, tablename,
+                                                     schemaname,
+                                                     self.info_cache)
+
+        return pkeys
+
+    def get_foreign_keys(self, tablename, schemaname=None):
+        """Return information about foreign_keys in `tablename`.
+
+        Given a string `tablename`, and an optional string `schemaname`, return 
+        foreign key information as a list of dicts with these keys:
+
+        constrained_columns
+          a list of column names that make up the foreign key
+
+        referred_schema
+          the name of the referred schema
+
+        referred_table
+          the name of the referred table
+
+        referred_columns
+          a list of column names in the referred table that correspond to
+          constrained_columns
+
+        """
+
+        fk_defs = self.engine.dialect.get_foreign_keys(self.conn, tablename,
+                                                       schemaname,
+                                                       self.info_cache)
+        for fk_def in fk_defs:
+            referred_schema = fk_def['referred_schema']
+            # always set the referred_schema.
+            if referred_schema is None and schemaname is None:
+                referred_schema = self.engine.dialect.get_default_schema_name(
+                                                                    self.conn)
+                fk_def['referred_schema'] = referred_schema
+        return fk_defs
+
+    def get_indexes(self, tablename, schemaname=None):
+        """Return information about indexes in `tablename`.
+
+        Given a string `tablename` and an optional string `schemaname`, return
+        index information as a list of dicts with these keys:
+
+        name
+          the index's name
+
+        column_names
+          list of column names in order
+
+        unique
+          boolean
+
+        """
+
+        indexes = self.engine.dialect.get_indexes(self.conn, tablename,
+                                                  schemaname,
+                                                  self.info_cache)
+        return indexes
diff --git a/test/reflection.py b/test/reflection.py
new file mode 100644 (file)
index 0000000..e3d911a
--- /dev/null
@@ -0,0 +1,283 @@
+"""tests for sqlalchemy.engine.reflection
+
+"""
+
+import testenv; testenv.configure_for_tests()
+import sqlalchemy as sa
+from sqlalchemy.engine.reflection import Inspector
+from testlib.sa import MetaData, Table, Column
+from testlib import TestBase, testing, engines
+
+if 'set' not in dir(__builtins__):
+    from sets import Set as set
+
+def getSchema():
+    if testing.against('oracle'):
+        return u'scott'
+    else:
+        return u'test_schema'
+
+def createTables(meta, schema=None):
+    if schema:
+        parent_user_id = Column('parent_user_id', sa.Integer,
+            sa.ForeignKey('%s.users.user_id' % schema)
+        )
+    else:
+        parent_user_id = Column('parent_user_id', sa.Integer,
+            sa.ForeignKey('users.user_id')
+        )
+
+    users = Table('users', meta,
+        Column('user_id', sa.INT, primary_key=True),
+        Column('user_name', sa.VARCHAR(20), nullable=False),
+        Column('test1', sa.CHAR(5), nullable=False),
+        Column('test2', sa.Float(5), nullable=False),
+        Column('test3', sa.Text),
+        Column('test4', sa.Numeric, nullable = False),
+        Column('test5', sa.DateTime),
+        parent_user_id,
+        Column('test6', sa.DateTime, nullable=False),
+        Column('test7', sa.Text),
+        Column('test8', sa.Binary),
+        Column('test_passivedefault2', sa.Integer, server_default='5'),
+        Column('test9', sa.Binary(100)),
+        Column('test_numeric', sa.Numeric()),
+        schema=schema,
+        test_needs_fk=True,
+    )
+    addresses = Table(u'email_addresses', meta,
+        Column('address_id', sa.Integer, primary_key = True),
+        Column('remote_user_id', sa.Integer,
+               sa.ForeignKey(users.c.user_id)),
+        Column('email_address', sa.String(20)),
+        schema=schema,
+        test_needs_fk=True,
+    )
+    return (users, addresses)
+
+def createIndexes(con, schema=None):
+    fullname = 'users'
+    if schema:
+        fullname = "%s.%s" % (schema, 'users')
+    query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname
+    con.execute(sa.sql.text(query))
+
+def createViews(con, schema=None):
+    for tablename in (u'users', u'email_addresses'):
+        fullname = tablename
+        if schema:
+            fullname = "%s.%s" % (schema, tablename)
+        view_name = fullname + '_v'
+        query = "CREATE OR REPLACE VIEW %s AS SELECT * FROM %s" % (view_name,
+                                                                   fullname)
+        con.execute(sa.sql.text(query))
+
+def dropViews(con, schema=None):
+    for tablename in (u'email_addresses', 'users'):
+        fullname = tablename
+        if schema:
+            fullname = "%s.%s" % (schema, tablename)
+        view_name = fullname + '_v'
+        query = "DROP VIEW %s" % view_name
+        con.execute(sa.sql.text(query))
+
+
+class ReflectionTest(TestBase):
+
+    def test_get_schema_names(self):
+        meta = MetaData(testing.db)
+        insp = Inspector(meta.bind)
+        self.assert_(getSchema() in insp.get_schema_names())
+
+    def _test_get_table_names(self, schemaname=None, table_type='table'):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schemaname)
+        meta.create_all()
+        createViews(meta.bind, schemaname)
+        try:
+            insp = Inspector(meta.bind)
+            if table_type == 'view':
+                table_names = insp.get_view_names(schemaname)
+                table_names.sort()
+                answer = [u'email_addresses_v', u'users_v']
+            else:
+                table_names = insp.get_table_names(schemaname)
+                table_names.sort()
+                answer = [u'email_addresses', 'users']
+            self.assertEqual(table_names, answer)
+        finally:
+            dropViews(meta.bind, schemaname)
+            addresses.drop()
+            users.drop()
+
+    def test_get_table_names(self):
+        self._test_get_table_names()
+
+    def test_get_table_names_with_schema(self):
+        self._test_get_table_names(getSchema())
+
+    def test_get_view_names(self):
+        self._test_get_table_names(table_type='view')
+
+    def test_get_view_names_with_schema(self):
+        self._test_get_table_names(getSchema(), table_type='view')
+
+    def _test_get_columns(self, schemaname=None, table_type='table'):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schemaname)
+        table_names = [u'users', u'email_addresses']
+        meta.create_all()
+        if table_type == 'view':
+            createViews(meta.bind, schemaname)
+            table_names = [u'users_v', u'email_addresses_v']
+        try:
+            insp = Inspector(meta.bind)
+            for (tablename, table) in zip(table_names, (users, addresses)):
+                schema_name = schemaname
+                if schemaname and testing.against('oracle'):
+                    schema_name = schema.upper()
+                cols = insp.get_columns(tablename, schemaname=schema_name)
+                self.assert_(len(cols) > 0, len(cols))
+                # should be in order
+                for (i, col) in enumerate(table.columns):
+                    self.assertEqual(col.name, cols[i]['name'])
+                    # coltype is tricky
+                    # It may not inherit from col.type while they share
+                    # the same base.
+                    coltype = cols[i]['type'].__class__
+                    self.assert_(
+                        issubclass(coltype, col.type.__class__) or \
+                        len(
+                            set(
+                                coltype.__bases__
+                            ).intersection(col.type.__class__.__bases__)) > 0
+                    ,("%s, %s", (col.type, coltype)))
+        finally:
+            if table_type == 'view':
+                dropViews(meta.bind, schemaname)
+            addresses.drop()
+            users.drop()
+
+    def test_get_columns(self):
+        self._test_get_columns()
+
+    def test_get_columns_with_schema(self):
+        self._test_get_columns(schemaname=getSchema())
+
+    def test_get_view_columns(self):
+        self._test_get_columns(table_type='view')
+
+    def test_get_view_columns_with_schema(self):
+        self._test_get_columns(schemaname=getSchema(), table_type='view')
+
+    def _test_get_primary_keys(self, schemaname=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schemaname)
+        meta.create_all()
+        insp = Inspector(meta.bind)
+        try:
+            users_pkeys = insp.get_primary_keys(unicode(users.name),
+                                                schemaname=schemaname)
+            self.assertEqual(users_pkeys,  ['user_id'])
+            addr_pkeys = insp.get_primary_keys(unicode(addresses.name),
+                                               schemaname=schemaname)
+            self.assertEqual(addr_pkeys,  ['address_id'])
+
+        finally:
+            addresses.drop()
+            users.drop()
+
+    def test_get_primary_keys(self):
+        self._test_get_primary_keys()
+
+    def test_get_primary_keys_with_schema(self):
+        self._test_get_primary_keys(schemaname=getSchema())
+
+    def _test_get_foreign_keys(self, schemaname=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schemaname)
+        meta.create_all()
+        insp = Inspector(meta.bind)
+        try:
+            expected_schema = schemaname
+            if schemaname is None:
+                expected_schema = meta.bind.dialect.get_default_schema_name(
+                                    meta.bind)
+            # users
+            users_fkeys = insp.get_foreign_keys(unicode(users.name),
+                                                schemaname=schemaname)
+            fkey1 = users_fkeys[0]
+            self.assert_(fkey1['name'] is not None)
+            self.assertEqual(fkey1['referred_schema'], expected_schema)
+            self.assertEqual(fkey1['referred_table'], users.name)
+            self.assertEqual(fkey1['referred_columns'], ['user_id', ])
+            self.assertEqual(fkey1['constrained_columns'], ['parent_user_id'])
+            #addresses
+            addr_fkeys = insp.get_foreign_keys(addresses.name,
+                                               schemaname=schemaname)
+            fkey1 = addr_fkeys[0]
+            self.assert_(fkey1['name'] is not None)
+            self.assertEqual(fkey1['referred_schema'], expected_schema)
+            self.assertEqual(fkey1['referred_table'], users.name)
+            self.assertEqual(fkey1['referred_columns'], ['user_id', ])
+            self.assertEqual(fkey1['constrained_columns'], ['remote_user_id'])
+        finally:
+            addresses.drop()
+            users.drop()
+
+    def test_get_foreign_keys(self):
+        self._test_get_foreign_keys()
+
+    def test_get_foreign_keys_with_schema(self):
+        self._test_get_foreign_keys(schemaname=getSchema())
+
+    def _test_get_indexes(self, schemaname=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schemaname)
+        meta.create_all()
+        createIndexes(meta.bind, schemaname)
+        try:
+            insp = Inspector(meta.bind)
+            indexes = insp.get_indexes(u'users', schemaname=schemaname)
+            indexes.sort()
+            expected_indexes = [
+                {'unique': False,
+                 'column_names': ['test1', 'test2'],
+                 'name': 'users_t_idx'}]
+            self.assertEqual(indexes, expected_indexes)
+        finally:
+            addresses.drop()
+            users.drop()
+
+    def test_get_indexes(self):
+        self._test_get_indexes()
+
+    def test_get_indexes_with_schema(self):
+        self._test_get_indexes(schemaname=getSchema())
+
+    def _test_get_view_definition(self, schemaname=None):
+        meta = MetaData(testing.db)
+        (users, addresses) = createTables(meta, schemaname)
+        meta.create_all()
+        createViews(meta.bind, schemaname)
+        view_name1 = u'users_v'
+        view_name2 = u'email_addresses_v'
+        try:
+            insp = Inspector(meta.bind)
+            v1 = insp.get_view_definition(view_name1, schemaname=schemaname)
+            self.assert_(v1)
+            v2 = insp.get_view_definition(view_name2, schemaname=schemaname)
+            self.assert_(v2)
+        finally:
+            dropViews(meta.bind, schemaname)
+            addresses.drop()
+            users.drop()
+
+    def test_get_view_definition(self):
+        self._test_get_view_definition()
+
+    def test_get_view_definition_with_schema(self):
+        self._test_get_view_definition(schemaname=getSchema())
+
+if __name__ == "__main__":
+    testenv.main()