]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- lazy loader can now handle a join condition where the "bound"
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 9 Feb 2008 01:48:19 +0000 (01:48 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 9 Feb 2008 01:48:19 +0000 (01:48 +0000)
column (i.e. the one that gets the parent id sent as a bind
parameter) appears more than once in the join condition.
Specifically this allows the common task of a relation()
which contains a parent-correlated subquery, such as "select
only the most recent child item". [ticket:946]
- col_is_part_of_mappings made more strict, seems to be OK
with tests
- memusage will dump out the size list in an assertion fail

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
test/orm/lazy_relations.py
test/orm/memusage.py

diff --git a/CHANGES b/CHANGES
index 605569ead152e75a04384f71b8b743b51172c58b..90124c05a15254bbad92a281521f3acec7264304 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -138,6 +138,13 @@ CHANGES
       relation.  This affects how many arguments need to be sent
       to query.get(), among other things.  [ticket:933]
 
+    - lazy loader can now handle a join condition where the "bound"
+      column (i.e. the one that gets the parent id sent as a bind
+      parameter) appears more than once in the join condition.
+      Specifically this allows the common task of a relation()
+      which contains a parent-correlated subquery, such as "select
+      only the most recent child item". [ticket:946]
+      
     - Fixed bug in polymorphic inheritance where incorrect
       exception is raised when base polymorphic_on column does not
       correspond to any columns within the local selectable of an
index a7932bfc55a163d46d93c07d20f90d9f22472ce5..d08dd712471bdeffbb36c03c4e4550a6c4945218 100644 (file)
@@ -494,19 +494,19 @@ class PropertyLoader(StrategizedProperty):
             if vis.result:
                 raise exceptions.ArgumentError("In relationship '%s', primary and secondary join conditions must not include columns from the polymorphic 'select_table' argument as of SA release 0.3.4.  Construct join conditions using the base tables of the related mappers." % (str(self)))
 
+    def _col_is_part_of_mappings(self, column):
+        if self.secondary is None:
+            return self.parent.mapped_table.c.contains_column(column) or \
+                self.target.c.contains_column(column)
+        else:
+            return self.parent.mapped_table.c.contains_column(column) or \
+                self.target.c.contains_column(column) or \
+                self.secondary.c.contains_column(column) is not None
+        
     def _determine_fks(self):
         if self._legacy_foreignkey and not self._is_self_referential():
             self.foreign_keys = self._legacy_foreignkey
 
-        def col_is_part_of_mappings(col):
-            if self.secondary is None:
-                return self.parent.mapped_table.corresponding_column(col) is not None or \
-                    self.target.corresponding_column(col) is not None
-            else:
-                return self.parent.mapped_table.corresponding_column(col) is not None or \
-                    self.target.corresponding_column(col) is not None or \
-                    self.secondary.corresponding_column(col) is not None
-
         if self.foreign_keys:
             self._opposite_side = util.Set()
             def visit_binary(binary):
@@ -529,7 +529,7 @@ class PropertyLoader(StrategizedProperty):
                 # this check is for when the user put the "view_only" flag on and has tables that have nothing
                 # to do with the relationship's parent/child mappings in the join conditions.  we dont want cols
                 # or clauses related to those external tables dealt with.  see orm.relationships.ViewOnlyTest
-                if not col_is_part_of_mappings(binary.left) or not col_is_part_of_mappings(binary.right):
+                if not self._col_is_part_of_mappings(binary.left) or not self._col_is_part_of_mappings(binary.right):
                     return
 
                 for f in binary.left.foreign_keys:
index 3b3c86d1a6796d5e986954bebd229ea52f97156d..bdc8ab9a99ed752fa1841a1c7419bd68778c42bb 100644 (file)
@@ -265,7 +265,7 @@ NoLoader.logger = logging.class_logger(NoLoader)
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self)
+        (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self.parent_property)
         
         self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
 
@@ -366,6 +366,9 @@ class LazyLoader(AbstractRelationLoader):
         equated_columns = {}
 
         def should_bind(targetcol, othercol):
+            if not prop._col_is_part_of_mappings(targetcol):
+                return False
+                
             if reverse_direction and not secondaryjoin:
                 return targetcol in remote_side
             else:
@@ -381,13 +384,20 @@ class LazyLoader(AbstractRelationLoader):
             equated_columns[leftcol] = rightcol
 
             if should_bind(leftcol, rightcol):
-                binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
+                if leftcol in binds:
+                    binary.left = binds[leftcol]
+                else:
+                    binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
 
             # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
             # which can happen in rare cases (test/orm/relationships.py RelationTest2)
             if leftcol is not rightcol and should_bind(rightcol, leftcol):
-                binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
+                if rightcol in binds:
+                    binary.right = binds[rightcol]
+                else:
+                    binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
 
+                
         lazywhere = primaryjoin
         
         if not secondaryjoin or not reverse_direction:
index 4bd3e71e194d0eca0825e77d919b00d8a13aafce..55d79fd32b94f8d458445a2db31d2db4525e7fb6 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
 from query import QueryTest
+import datetime
 
 class LazyTest(FixtureTest):
     keep_mappers = False
@@ -335,5 +336,60 @@ class M2OGetTest(FixtureTest):
             assert ad3.user is None
         self.assert_sql_count(testing.db, go, 1)
 
+class CorrelatedTest(ORMTest):
+    keep_mappers = False
+    keep_data = False
+    
+    def define_tables(self, meta):
+        global user_t, stuff
+        
+        user_t = Table('users', meta,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(50))
+            )
+
+        stuff = Table('stuff', meta,
+            Column('id', Integer, primary_key=True),
+            Column('date', Date),
+            Column('user_id', Integer, ForeignKey('users.id')))
+    
+    def insert_data(self):
+        user_t.insert().execute(
+            {'id':1, 'name':'user1'},
+            {'id':2, 'name':'user2'},
+            {'id':3, 'name':'user3'},
+        )
+
+        stuff.insert().execute(
+            {'id':1, 'user_id':1, 'date':datetime.date(2007, 10, 15)},
+            {'id':2, 'user_id':1, 'date':datetime.date(2007, 12, 15)},
+            {'id':3, 'user_id':1, 'date':datetime.date(2007, 11, 15)},
+            {'id':4, 'user_id':2, 'date':datetime.date(2008, 1, 15)},
+            {'id':5, 'user_id':3, 'date':datetime.date(2007, 6, 15)},
+        )        
+        
+    def test_correlated_lazyload(self):
+        class User(Base):
+            pass
+
+        class Stuff(Base):
+            pass
+            
+        mapper(Stuff, stuff)
+
+        stuff_view = select([stuff.c.id]).where(stuff.c.user_id==user_t.c.id).correlate(user_t).order_by(desc(stuff.c.date)).limit(1)
+
+        mapper(User, user_t, properties={
+            'stuff':relation(Stuff, primaryjoin=and_(user_t.c.id==stuff.c.user_id, stuff.c.id==(stuff_view.as_scalar())))
+        })
+
+        sess = create_session()
+
+        self.assertEquals(sess.query(User).all(), [
+            User(name='user1', stuff=[Stuff(date=datetime.date(2007, 12, 15), id=2)]), 
+            User(name='user2', stuff=[Stuff(id=4, date=datetime.date(2008, 1 , 15))]), 
+            User(name='user3', stuff=[Stuff(id=5, date=datetime.date(2007, 6, 15))])
+        ])
+
 if __name__ == '__main__':
     testenv.main()
index 7f8392ed392282b514183eb8b839e465307384fa..87471842fde4ee276220aadc87b6651266bc212e 100644 (file)
@@ -28,7 +28,7 @@ def profile_memory(func):
             if i < len(samples) - 1 and samples[i+1] <= x:
                 break
         else:
-            assert False
+            assert False, repr(samples)
         assert True
     return profile