]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refined "local_remote_pairs" a bit to account for the same columns repeated multiple...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Apr 2008 20:21:03 +0000 (20:21 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Apr 2008 20:21:03 +0000 (20:21 +0000)
CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
test/orm/relationships.py

diff --git a/CHANGES b/CHANGES
index 750410db26877f79c5b0094ef7b34622590bbe0a..897b84188a62ec04dca9271c80ad461aab43999f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,6 +4,12 @@ CHANGES
 
 0.4.6
 =====
+- orm
+    - Fix to the recent relation() refactoring which fixes
+      exotic relations which join between local and remote table
+      multiple times, with a common column shared between the 
+      joins.
+      
 - sql
     - Fixed bug with union() when applied to non-Table connected
       select statements
index a171923fe948637a1acfdb6a4c1d721603cc8498..f18a6bddecbc4b5bee37fa2f55bedabcda675a88 100644 (file)
@@ -572,7 +572,7 @@ class PropertyLoader(StrategizedProperty):
         self.synchronize_pairs = eq_pairs
         
         if self.secondaryjoin:
-            sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys)
+            sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly)
             sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
             
             if not sq_pairs:
@@ -590,24 +590,23 @@ class PropertyLoader(StrategizedProperty):
         else:
             self.secondary_synchronize_pairs = None
     
-    def local_remote_pairs(self):
-        return zip(self.local_side, self.remote_side)
-    local_remote_pairs = property(local_remote_pairs)
-    
     def __determine_remote_side(self):
         if self.remote_side:
             if self.direction is MANYTOONE:
                 eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True)
             else:
                 eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True)
-
-            if self.secondaryjoin:
-                sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True)
-                sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
-                eq_pairs += sq_pairs
         else:
-            eq_pairs = zip(self._opposite_side, self.foreign_keys)
-
+            if self.viewonly:
+                eq_pairs = self.synchronize_pairs
+            else:
+                eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True)
+                if self.secondaryjoin:
+                    sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True)
+                    eq_pairs += sq_pairs
+                eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+        
+        self.local_remote_pairs = eq_pairs
         if self.direction is MANYTOONE:
             self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
         else:
index e4c85b5d712fcf9a68b60ab2e0495909851ca583..3bbb380d134ab1c884c4068a0b26a70a07cd3458 100644 (file)
@@ -354,13 +354,13 @@ class LazyLoader(AbstractRelationLoader):
         equated_columns = {}
 
         secondaryjoin = prop.secondaryjoin
-        equated = dict(prop.local_remote_pairs)
+        local = prop.local_side
         
         def should_bind(targetcol, othercol):
             if reverse_direction and not secondaryjoin:
-                return othercol in equated
+                return othercol in local
             else:
-                return targetcol in equated
+                return targetcol in local
 
         def visit_binary(binary):
             leftcol = binary.left
index 89fd86f27578e8cacb060be2272097243bb5d9f1..31d8ab49d6d3172e3be9661ce6395e017208f970 100644 (file)
@@ -863,6 +863,7 @@ class ViewOnlyTest2(ORMTest):
         assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id])
 
 class ViewOnlyTest3(ORMTest):
+    """test relating on a join that has no equated columns"""
     def define_tables(self, metadata):
         global foos, bars
         foos = Table('foos', metadata, Column('id', Integer, primary_key=True))
@@ -893,6 +894,88 @@ class ViewOnlyTest3(ORMTest):
         self.assertEquals(sess.query(Foo).filter_by(id=4).one(), Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)]))
         self.assertEquals(sess.query(Foo).filter_by(id=9).one(), Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)]))
 
+class ViewOnlyTest4(ORMTest):
+    """test relating on a join that contains the same 'remote' column twice"""
+    def define_tables(self, metadata):
+        global foos, bars
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True), 
+            Column('bid1', Integer,ForeignKey('bars.id')),
+            Column('bid2', Integer,ForeignKey('bars.id')))
+            
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)))
+        
+    def test_relation_on_or(self):
+        class Foo(fixtures.Base):
+            pass
+        class Bar(fixtures.Base):
+            pass
+
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=or_(bars.c.id==foos.c.bid1, bars.c.id==foos.c.bid2), uselist=True, viewonly=True)
+        })
+
+        mapper(Bar, bars)
+        sess = create_session()
+        b1 = Bar(id=1, data='b1')
+        b2 = Bar(id=2, data='b2')
+        b3 = Bar(id=3, data='b3')
+        f1 = Foo(bid1=1, bid2=2)
+        f2 = Foo(bid1=3, bid2=None)
+        sess.save(b1)
+        sess.save(b2)
+        sess.save(b3)
+        sess.save(f1)
+        sess.save(f2)
+        sess.flush()
+        
+        sess.clear()
+        self.assertEquals(sess.query(Foo).filter_by(id=f1.id).one(), Foo(bars=[Bar(data='b1'), Bar(data='b2')]))
+        self.assertEquals(sess.query(Foo).filter_by(id=f2.id).one(), Foo(bars=[Bar(data='b3')]))
+
+class ViewOnlyTest5(ORMTest):
+    """test relating on a join that contains the same 'local' column twice"""
+    def define_tables(self, metadata):
+        global foos, bars
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True), 
+            Column('data', String(50))
+            )
+
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True), 
+                Column('fid1', Integer, ForeignKey('foos.id')),
+                Column('fid2', Integer, ForeignKey('foos.id')),
+                Column('data', String(50))
+            )
+
+    def test_relation_on_or(self):
+        class Foo(fixtures.Base):
+            pass
+        class Bar(fixtures.Base):
+            pass
+
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=or_(bars.c.fid1==foos.c.id, bars.c.fid2==foos.c.id), viewonly=True)
+        })
+
+        mapper(Bar, bars)
+        sess = create_session()
+        f1 = Foo(id=1, data='f1')
+        f2 = Foo(id=2, data='f2')
+        b1 = Bar(fid1=1, data='b1')
+        b2 = Bar(fid2=1, data='b2')
+        b3 = Bar(fid1=2, data='b3')
+        b4 = Bar(fid1=1, fid2=2, data='b4')
+        sess.save(b1)
+        sess.save(b2)
+        sess.save(b3)
+        sess.save(b4)
+        sess.save(f1)
+        sess.save(f2)
+        sess.flush()
+
+        sess.clear()
+        self.assertEquals(sess.query(Foo).filter_by(id=f1.id).one(), Foo(bars=[Bar(data='b1'), Bar(data='b2'), Bar(data='b4')]))
+        self.assertEquals(sess.query(Foo).filter_by(id=f2.id).one(), Foo(bars=[Bar(data='b3'), Bar(data='b4')]))
+    
 class InvalidRelationEscalationTest(ORMTest):
     def define_tables(self, metadata):
         global foos, bars, Foo, Bar