]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bugs in Query regarding simultaneous selection of
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 Feb 2009 17:14:05 +0000 (17:14 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 Feb 2009 17:14:05 +0000 (17:14 +0000)
multiple joined-table inheritance entities with common base
classes, previously the adaption applied to "e2" on
"e1 JOIN e2" would be partially applied to "e1".  Additionally,
comparisons on relations (i.e. Entity2.related==e2)
were not getting adapted correctly.

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
test/orm/inheritance/query.py

diff --git a/CHANGES b/CHANGES
index be009b4efe1d4fd048d0681ec0d8f10add90410f..605bd5501440a47c728a0d79890b84d6ffac65dc 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -32,6 +32,13 @@ CHANGES
       so that the "listen_for_events.py" example works again.
       [ticket:1314]
       
+    - Fixed bugs in Query regarding simultaneous selection of 
+      multiple joined-table inheritance entities with common base 
+      classes, previously the adaption applied to "e2" on 
+      "e1 JOIN e2" would be partially applied to "e1".  Additionally,
+      comparisons on relations (i.e. Entity2.related==e2)
+      were not getting adapted correctly.
+      
 - sql
     - Fixed missing _label attribute on Function object, others
       when used in a select() with use_labels (such as when used
index d5857d965433fab49ad63166f3c7915ec74116e7..42c3bffa3bb0aad65a8b2928936dbd6c8909b517 100644 (file)
@@ -485,11 +485,11 @@ class RelationProperty(StrategizedProperty):
                 if self.property.direction in [ONETOMANY, MANYTOMANY]:
                     return ~self._criterion_exists()
                 else:
-                    return self.property._optimized_compare(None, adapt_source=self.adapter)
+                    return _orm_annotate(self.property._optimized_compare(None, adapt_source=self.adapter))
             elif self.property.uselist:
                 raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.")
             else:
-                return self.property._optimized_compare(other, adapt_source=self.adapter)
+                return _orm_annotate(self.property._optimized_compare(other, adapt_source=self.adapter))
 
         def _criterion_exists(self, criterion=None, **kwargs):
             if getattr(self, '_of_type', None):
index b2c68faa208e9deb6f7cf8ed1464ae81d089dff9..5bef348a1ff5b92898cb768e2bc17af3a691bd2d 100644 (file)
@@ -127,6 +127,7 @@ class Query(object):
 
     def __mapper_loads_polymorphically_with(self, mapper, adapter):
         for m2 in mapper._with_polymorphic_mappers:
+            self._polymorphic_adapters[m2] = adapter
             for m in m2.iterate_to_root():
                 self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
 
@@ -145,6 +146,7 @@ class Query(object):
 
     def _reset_polymorphic_adapter(self, mapper):
         for m2 in mapper._with_polymorphic_mappers:
+            self._polymorphic_adapters.pop(m2, None)
             for m in m2.iterate_to_root():
                 self._polymorphic_adapters.pop(m.mapped_table, None)
                 self._polymorphic_adapters.pop(m.local_table, None)
@@ -1892,10 +1894,7 @@ class _MapperEntity(_QueryEntity):
 
         adapter = None
         if not self.is_aliased_class and query._polymorphic_adapters:
-            for mapper in self.mapper.iterate_to_root():
-                adapter = query._polymorphic_adapters.get(mapper.mapped_table, None)
-                if adapter:
-                    break
+            adapter = query._polymorphic_adapters.get(self.mapper, None)
 
         if not adapter and self.adapter:
             adapter = self.adapter
index fe948931b6acfb12998906a00bf1e4eedcf45ebe..07bf068b72ac62fcdddd8b2381d52de64d8b65d9 100644 (file)
@@ -9,6 +9,8 @@ from sqlalchemy.orm import *
 from sqlalchemy import exc as sa_exc
 from testlib import *
 from testlib import fixtures
+from orm import _base
+from testlib.testing import eq_
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.engine import default
 
@@ -773,7 +775,7 @@ class SelfReferentialTestJoinedToJoined(ORMTest):
         
         mapper(Engineer, engineers, inherits=Person, 
           polymorphic_identity='engineer', properties={
-          'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id)
+          'reports_to':relation(Manager, primaryjoin=managers.c.person_id==engineers.c.reports_to_id, backref='engineers')
         })
 
     def test_has(self):
@@ -800,6 +802,33 @@ class SelfReferentialTestJoinedToJoined(ORMTest):
         self.assertEquals(
             sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), 
             Engineer(name='dilbert'))
+    
+    def test_relation_compare(self):
+        m1 = Manager(name='dogbert')
+        m2 = Manager(name='foo')
+        e1 = Engineer(name='dilbert', primary_language='java', reports_to=m1)
+        e2 = Engineer(name='wally', primary_language='c++', reports_to=m2)
+        e3 = Engineer(name='etc', primary_language='c++')
+        sess = create_session()
+        sess.add(m1)
+        sess.add(m2)
+        sess.add(e1)
+        sess.add(e2)
+        sess.add(e3)
+        sess.flush()
+        sess.expunge_all()
+
+        self.assertEquals(
+            sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), 
+            []
+        )
+
+        self.assertEquals(
+            sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), 
+            [m1]
+        )
+        
+        
         
 
 class M2MFilterTest(ORMTest):
@@ -868,6 +897,8 @@ class M2MFilterTest(ORMTest):
         self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
 
 class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL):
+    keep_mappers = True
+    
     def define_tables(self, metadata):
         Base = declarative_base(metadata=metadata)
 
@@ -895,9 +926,50 @@ class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL):
         Child1.left_child2 = relation(Child2, secondary = secondary_table,
                primaryjoin = Parent.id == secondary_table.c.right_id,
                secondaryjoin = Parent.id == secondary_table.c.left_id,
-               uselist = False,
+               uselist = False, backref="right_children"
                                )
 
+    
+    def test_query_crit(self):
+        session = create_session()
+        c11, c12, c13 = Child1(), Child1(), Child1()
+        c21, c22, c23 = Child2(), Child2(), Child2()
+        
+        c11.left_child2 = c22
+        c12.left_child2 = c22
+        c13.left_child2 = c23
+        
+        session.add_all([c11, c12, c13, c21, c22, c23])
+        session.flush()
+        
+        # test that the join to Child2 doesn't alias Child1 in the select
+        eq_(
+            set(session.query(Child1).join(Child1.left_child2)), 
+            set([c11, c12, c13])
+        )
+
+        eq_(
+            set(session.query(Child1, Child2).join(Child1.left_child2)), 
+            set([(c11, c22), (c12, c22), (c13, c23)])
+        )
+
+        # test __eq__() on property is annotating correctly
+        eq_(
+            set(session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22)),
+            set([c22])
+        )
+
+        # test the same again
+        self.assert_compile(
+            session.query(Child2).join(Child2.right_children).filter(Child1.left_child2==c22).with_labels().statement,
+            "SELECT parent.id AS parent_id, child2.id AS child2_id, parent.cls AS parent_cls FROM "
+            "secondary AS secondary_1, parent JOIN child2 ON parent.id = child2.id JOIN secondary AS secondary_2 "
+            "ON parent.id = secondary_2.left_id JOIN (SELECT parent.id AS parent_id, parent.cls AS parent_cls, "
+            "child1.id AS child1_id FROM parent JOIN child1 ON parent.id = child1.id) AS anon_1 ON "
+            "anon_1.parent_id = secondary_2.right_id WHERE anon_1.parent_id = secondary_1.right_id AND :param_1 = secondary_1.left_id",
+            dialect=default.DefaultDialect()
+        )
+
     def test_eager_join(self):
         session = create_session()