]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added 'remote_side' functionality to lazy clause generation
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Dec 2006 23:36:59 +0000 (23:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Dec 2006 23:36:59 +0000 (23:36 +0000)
lib/sqlalchemy/orm/strategies.py
test/orm/cycles.py

index 28fbdcfd31165d8e245eeb8b0f87140c4875d263..a3a92841de07a7e5d0ab03fa9501346493ab35f6 100644 (file)
@@ -129,7 +129,8 @@ class AbstractRelationLoader(LoaderStrategy):
         self.cascade = self.parent_property.cascade
         self.attributeext = self.parent_property.attributeext
         self.order_by = self.parent_property.order_by
-
+        self.remote_side = self.parent_property.remote_side
+        
     def _init_instance_attribute(self, instance, callable_=None):
         return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, self.uselist, cascade=self.cascade,  trackparent=True, callable_=callable_)
         
@@ -151,7 +152,7 @@ NoLoader.logger = logging.class_logger(NoLoader)
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey)
+        (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.remote_side)
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere)
@@ -235,12 +236,15 @@ class LazyLoader(AbstractRelationLoader):
                 # to load data into it.
                 sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
 
-    def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey):
+    def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey, remote_side):
         binds = {}
         reverse = {}
         def column_in_table(table, column):
             return table.corresponding_column(column, raiseerr=False, keys_ok=False) is not None
 
+        if remote_side is None or len(remote_side) == 0:
+            remote_side = foreignkey
+            
         def find_column_in_expr(expr):
             if not isinstance(expr, sql.ColumnElement):
                 return None
@@ -259,13 +263,13 @@ class LazyLoader(AbstractRelationLoader):
             if leftcol is None or rightcol is None:
                 return
             circular = leftcol.table is rightcol.table
-            if ((not circular and column_in_table(table, leftcol)) or (circular and rightcol in foreignkey)):
+            if ((not circular and column_in_table(table, leftcol)) or (circular and rightcol in remote_side)):
                 col = leftcol
                 binary.left = binds.setdefault(leftcol,
                         sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type))
                 reverse[rightcol] = binds[col]
 
-            if (leftcol is not rightcol) and ((not circular and column_in_table(table, rightcol)) or (circular and leftcol in foreignkey)):
+            if (leftcol is not rightcol) and ((not circular and column_in_table(table, rightcol)) or (circular and leftcol in remote_side)):
                 col = rightcol
                 binary.right = binds.setdefault(rightcol,
                         sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type))
index b5ad6ce332ce27803e793d48ff9f42e8899b9124..5ec04264891c8ffdc95c4ed3becf2f17a93c1d0a 100644 (file)
@@ -147,6 +147,22 @@ class SelfReferentialNoPKTest(AssertMixin):
         s.clear()
         t = s.query(TT).get_by(id=t1.id)
         assert t.children[0].parent_uuid == t1.uuid
+    def testlazyclause(self):
+        class TT(object):
+            def __init__(self):
+                self.uuid = hex(id(self))
+        mapper(TT, table, properties={'children':relation(TT, remote_side=[table.c.parent_uuid], backref=backref('parent', remote_side=[table.c.uuid]))})
+        s = create_session()
+        t1 = TT()
+        t2 = TT()
+        t1.children.append(t2)
+        s.save(t1)
+        s.flush()
+        s.clear()
+
+        t = s.query(TT).get_by(id=t2.id)
+        assert t.uuid == t2.uuid
+        assert t.parent.uuid == t1.uuid
         
 class InheritTestOne(AssertMixin):
     def setUpAll(self):