]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- brought oracle reflection into the 21st century
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Jul 2009 16:51:42 +0000 (16:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Jul 2009 16:51:42 +0000 (16:51 +0000)
- some more test fixes

lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/engine/ddl.py
test/aaa_profiling/test_memusage.py
test/engine/test_reflection.py
test/orm/test_onetoone.py
test/orm/test_query.py

index d27775b63104f1eb19c6a2bbc4da6ee97bac0b95..aacfab9b423e7ae6a572e943f968c105b498aabd 100644 (file)
@@ -564,14 +564,16 @@ class OracleDialect(default.DefaultDialect):
                                           resolve_synonyms, dblink,
                                           info_cache=info_cache)
         columns = []
-        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':table_name, 'owner':schema})
+        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, "
+                                "DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s "
+                                "where TABLE_NAME = :table_name and OWNER = :owner" % 
+                                {'dblink':dblink}, {'table_name':table_name, 'owner':schema}
+                                )
 
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
+        for row in c:
 
-            (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+            (colname, coltype, length, precision, scale, nullable, default) = \
+                (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
 
             # INTEGER if the scale is 0 and precision is null
             # NUMBER if the scale and precision are both null
@@ -658,7 +660,9 @@ class OracleDialect(default.DefaultDialect):
              loc.column_name AS local_column,
              rem.table_name AS remote_table,
              rem.column_name AS remote_column,
-             rem.owner AS remote_owner
+             rem.owner AS remote_owner,
+             loc.position as loc_pos,
+             rem.position as rem_pos
            FROM all_constraints%(dblink)s ac,
              all_cons_columns%(dblink)s loc,
              all_cons_columns%(dblink)s rem
@@ -669,8 +673,9 @@ class OracleDialect(default.DefaultDialect):
            AND ac.constraint_name = loc.constraint_name
            AND ac.r_owner = rem.owner(+)
            AND ac.r_constraint_name = rem.constraint_name(+)
-           -- order multiple primary keys correctly
-           ORDER BY ac.constraint_name, loc.position, rem.position"""
+           AND (rem.position IS NULL or loc.position=rem.position)
+           ORDER BY ac.constraint_name, loc.position"""
+           
          % {'dblink':dblink}, {'table_name' : table_name, 'owner' : schema})
         constraint_data = rp.fetchall()
         return constraint_data
@@ -699,9 +704,11 @@ class OracleDialect(default.DefaultDialect):
         constraint_data = self._get_constraint_data(connection, table_name,
                                         schema, dblink,
                                         info_cache=kw.get('info_cache'))
+                                        
         for row in constraint_data:
             #print "ROW:" , row
-            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
+                row[0:2] + tuple([self._normalize_name(x) for x in row[2:6]])
             if cons_type == 'P':
                 pkeys.append(local_column)
         return pkeys
@@ -731,16 +738,23 @@ class OracleDialect(default.DefaultDialect):
         constraint_data = self._get_constraint_data(connection, table_name,
                                                 schema, dblink,
                                                 info_cache=kw.get('info_cache'))
-        fkeys = []
-        fks = {}
+
+        def fkey_rec():
+            return {
+                'name' : None,
+                'constrained_columns' : [],
+                'referred_schema' : None,
+                'referred_table' : None,
+                'referred_columns' : []
+            }
+
+        fkeys = util.defaultdict(fkey_rec)
+        
         for row in constraint_data:
-            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
+                    row[0:2] + tuple([self._normalize_name(x) for x in row[2:6]])
+
             if cons_type == 'R':
-                try:
-                    fk = fks[cons_name]
-                except KeyError:
-                    fk = ([], [])
-                    fks[cons_name] = fk
                 if remote_table is None:
                     # ticket 363
                     util.warn(
@@ -749,30 +763,31 @@ class OracleDialect(default.DefaultDialect):
                          "proper rights to the table?") % {'dblink':dblink})
                     continue
 
-                if resolve_synonyms:
-                    ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(remote_owner), desired_table=self._denormalize_name(remote_table))
-                    if ref_synonym:
-                        remote_table = self._normalize_name(ref_synonym)
-                        remote_owner = self._normalize_name(ref_remote_owner)
-                if local_column not in fk[0]:
-                    fk[0].append(local_column)
-                if remote_column not in fk[1]:
-                    fk[1].append(remote_column)
-        for (name, value) in fks.items():
-            if remote_table and value[1]:
-                if requested_schema is None and remote_owner is not None:
-                    default_schema = self.get_default_schema_name(connection) 
-                    if remote_owner.lower() == default_schema.lower():
-                        remote_owner = None
-                fkey_d = {
-                    'name' : name,
-                    'constrained_columns' : value[0],
-                    'referred_schema' : remote_owner,
-                    'referred_table' : remote_table,
-                    'referred_columns' : value[1]
-                }
-                fkeys.append(fkey_d)
-        return fkeys
+                rec = fkeys[cons_name]
+                rec['name'] = cons_name
+                local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
+
+                if not rec['referred_table']:
+                    if resolve_synonyms:
+                        ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \
+                                self._resolve_synonym(
+                                    connection, 
+                                    desired_owner=self._denormalize_name(remote_owner), 
+                                    desired_table=self._denormalize_name(remote_table)
+                                )
+                        if ref_synonym:
+                            remote_table = self._normalize_name(ref_synonym)
+                            remote_owner = self._normalize_name(ref_remote_owner)
+                    
+                    rec['referred_table'] = remote_table
+                    
+                    if requested_schema is not None or self._denormalize_name(remote_owner) != schema:
+                        rec['referred_schema'] = remote_owner
+                
+                local_cols.append(local_column)
+                remote_cols.append(remote_column)
+
+        return fkeys.values()
 
     @reflection.cache
     def get_view_definition(self, connection, view_name, schema=None,
@@ -787,11 +802,12 @@ class OracleDialect(default.DefaultDialect):
         WHERE owner = :schema
         AND view_name = :view_name
         """
-        rp = connection.execute(sql.text(s),
-                                view_name=view_name, schema=schema)
+        rp = connection.execute(s,
+                                view_name=view_name, schema=schema).scalar()
         if rp:
-            view_def = rp.scalar().decode(self.encoding)
-            return view_def
+            return rp.decode(self.encoding)
+        else:
+            return None
 
     def reflecttable(self, connection, table, include_columns):
         insp = reflection.Inspector.from_engine(connection)
index f344a7138af68b2898648d99792fe06f2e976558..6e7253e9a700a653fedcd3100d69a7899c5c90a1 100644 (file)
@@ -87,7 +87,7 @@ class SchemaDropper(DDLBase):
         else:
             tables = metadata.tables.values()
         collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
-
+        
         for listener in metadata.ddl_listeners['before-drop']:
             listener('before-drop', metadata, self.connection, tables=collection)
         
index 104ba4e3c90d2efbfc85df034701923a6ffd1e76..795bd4b575322dfaf6b5c8e642f07d62ff83c236 100644 (file)
@@ -5,8 +5,7 @@ from sqlalchemy.orm.session import _sessions
 import operator
 from sqlalchemy.test import testing
 from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 import sqlalchemy as sa
 from sqlalchemy.sql import column
 from sqlalchemy.test.util import gc_collect
@@ -77,11 +76,11 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)))
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
@@ -130,11 +129,11 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)))
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)),
             Column('col3', Integer, ForeignKey("mytable.col1")))
 
@@ -185,13 +184,13 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30))
             )
 
         table2 = Table("mytable2", metadata,
             Column('col1', Integer, ForeignKey('mytable.col1'),
-                   primary_key=True),
+                   primary_key=True, test_needs_autoincrement=True),
             Column('col3', String(30)),
             )
 
@@ -245,12 +244,12 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30))
             )
 
         table2 = Table("mytable2", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', String(30)),
             )
 
@@ -309,12 +308,12 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("table1", metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30))
             )
 
         table2 = Table("table2", metadata,
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(30)),
             Column('t1id', Integer, ForeignKey('table1.id'))
             )
@@ -348,7 +347,7 @@ class MemUsageTest(EnsureZeroed):
         metadata = MetaData(testing.db)
 
         table1 = Table("mytable", metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('col2', PickleType(comparator=operator.eq))
             )
         
index 97ec107546ea9f92b586627335710e9fd2a970bb..6b9760d38d7946eea9d920741632b8d2f7a769de 100644 (file)
@@ -58,7 +58,31 @@ class ReflectionTest(TestBase, ComparesTables):
             self.assert_tables_equal(addresses, reflected_addresses)
         finally:
             meta.drop_all()
-
+    
+    def test_two_foreign_keys(self):
+        meta = MetaData(testing.db)
+        t1 = Table('t1', meta, 
+                Column('id', sa.Integer, primary_key=True),
+                Column('t2id', sa.Integer, sa.ForeignKey('t2.id')),
+                Column('t3id', sa.Integer, sa.ForeignKey('t3.id'))
+        )
+        t2 = Table('t2', meta, 
+                Column('id', sa.Integer, primary_key=True)
+        )
+        t3 = Table('t3', meta, 
+                Column('id', sa.Integer, primary_key=True)
+        )
+        meta.create_all()
+        try:
+            meta2 = MetaData()
+            t1r, t2r, t3r = [Table(x, meta2, autoload=True, autoload_with=testing.db) for x in ('t1', 't2', 't3')]
+            
+            assert t1r.c.t2id.references(t2r.c.id)
+            assert t1r.c.t3id.references(t3r.c.id)
+            
+        finally:
+            meta.drop_all()
+            
     def test_include_columns(self):
         meta = MetaData(testing.db)
         foo = Table('foo', meta, *[Column(n, sa.String(30))
index 0d66915ea5d79230bf5cd45a61ec8f0c9dfe0c87..6880f1f747e4eff1e0edc82f5e6157b373ef8136 100644 (file)
@@ -1,8 +1,7 @@
 import sqlalchemy as sa
 from sqlalchemy.test import testing
 from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relation, create_session
 from test.orm import _base
 
@@ -11,13 +10,13 @@ class O2OTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('jack', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('number', String(50)),
               Column('status', String(20)),
               Column('subroom', String(5)))
 
         Table('port', metadata,
-              Column('id', Integer, primary_key=True),
+              Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('name', String(30)),
               Column('description', String(100)),
               Column('jack_id', Integer, ForeignKey("jack.id")))
index 45bd2b9992029c0b7376520e1c2051fb0739f19e..bb15641525e5c27ce615b6f85303ae8d260abe3d 100644 (file)
@@ -471,10 +471,16 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         sess = create_session()
 
         self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement, 
-            "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1")
+            "SELECT users.id AS users_id, users.name AS users_name FROM users, "
+            "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1",
+            dialect=default.DefaultDialect()
+            )
 
         self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement, 
-            "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users")
+            "SELECT users.id AS users_id, users.name AS users_name, EXISTS "
+            "(SELECT 1 FROM addresses) AS anon_1 FROM users",
+            dialect=default.DefaultDialect()
+            )
 
         # a little tedious here, adding labels to work around Query's auto-labelling.
         # also correlate needed explicitly.  hmmm.....
@@ -793,7 +799,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL):
         
         s = create_session()
         
-        oracle_as = "AS " if not testing.against('oracle') else ""
+        oracle_as = not testing.against('oracle') and "AS " or ""
         
         self.assert_compile(
             s.query(User).options(eagerload(User.addresses)).from_self().statement,
@@ -2510,12 +2516,14 @@ class SelfReferentialTest(_base.MappedTest):
         sess = create_session()
         eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
         eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
-        eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
+        eq_(sess.query(Node).filter(~Node.children.any()).order_by(Node.id).all(), 
+                [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
 
     def test_has(self):
         sess = create_session()
     
-        eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+        eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).order_by(Node.id).all(), 
+            [Node(data='n121'),Node(data='n122'),Node(data='n123')])
         eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
         eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])