]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merged r6355 from trunk for #1543
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Sep 2009 19:50:57 +0000 (19:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Sep 2009 19:50:57 +0000 (19:50 +0000)
CHANGES
lib/sqlalchemy/orm/strategies.py
test/orm/inheritance/test_basic.py
test/orm/inheritance/test_polymorph2.py
test/orm/inheritance/test_query.py

diff --git a/CHANGES b/CHANGES
index 79d08eb90bf82eff7dc874d4742e2aa8da0fbfde..e5bb8cd7b9c4007f5ebbe6437fbacc0b226576c9 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -3,6 +3,19 @@
 =======
 CHANGES
 =======
+
+0.5.7
+=====
+- orm
+    - contains_eager() now works with the automatically 
+      generated subquery that results when you say 
+      "query(Parent).join(Parent.somejoinedsubclass)", i.e. 
+      when Parent joins to a joined-table-inheritance subclass.  
+      Previously contains_eager() would erroneously add the 
+      subclass table to the query separately producing a 
+      cartesian product.  An example is in the ticket
+      description.  [ticket:1543]
+      
 0.5.6
 =====
 - orm
index fd767c8655de1d5447f8e0f2882d43157df54df3..bdf99980cd4159bfdfb02405b48a72b398ab02e1 100644 (file)
@@ -836,7 +836,10 @@ class LoadEagerFromAliasOption(PropertyOption):
                 self.alias = prop.target.alias(self.alias)
             query._attributes[("user_defined_eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
         else:
-            query._attributes[("user_defined_eager_row_processor", paths[-1])] = None
+            (mapper, propname) = paths[-1][-2:]
+            prop = mapper.get_property(propname, resolve_synonyms=True)
+            adapter = query._polymorphic_adapters.get(prop.mapper, None)
+            query._attributes[("user_defined_eager_row_processor", paths[-1])] = adapter
 
 class _SingleParentValidator(interfaces.AttributeExtension):
     def __init__(self, prop):
index f93dfd5476634441d2fe8191f04b61c48ab8f4c8..acb9f369d5196064627da2f2f040d1c70ce5a6c9 100644 (file)
@@ -398,6 +398,88 @@ class EagerTargetingTest(_base.MappedTest):
         eq_(node, B(id=1, name='b1',b_data='i'))
         eq_(node.children[0], B(id=2, name='b2',b_data='l'))
         
+class EagerToSubclassTest(_base.MappedTest):
+    """Test eagerloads to subclass mappers"""
+    
+    run_setup_classes = 'once'
+    run_setup_mappers = 'once'
+    run_inserts = 'once'
+    run_deletes = None
+    
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('parent', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(10)),
+        )
+        
+        Table('base', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('type', String(10)),
+        )
+
+        Table('sub', metadata,
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+            Column('data', String(10)),
+            Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False)
+        )
+    
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_classes(cls):
+        class Parent(_base.ComparableEntity):
+            pass
+        
+        class Base(_base.ComparableEntity):
+            pass
+        
+        class Sub(Base):
+            pass
+    
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Parent, parent, properties={
+            'children':relation(Sub)
+        })
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='s')
+    
+    @classmethod
+    @testing.resolve_artifact_names
+    def insert_data(cls):
+        sess = create_session()
+        p1 = Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')])
+        p2 = Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+        sess.add(p1)
+        sess.add(p2)
+        sess.flush()
+        
+    @testing.resolve_artifact_names
+    def test_eagerload(self):
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(Parent).options(eagerload(Parent.children)).all(), 
+                [
+                    Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+    @testing.resolve_artifact_names
+    def test_contains_eager(self):
+        sess = create_session()
+        def go():
+            eq_(
+                sess.query(Parent).join(Parent.children).options(contains_eager(Parent.children)).all(), 
+                [
+                    Parent(data='p1', children=[Sub(data='s1'), Sub(data='s2'), Sub(data='s3')]),
+                    Parent(data='p2', children=[Sub(data='s4'), Sub(data='s5')])
+                ]
+            )
+        self.assert_sql_count(testing.db, go, 1)
         
 class FlushTest(_base.MappedTest):
     """test dependency sorting among inheriting mappers"""
index 51b6d4970a5a92c6c55e623b69e5bb3e39967b50..834f18fc86f9b28a989b816588e2773a371bfdea 100644 (file)
@@ -436,8 +436,8 @@ class RelationTest5(_base.MappedTest):
            Column('car_id', Integer, primary_key=True),
            Column('owner', Integer, ForeignKey('people.person_id')))
 
-    def testeagerempty(self):
-        """an easy one...test parent object with child relation to an inheriting mapper, using eager loads,
+    def test_eager_empty(self):
+        """test parent object with child relation to an inheriting mapper, using eager loads,
         works when there are no child objects present"""
         class Person(object):
             def __init__(self, **kwargs):
index 5b57e8f4575e2fe2a6ef62b70012b64904fdda31..1dd96ef1b480d681d9869671cdf2ab4ebaae4aa5 100644 (file)
@@ -472,7 +472,7 @@ def _produce_test(select_type):
                 [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])]
                 )
             self.assert_sql_count(testing.db, go, 1)
-            
+
         def test_join_to_subclass(self):
             sess = create_session()
             eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])