]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- simplified __create_lazy_clause to make better usage of the new local/remote pairs...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Apr 2008 18:23:59 +0000 (18:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Apr 2008 18:23:59 +0000 (18:23 +0000)
- corrected the direction of local/remote pairs for manytoone
- added new tests which demonstrate lazyloading working when the bind param is embedded inside of a SQL function,
when _local_remote_pairs argument is used; fixes the viewonly version of [ticket:610]
- removed needless kwargs check from visitors.traverse

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/visitors.py
test/orm/relationships.py

diff --git a/CHANGES b/CHANGES
index 37861592d781c9eabc93001af644e90868fc85b1..3bcfd8915902a3eb54457b216f1bdca00029d351 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -17,8 +17,11 @@ CHANGES
       across functions, etc., _local_remote_pairs=[tuples].  
       This complements a complex primaryjoin condition allowing 
       you to provide the individual column pairs which comprise
-      the relation's local and remote sides.
-      
+      the relation's local and remote sides.  Also improved 
+      lazy load SQL generation to handle placing bind params 
+      inside of functions and other expressions.
+      (partial progress towards [ticket:610])
+    
     - removed ancient assertion that mapped selectables require
       "alias names" - the mapper creates its own alias now if
       none is present.  Though in this case you need to use 
index 140fffd892183c7021d12eb86bf6ad27198f6fe1..55f6c9875224d574b0cc9552124986ff8e504bc6 100644 (file)
@@ -621,9 +621,10 @@ class PropertyLoader(StrategizedProperty):
         
         if self.direction is MANYTOONE:
             self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+            self.local_remote_pairs = [(r, l) for l, r in eq_pairs]
         else:
             self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
-        self.local_remote_pairs = zip(self.local_side, self.remote_side)
+            self.local_remote_pairs = eq_pairs
         
         if self.direction is ONETOMANY:
             for l in self.local_side:
@@ -651,12 +652,15 @@ class PropertyLoader(StrategizedProperty):
                         self.direction = ONETOMANY
                     else:
                         self.direction = MANYTOONE
-
+            elif self._arg_local_remote_pairs:
+                remote = util.Set([r for l, r in self._arg_local_remote_pairs])
+                if self.foreign_keys.intersection(remote):
+                    self.direction = ONETOMANY
+                else:
+                    self.direction = MANYTOONE
             elif self.remote_side:
-                for f in self.foreign_keys:
-                    if f in self.remote_side:
-                        self.direction = ONETOMANY
-                        return
+                if self.foreign_keys.intersection(self.remote_side):
+                    self.direction = ONETOMANY
                 else:
                     self.direction = MANYTOONE
             else:
index 3bbb380d134ab1c884c4068a0b26a70a07cd3458..e758ac08b102d57c0283bdc6fe357bdc62ab63f9 100644 (file)
@@ -351,41 +351,39 @@ class LazyLoader(AbstractRelationLoader):
 
     def __create_lazy_clause(cls, prop, reverse_direction=False):
         binds = {}
+        lookup = {}
         equated_columns = {}
 
-        secondaryjoin = prop.secondaryjoin
-        local = prop.local_side
-        
-        def should_bind(targetcol, othercol):
-            if reverse_direction and not secondaryjoin:
-                return othercol in local
-            else:
-                return targetcol in local
-
-        def visit_binary(binary):
-            leftcol = binary.left
-            rightcol = binary.right
-
-            equated_columns[rightcol] = leftcol
-            equated_columns[leftcol] = rightcol
-
-            if should_bind(leftcol, rightcol):
-                if leftcol not in binds:
-                    binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
-                binary.left = binds[leftcol]
-            elif should_bind(rightcol, leftcol):
-                if rightcol not in binds:
-                    binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
-                binary.right = binds[rightcol]
-
+        if reverse_direction and not prop.secondaryjoin:
+            for l, r in prop.local_remote_pairs:
+                _list = lookup.setdefault(r, [])
+                _list.append((r, l))
+                equated_columns[l] = r
+        else:
+            for l, r in prop.local_remote_pairs:
+                _list = lookup.setdefault(l, [])
+                _list.append((l, r))
+                equated_columns[r] = l
+                
+        def col_to_bind(col):
+            if col in lookup:
+                for tobind, equated in lookup[col]:
+                    if equated in binds:
+                        return None
+                if col not in binds:
+                    binds[col] = sql.bindparam(None, None, type_=col.type)
+                return binds[col]
+            return None
+                    
         lazywhere = prop.primaryjoin
         
         if not prop.secondaryjoin or not reverse_direction:
-            lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary)
+            lazywhere = visitors.traverse(lazywhere, before_clone=col_to_bind, clone=True) 
         
         if prop.secondaryjoin is not None:
+            secondaryjoin = prop.secondaryjoin
             if reverse_direction:
-                secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
+                secondaryjoin = visitors.traverse(secondaryjoin, before_clone=col_to_bind, clone=True)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
     
         bind_to_col = dict([(binds[col].key, col) for col in binds])
index 792391929f3db7afccb774b2395c8c2bcaa5c025..9888a228a39a55a50729f12e0708d1cfc4041ef5 100644 (file)
@@ -177,7 +177,6 @@ def traverse(clause, **kwargs):
         __traverse_options__ = kwargs.pop('traverse_options', {})
     vis = Vis()
     for key in kwargs:
-        if key.startswith('visit_'):
-            setattr(vis, key, kwargs[key])
+        setattr(vis, key, kwargs[key])
     return vis.traverse(clause, clone=clone)
 
index 6e9a0f96e1d1188990fc180769b313db98e54fca..40773f8359599013eaa697bf45996a146d257cea 100644 (file)
@@ -1066,7 +1066,7 @@ class ExplicitLocalRemoteTest(ORMTest):
             Column('t1id', String(50)),
         )
 
-    def test_onetomany(self):
+    def test_onetomany_funcfk(self):
         class T1(fixtures.Base):
             pass
         class T2(fixtures.Base):
@@ -1084,17 +1084,21 @@ class ExplicitLocalRemoteTest(ORMTest):
         
         sess = create_session()
         a1 = T1(id='number1', data='a1')
+        a2 = T1(id='number2', data='a2')
         b1 = T2(data='b1', t1id='NuMbEr1')
         b2 = T2(data='b2', t1id='Number1')
+        b3 = T2(data='b3', t1id='Number2')
         sess.save(a1)
+        sess.save(a2)
         sess.save(b1)
         sess.save(b2)
+        sess.save(b3)
         sess.flush()
         sess.clear()
         
         self.assertEquals(sess.query(T1).first(), T1(id='number1', data='a1', t2s=[T2(data='b1', t1id='NuMbEr1'), T2(data='b2', t1id='Number1')]))
     
-    def test_manytoone(self):
+    def test_manytoone_funcfk(self):
         class T1(fixtures.Base):
             pass
         class T2(fixtures.Base):
@@ -1103,25 +1107,95 @@ class ExplicitLocalRemoteTest(ORMTest):
         mapper(T2, t2, properties={
             't1':relation(T1, primaryjoin=t1.c.id==func.lower(t2.c.t1id),
                 _local_remote_pairs=[(t2.c.t1id, t1.c.id)],
-                foreign_keys=[t2.c.t1id]
+                foreign_keys=[t2.c.t1id],
+                uselist=True
             )
         })
         sess = create_session()
         a1 = T1(id='number1', data='a1')
+        a2 = T1(id='number2', data='a2')
         b1 = T2(data='b1', t1id='NuMbEr1')
         b2 = T2(data='b2', t1id='Number1')
+        b3 = T2(data='b3', t1id='Number2')
         sess.save(a1)
+        sess.save(a2)
         sess.save(b1)
         sess.save(b2)
+        sess.save(b3)
         sess.flush()
         sess.clear()
-        self.assertEquals(sess.query(T2).all(), 
+        self.assertEquals(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), 
             [
-                T2(data='b1', t1=T1(id='number1', data='a1')),
-                T2(data='b2', t1=T1(id='number1', data='a1'))
+                T2(data='b1', t1=[T1(id='number1', data='a1')]),
+                T2(data='b2', t1=[T1(id='number1', data='a1')])
             ]
         )
     
+    def test_onetomany_func_referent(self):
+        class T1(fixtures.Base):
+            pass
+        class T2(fixtures.Base):
+            pass
+        
+        mapper(T1, t1, properties={
+            't2s':relation(T2, primaryjoin=func.lower(t1.c.id)==t2.c.t1id, 
+                _local_remote_pairs=[(t1.c.id, t2.c.t1id)],
+                foreign_keys=[t2.c.t1id]
+            )
+        })
+        mapper(T2, t2)
+        
+        sess = create_session()
+        a1 = T1(id='NuMbeR1', data='a1')
+        a2 = T1(id='NuMbeR2', data='a2')
+        b1 = T2(data='b1', t1id='number1')
+        b2 = T2(data='b2', t1id='number1')
+        b3 = T2(data='b2', t1id='number2')
+        sess.save(a1)
+        sess.save(a2)
+        sess.save(b1)
+        sess.save(b2)
+        sess.save(b3)
+        sess.flush()
+        sess.clear()
+        
+        self.assertEquals(sess.query(T1).first(), T1(id='NuMbeR1', data='a1', t2s=[T2(data='b1', t1id='number1'), T2(data='b2', t1id='number1')]))
+
+    def test_manytoone_func_referent(self):
+        class T1(fixtures.Base):
+            pass
+        class T2(fixtures.Base):
+            pass
+
+        mapper(T1, t1)
+        mapper(T2, t2, properties={
+            't1':relation(T1, primaryjoin=func.lower(t1.c.id)==t2.c.t1id,
+                _local_remote_pairs=[(t2.c.t1id, t1.c.id)],
+                foreign_keys=[t2.c.t1id], uselist=True
+            )
+        })
+
+        sess = create_session()
+        a1 = T1(id='NuMbeR1', data='a1')
+        a2 = T1(id='NuMbeR2', data='a2')
+        b1 = T2(data='b1', t1id='number1')
+        b2 = T2(data='b2', t1id='number1')
+        b3 = T2(data='b3', t1id='number2')
+        sess.save(a1)
+        sess.save(a2)
+        sess.save(b1)
+        sess.save(b2)
+        sess.save(b3)
+        sess.flush()
+        sess.clear()
+
+        self.assertEquals(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), 
+            [
+                T2(data='b1', t1=[T1(id='NuMbeR1', data='a1')]),
+                T2(data='b2', t1=[T1(id='NuMbeR1', data='a1')])
+            ]
+        )
+        
     def test_escalation(self):
         class T1(fixtures.Base):
             pass