From: Randall Smith Date: Sun, 15 Feb 2009 07:52:27 +0000 (+0000) Subject: finished oracle - all tests pass X-Git-Tag: rel_0_6_6~278 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a660493f1fcf30f6c85aae94d3adfc03b88fe9e6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git finished oracle - all tests pass --- diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 25b50c550e..de9d14265b 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -583,6 +583,22 @@ class OracleDialect(default.DefaultDialect): owner = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) return (actual_name, owner, dblink, synonym) + def get_schema_names(self, connection, info_cache=None): + s = "SELECT username FROM all_users ORDER BY username" + cursor = connection.execute(s,) + return [self._normalize_name(row[0]) for row in cursor] + + def get_table_names(self, connection, schemaname=None, info_cache=None): + schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) + return self.table_names(connection, schemaname) + + def get_view_names(self, connection, schemaname=None, info_cache=None): + schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection)) + s = "select view_name from all_views where OWNER = :owner" + cursor = connection.execute(s, + {'owner':self._denormalize_name(schemaname)}) + return [self._normalize_name(row[0]) for row in cursor] + def get_columns(self, connection, tablename, schemaname=None, info_cache=None, resolve_synonyms=False, dblink=''): @@ -642,6 +658,51 @@ class OracleDialect(default.DefaultDialect): info_cache.setColumns(columns, tablename, schemaname) return columns + def get_indexes(self, connection, tablename, schemaname=None, + info_cache=None, resolve_synonyms=False, dblink=''): + + + (tablename, schemaname, dblink, synonym) = \ + self._prepare_reflection_args(connection, tablename, schemaname, + resolve_synonyms, dblink) + if info_cache: + indexes = info_cache.getIndexes(tablename, schemaname) + if indexes: + return indexes + indexes = [] + q = """ + SELECT a.INDEX_NAME, a.COLUMN_NAME, b.UNIQUENESS + FROM ALL_IND_COLUMNS%(dblink)s a + INNER JOIN ALL_INDEXES%(dblink)s b + ON a.INDEX_NAME = b.INDEX_NAME + AND a.TABLE_OWNER = b.TABLE_OWNER + AND a.TABLE_NAME = b.TABLE_NAME + WHERE a.TABLE_NAME = :tablename + AND a.TABLE_OWNER = :schemaname + ORDER BY a.INDEX_NAME, a.COLUMN_POSITION + """ % dict(dblink=dblink) + rp = connection.execute(q, + dict(tablename=self._denormalize_name(tablename), + schemaname=self._denormalize_name(schemaname))) + indexes = [] + last_index_name = None + pkeys = self.get_primary_keys(connection, tablename, schemaname, + info_cache, resolve_synonyms, dblink) + uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + for rset in rp: + # don't include the primary key columns + if rset.column_name in [s.upper() for s in pkeys]: + continue + if rset.index_name != last_index_name: + index = dict(name=rset.index_name, column_names=[]) + indexes.append(index) + index['unique'] = uniqueness.get(rset.uniqueness, False) + index['column_names'].append(rset.column_name) + last_index_name = rset.index_name + if info_cache: + info_cache.setIndexes(indexes, tablename, schemaname) + return indexes + def _get_constraint_data(self, connection, tablename, schemaname=None, info_cache=None, dblink=''): @@ -738,11 +799,42 @@ class OracleDialect(default.DefaultDialect): fk[1].append(remote_column) for (name, value) in fks.items(): if remote_table and value[1]: - fkeys.append((name, value[0], remote_owner, remote_table, value[1])) + fkey_d = { + 'name' : name, + 'constrained_columns' : value[0], + 'referred_schema' : remote_owner, + 'referred_table' : remote_table, + 'referred_columns' : value[1] + } + fkeys.append(fkey_d) if info_cache: info_cache.setForeignKeys(fkeys, tablename, schemaname) return fkeys + def get_view_definition(self, connection, viewname, schemaname=None, + info_cache=None, resolve_synonyms=False, dblink=''): + (viewname, schemaname, dblink, synonym) = \ + self._prepare_reflection_args(connection, viewname, schemaname, + resolve_synonyms, dblink) + if info_cache: + view_cache = info_cache.getView(viewname, schemaname) + if view_cache and 'definition' in view_cache: + return view_cache['definition'] + s = """ + SELECT text FROM all_views + WHERE owner = :schemaname + AND view_name = :viewname + """ + rp = connection.execute(sql.text(s), + viewname=viewname, schemaname=schemaname) + if rp: + view_def = rp.scalar().decode(self.encoding) + if info_cache: + view = info_cache.getView(viewname, schemaname, + create=True) + view['definition'] = view_def + return view_def + def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer info_cache = OracleInfoCache() @@ -780,8 +872,12 @@ class OracleDialect(default.DefaultDialect): fkeys = self.get_foreign_keys(connection, actual_name, owner, info_cache, resolve_synonyms, dblink) refspecs = [] - for (conname, constrained_columns, referred_schema, referred_table, - referred_columns) in fkeys: + for fkey_d in fkeys: + conname = fkey_d['name'] + constrained_columns = fkey_d['constrained_columns'] + referred_schema = fkey_d['referred_schema'] + referred_table = fkey_d['referred_table'] + referred_columns = fkey_d['referred_columns'] for (i, ref_col) in enumerate(referred_columns): if not table.schema and self._denormalize_name(referred_schema) == self._denormalize_name(owner): t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4c56dfb70e..b50411c0cf 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -186,6 +186,15 @@ class DefaultInfoCache(object): return self._setTableData('foreign_keys', fkeys, tablename, schemaname) + def getIndexes(self, tablename, schemaname=None): + """Return indexes list or None.""" + + return self._getTableData('indexes', tablename, schemaname) + + def setIndexes(self, indexes, tablename, schemaname=None): + """Add list of indexes to table cache.""" + + return self._setTableData('indexes', indexes, tablename, schemaname) class DefaultDialect(base.Dialect): """Default implementation of Dialect""" diff --git a/test/reflection.py b/test/reflection.py index 0e95a6dbf4..39240487c5 100644 --- a/test/reflection.py +++ b/test/reflection.py @@ -4,6 +4,7 @@ import testenv; testenv.configure_for_tests() import sqlalchemy as sa +from sqlalchemy import types as sql_types from sqlalchemy.engine.reflection import Inspector from testlib.sa import MetaData, Table, Column from testlib import TestBase, testing, engines @@ -13,7 +14,7 @@ if 'set' not in dir(__builtins__): def getSchema(): if testing.against('oracle'): - return 'scott' + return 'test' else: return 'test_schema' @@ -35,6 +36,7 @@ def createTables(meta, schema=None): Column('test3', sa.Text), Column('test4', sa.Numeric, nullable = False), Column('test5', sa.DateTime), + Column('test5-1', sa.TIMESTAMP), parent_user_id, Column('test6', sa.DateTime, nullable=False), Column('test7', sa.Text), @@ -143,7 +145,7 @@ class ReflectionTest(TestBase): for (tablename, table) in zip(table_names, (users, addresses)): schema_name = schemaname if schemaname and testing.against('oracle'): - schema_name = schema.upper() + schema_name = schemaname.upper() cols = insp.get_columns(tablename, schemaname=schema_name) self.assert_(len(cols) > 0, len(cols)) # should be in order @@ -152,14 +154,22 @@ class ReflectionTest(TestBase): # coltype is tricky # It may not inherit from col.type while they share # the same base. - coltype = cols[i]['type'].__class__ + ctype = cols[i]['type'].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + # Oracle returns Date for DateTime. + if testing.against('oracle') \ + and ctype_def in (sql_types.Date, sql_types.DateTime): + ctype_def = sql_types.Date self.assert_( - issubclass(coltype, col.type.__class__) or \ + issubclass(ctype, ctype_def) or \ len( set( - coltype.__bases__ - ).intersection(col.type.__class__.__bases__)) > 0 - ,("%s, %s", (col.type, coltype))) + ctype.__bases__ + ).intersection(ctype_def.__bases__)) > 0 + ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'], + ctype))) finally: if table_type == 'view': dropViews(meta.bind, schemaname) @@ -248,10 +258,16 @@ class ReflectionTest(TestBase): insp = Inspector(meta.bind) indexes = insp.get_indexes('users', schemaname=schemaname) indexes.sort() - expected_indexes = [ - {'unique': False, - 'column_names': ['test1', 'test2'], - 'name': 'users_t_idx'}] + if testing.against('oracle'): + expected_indexes = [ + {'unique': False, + 'column_names': ['TEST1', 'TEST2'], + 'name': 'USERS_T_IDX'}] + else: + expected_indexes = [ + {'unique': False, + 'column_names': ['test1', 'test2'], + 'name': 'users_t_idx'}] self.assertEqual(indexes, expected_indexes) finally: addresses.drop()