]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed bugs in determining proper sync clauses from custom inherit
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Aug 2007 15:55:41 +0000 (15:55 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Aug 2007 15:55:41 +0000 (15:55 +0000)
  conditions [ticket:769]

CHANGES
lib/sqlalchemy/orm/sync.py
test/orm/inheritance/basic.py

diff --git a/CHANGES b/CHANGES
index 239ffaa7b51fea956219862265059349791f50c7..81a00cd314d03510b624701ee0fd6f9ce2cacd36 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -5,6 +5,9 @@ CHANGES
 0.4.0beta5
 ----------
 
+- fixed bugs in determining proper sync clauses from custom inherit
+  conditions [ticket:769]
+  
 - Extended 'engine_from_config' coercion for QueuePool size / overflow.
   [ticket:763]
 
index 795d76b1af363279be546f3363362bfef5ff59e5..d858428611a02a78e397224e17c9d1f292c90fb9 100644 (file)
@@ -9,7 +9,7 @@ attributes between two objects in a manner corresponding to a SQL
 clause that compares column values.
 """
 
-from sqlalchemy import schema, exceptions
+from sqlalchemy import schema, exceptions, util
 from sqlalchemy.sql import visitors, operators
 from sqlalchemy import logging
 from sqlalchemy.orm import util as mapperutil
@@ -53,10 +53,10 @@ class ClauseSynchronizer(object):
                 if binary.left.table == binary.right.table:
                     raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync")
 
-                if binary.left in [f.column for f in binary.right.foreign_keys]:
+                if binary.left in util.Set([f.column for f in binary.right.foreign_keys]):
                     dest_column = binary.right
                     source_column = binary.left
-                elif binary.right in [f.column for f in binary.left.foreign_keys]:
+                elif binary.right in util.Set([f.column for f in binary.left.foreign_keys]):
                     dest_column = binary.left
                     source_column = binary.right
             else:
index ae709b7c5fb141d5019969168fb56a7d9d39a962..870f3e86922cee932ceb42ad5ea63f4110d75b23 100644 (file)
@@ -405,6 +405,73 @@ class DistinctPKTest(ORMTest):
             assert alice1.name == alice2.name == 'alice'
             assert bob.name == 'bob'
 
+class SyncCompileTest(ORMTest):
+    """test that syncrules compile properly on custom inherit conds"""
+    def define_tables(self, metadata):
+        global _a_table, _b_table, _c_table
+        
+        _a_table = Table('a', metadata,
+           Column('id', Integer, primary_key=True),
+           Column('data1', String)
+        )
+
+        _b_table = Table('b', metadata,
+           Column('a_id', Integer, ForeignKey('a.id'), primary_key=True),
+           Column('data2', String)
+        )
+
+        _c_table = Table('c', metadata,
+        #   Column('a_id', Integer, ForeignKey('b.a_id'), primary_key=True), #works
+           Column('b_a_id', Integer, ForeignKey('b.a_id'), primary_key=True),
+           Column('data3', String)
+        )
+    
+    def test_joins(self):
+        for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id):
+            for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id, _c_table.c.b_a_id==_b_table.c.a_id):
+                self._do_test(j1, j2)
+                for t in _a_table.metadata.table_iterator(reverse=True):
+                    t.delete().execute().close()
+                
+    def _do_test(self, j1, j2):
+        class A(object):
+           def __init__(self, **kwargs):
+               for key, value in kwargs.items():
+                    setattr(self, key, value)
+
+        class B(A):
+            pass
+
+        class C(B):
+            pass
+        
+        mapper(A, _a_table)
+        mapper(B, _b_table, inherits=A, 
+               inherit_condition=j1
+               )
+        mapper(C, _c_table, inherits=B, 
+               inherit_condition=j2
+               )
+
+        session = create_session()
+
+        a = A(data1='a1')
+        session.save(a)
+
+        b = B(data1='b1', data2='b2')
+        session.save(b)
+
+        c = C(data1='c1', data2='c2', data3='c3')
+        session.save(c)
+
+        session.flush()
+        session.clear()
+
+        assert len(session.query(A).all()) == 3
+        assert len(session.query(B).all()) == 2
+        assert len(session.query(C).all()) == 1
+
+        
 
 if __name__ == "__main__":    
     testbase.main()