]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
finished oracle - all tests pass
authorRandall Smith <randall@tnr.cc>
Sun, 15 Feb 2009 07:52:27 +0000 (07:52 +0000)
committerRandall Smith <randall@tnr.cc>
Sun, 15 Feb 2009 07:52:27 +0000 (07:52 +0000)
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/engine/default.py
test/reflection.py

index 25b50c550eed6253488eaabd0118daf8e4492450..de9d14265ba28f140e539b2d94cb7ebb15096443 100644 (file)
@@ -583,6 +583,22 @@ class OracleDialect(default.DefaultDialect):
             owner = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
         return (actual_name, owner, dblink, synonym)
 
+    def get_schema_names(self, connection, info_cache=None):
+        s = "SELECT username FROM all_users ORDER BY username"
+        cursor = connection.execute(s,)
+        return [self._normalize_name(row[0]) for row in cursor]
+
+    def get_table_names(self, connection, schemaname=None, info_cache=None):
+        schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
+        return self.table_names(connection, schemaname)
+
+    def get_view_names(self, connection, schemaname=None, info_cache=None):
+        schemaname = self._denormalize_name(schemaname or self.get_default_schema_name(connection))
+        s = "select view_name from all_views where OWNER = :owner"
+        cursor = connection.execute(s,
+                {'owner':self._denormalize_name(schemaname)})
+        return [self._normalize_name(row[0]) for row in cursor]
+
     def get_columns(self, connection, tablename, schemaname=None,
                     info_cache=None, resolve_synonyms=False, dblink=''):
 
@@ -642,6 +658,51 @@ class OracleDialect(default.DefaultDialect):
             info_cache.setColumns(columns, tablename, schemaname)
         return columns
 
+    def get_indexes(self, connection, tablename, schemaname=None,
+                    info_cache=None, resolve_synonyms=False, dblink=''):
+
+        
+        (tablename, schemaname, dblink, synonym) = \
+            self._prepare_reflection_args(connection, tablename, schemaname,
+                                          resolve_synonyms, dblink)
+        if info_cache:
+            indexes = info_cache.getIndexes(tablename, schemaname)
+            if indexes:
+                return indexes
+        indexes = []
+        q = """
+        SELECT a.INDEX_NAME, a.COLUMN_NAME, b.UNIQUENESS
+        FROM ALL_IND_COLUMNS%(dblink)s a
+        INNER JOIN ALL_INDEXES%(dblink)s b
+            ON a.INDEX_NAME = b.INDEX_NAME
+            AND a.TABLE_OWNER = b.TABLE_OWNER
+            AND a.TABLE_NAME = b.TABLE_NAME
+        WHERE a.TABLE_NAME = :tablename
+        AND a.TABLE_OWNER = :schemaname
+        ORDER BY a.INDEX_NAME, a.COLUMN_POSITION
+        """ % dict(dblink=dblink)
+        rp = connection.execute(q,
+            dict(tablename=self._denormalize_name(tablename),
+                 schemaname=self._denormalize_name(schemaname)))
+        indexes = []
+        last_index_name = None
+        pkeys = self.get_primary_keys(connection, tablename, schemaname,
+                                      info_cache, resolve_synonyms, dblink)
+        uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
+        for rset in rp:
+            # don't include the primary key columns
+            if rset.column_name in [s.upper() for s in pkeys]:
+                continue
+            if rset.index_name != last_index_name:
+                index = dict(name=rset.index_name, column_names=[])
+                indexes.append(index)
+            index['unique'] = uniqueness.get(rset.uniqueness, False)
+            index['column_names'].append(rset.column_name)
+            last_index_name = rset.index_name
+        if info_cache:
+            info_cache.setIndexes(indexes, tablename, schemaname)
+        return indexes
+
     def _get_constraint_data(self, connection, tablename, schemaname=None,
                              info_cache=None, dblink=''):
 
@@ -738,11 +799,42 @@ class OracleDialect(default.DefaultDialect):
                     fk[1].append(remote_column)
         for (name, value) in fks.items():
             if remote_table and value[1]:
-                fkeys.append((name, value[0], remote_owner, remote_table, value[1]))
+                fkey_d = {
+                    'name' : name,
+                    'constrained_columns' : value[0],
+                    'referred_schema' : remote_owner,
+                    'referred_table' : remote_table,
+                    'referred_columns' : value[1]
+                }
+                fkeys.append(fkey_d)
         if info_cache:
             info_cache.setForeignKeys(fkeys, tablename, schemaname)
         return fkeys
 
+    def get_view_definition(self, connection, viewname, schemaname=None,
+                            info_cache=None, resolve_synonyms=False, dblink=''):
+        (viewname, schemaname, dblink, synonym) = \
+            self._prepare_reflection_args(connection, viewname, schemaname,
+                                          resolve_synonyms, dblink)
+        if info_cache:
+            view_cache = info_cache.getView(viewname, schemaname)
+            if view_cache and 'definition' in view_cache:
+                return view_cache['definition']
+        s = """
+        SELECT text FROM all_views
+        WHERE owner = :schemaname
+        AND view_name = :viewname
+        """
+        rp = connection.execute(sql.text(s),
+                                viewname=viewname, schemaname=schemaname)
+        if rp:
+            view_def = rp.scalar().decode(self.encoding)
+            if info_cache:
+                view = info_cache.getView(viewname, schemaname,
+                                          create=True)
+                view['definition'] = view_def
+            return view_def
+
     def reflecttable(self, connection, table, include_columns):
         preparer = self.identifier_preparer
         info_cache = OracleInfoCache()
@@ -780,8 +872,12 @@ class OracleDialect(default.DefaultDialect):
         fkeys = self.get_foreign_keys(connection, actual_name, owner,
                                       info_cache, resolve_synonyms, dblink)
         refspecs = []
-        for (conname, constrained_columns, referred_schema, referred_table,
-             referred_columns) in fkeys:
+        for fkey_d in fkeys:
+            conname = fkey_d['name']
+            constrained_columns = fkey_d['constrained_columns']
+            referred_schema = fkey_d['referred_schema']
+            referred_table = fkey_d['referred_table']
+            referred_columns = fkey_d['referred_columns']
             for (i, ref_col) in enumerate(referred_columns):
                 if not table.schema and self._denormalize_name(referred_schema) == self._denormalize_name(owner):
                     t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
index 4c56dfb70e1629d183acadf82fbcb585b9169070..b50411c0cff2c369faec02fde4f925a409d3d17e 100644 (file)
@@ -186,6 +186,15 @@ class DefaultInfoCache(object):
 
         return self._setTableData('foreign_keys', fkeys, tablename, schemaname)
 
+    def getIndexes(self, tablename, schemaname=None):
+        """Return indexes list or None."""
+        
+        return self._getTableData('indexes', tablename, schemaname)
+
+    def setIndexes(self, indexes, tablename, schemaname=None):
+        """Add list of indexes to table cache."""
+
+        return self._setTableData('indexes', indexes, tablename, schemaname)
 
 class DefaultDialect(base.Dialect):
     """Default implementation of Dialect"""
index 0e95a6dbf454f4d6a3886b49ca061b8a12f87b08..39240487c552f0698fb69111f8e58e344bece861 100644 (file)
@@ -4,6 +4,7 @@
 
 import testenv; testenv.configure_for_tests()
 import sqlalchemy as sa
+from sqlalchemy import types as sql_types
 from sqlalchemy.engine.reflection import Inspector
 from testlib.sa import MetaData, Table, Column
 from testlib import TestBase, testing, engines
@@ -13,7 +14,7 @@ if 'set' not in dir(__builtins__):
 
 def getSchema():
     if testing.against('oracle'):
-        return 'scott'
+        return 'test'
     else:
         return 'test_schema'
 
@@ -35,6 +36,7 @@ def createTables(meta, schema=None):
         Column('test3', sa.Text),
         Column('test4', sa.Numeric, nullable = False),
         Column('test5', sa.DateTime),
+        Column('test5-1', sa.TIMESTAMP),
         parent_user_id,
         Column('test6', sa.DateTime, nullable=False),
         Column('test7', sa.Text),
@@ -143,7 +145,7 @@ class ReflectionTest(TestBase):
             for (tablename, table) in zip(table_names, (users, addresses)):
                 schema_name = schemaname
                 if schemaname and testing.against('oracle'):
-                    schema_name = schema.upper()
+                    schema_name = schemaname.upper()
                 cols = insp.get_columns(tablename, schemaname=schema_name)
                 self.assert_(len(cols) > 0, len(cols))
                 # should be in order
@@ -152,14 +154,22 @@ class ReflectionTest(TestBase):
                     # coltype is tricky
                     # It may not inherit from col.type while they share
                     # the same base.
-                    coltype = cols[i]['type'].__class__
+                    ctype = cols[i]['type'].__class__
+                    ctype_def = col.type
+                    if isinstance(ctype_def, sa.types.TypeEngine):
+                        ctype_def = ctype_def.__class__
+                    # Oracle returns Date for DateTime.
+                    if testing.against('oracle') \
+                        and ctype_def in (sql_types.Date, sql_types.DateTime):
+                            ctype_def = sql_types.Date
                     self.assert_(
-                        issubclass(coltype, col.type.__class__) or \
+                        issubclass(ctype, ctype_def) or \
                         len(
                             set(
-                                coltype.__bases__
-                            ).intersection(col.type.__class__.__bases__)) > 0
-                    ,("%s, %s", (col.type, coltype)))
+                                ctype.__bases__
+                            ).intersection(ctype_def.__bases__)) > 0
+                    ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'],
+                                          ctype)))
         finally:
             if table_type == 'view':
                 dropViews(meta.bind, schemaname)
@@ -248,10 +258,16 @@ class ReflectionTest(TestBase):
             insp = Inspector(meta.bind)
             indexes = insp.get_indexes('users', schemaname=schemaname)
             indexes.sort()
-            expected_indexes = [
-                {'unique': False,
-                 'column_names': ['test1', 'test2'],
-                 'name': 'users_t_idx'}]
+            if testing.against('oracle'):
+                expected_indexes = [
+                    {'unique': False,
+                     'column_names': ['TEST1', 'TEST2'],
+                     'name': 'USERS_T_IDX'}]
+            else:
+                expected_indexes = [
+                    {'unique': False,
+                     'column_names': ['test1', 'test2'],
+                     'name': 'users_t_idx'}]
             self.assertEqual(indexes, expected_indexes)
         finally:
             addresses.drop()