]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- joins along a relation() from a mapped
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Aug 2008 20:31:14 +0000 (20:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Aug 2008 20:31:14 +0000 (20:31 +0000)
class to a mapped subclass, where the mapped
subclass is configured with single table
inheritance, will include an
IN clause which limits the subtypes of the
joined class to those requsted, within the
ON clause of the join.  This takes effect for
eager load joins as well as query.join().
Note that in some scenarios the IN clause will
appear in the WHERE clause of the query
as well since this discrimination has multiple
trigger points.

CHANGES
lib/sqlalchemy/orm/properties.py
test/orm/inheritance/single.py

diff --git a/CHANGES b/CHANGES
index 9c88030c632d42fd91958cb8404b0a1fc2a0b223..2a189ce7c9ecac27584a57794813601cfd02fccb 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -21,7 +21,20 @@ CHANGES
       list of entities.  In particular scalar subqueries
       should not "leak" their inner FROM objects out
       into the enclosing query.
-    
+
+    - joins along a relation() from a mapped
+      class to a mapped subclass, where the mapped
+      subclass is configured with single table
+      inheritance, will include an 
+      IN clause which limits the subtypes of the
+      joined class to those requsted, within the
+      ON clause of the join.  This takes effect for 
+      eager load joins as well as query.join().    
+      Note that in some scenarios the IN clause will 
+      appear in the WHERE clause of the query 
+      as well since this discrimination has multiple
+      trigger points.
+      
     - Improved the behavior of query.join()
       when joining to joined-table inheritance
       subclasses, using explicit join criteria
index f46fd722d0c4c4eeab9dbdd888301b3749eab765..3f61f3ebd24ae0358b54e0b22b0c4d996056a0e8 100644 (file)
@@ -790,6 +790,21 @@ class PropertyLoader(StrategizedProperty):
         aliased = aliased or bool(source_selectable)
 
         primaryjoin, secondaryjoin, secondary = self.primaryjoin, self.secondaryjoin, self.secondary
+        
+        # adjust the join condition for single table inheritance,
+        # in the case that the join is to a subclass
+        # this is analgous to the "_adjust_for_single_table_inheritance()"
+        # method in Query.
+        if self.mapper.single and self.mapper.inherits and self.mapper.polymorphic_on and self.mapper.polymorphic_identity is not None:
+            crit = self.mapper.polymorphic_on.in_(
+                m.polymorphic_identity
+                for m in self.mapper.polymorphic_iterator())
+            if secondaryjoin:
+                secondaryjoin = secondaryjoin & crit
+            else:
+                primaryjoin = primaryjoin & crit
+            
+
         if aliased:
             if secondary:
                 secondary = secondary.alias()
index ba2f930d0af99b745e2bc0b1914f8d17ded7e53a..65e8a2c79b811bae98bf11d54296f18ddcad587e 100644 (file)
@@ -136,6 +136,122 @@ class SingleInheritanceTest(MappedTest):
         
         self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
         self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
+
+class RelationToSingleTest(MappedTest):
+    def define_tables(self, metadata):
+        Table('employees', metadata,
+            Column('employee_id', Integer, primary_key=True),
+            Column('name', String(50)),
+            Column('manager_data', String(50)),
+            Column('engineer_info', String(50)),
+            Column('type', String(20)),
+            Column('company_id', Integer, ForeignKey('companies.company_id'))
+        )
+        
+        Table('companies', metadata,
+            Column('company_id', Integer, primary_key=True),
+            Column('name', String(50)),
+        )
+    
+    def setup_classes(self):
+        class Company(ComparableEntity):
+            pass
+            
+        class Employee(ComparableEntity):
+            pass
+        class Manager(Employee):
+            pass
+        class Engineer(Employee):
+            pass
+        class JuniorEngineer(Engineer):
+            pass
+
+    @testing.resolve_artifact_names
+    def test_relation_to_subclass(self):
+        mapper(Company, companies, properties={
+            'engineers':relation(Engineer)
+        })
+        mapper(Employee, employees, polymorphic_on=employees.c.type, properties={
+            'company':relation(Company)
+        })
+        mapper(Manager, inherits=Employee, polymorphic_identity='manager')
+        mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
+        mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
+        sess = sessionmaker()()
+        
+        c1 = Company(name='c1')
+        c2 = Company(name='c2')
+        
+        m1 = Manager(name='Tom', manager_data='data1', company=c1)
+        m2 = Manager(name='Tom2', manager_data='data2', company=c2)
+        e1 = Engineer(name='Kurt', engineer_info='knows how to hack', company=c2)
+        e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1)
+        sess.add_all([c1, c2, m1, m2, e1, e2])
+        sess.commit()
+
+        self.assertEquals(c1.engineers, [e2])
+        self.assertEquals(c2.engineers, [e1])
+        
+        sess.clear()
+        self.assertEquals(sess.query(Company).order_by(Company.name).all(), 
+            [
+                Company(name='c1', engineers=[JuniorEngineer(name='Ed')]),
+                Company(name='c2', engineers=[Engineer(name='Kurt')])
+            ]
+        )
+
+        # eager load join should limit to only "Engineer"
+        sess.clear()
+        self.assertEquals(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), 
+            [
+                Company(name='c1', engineers=[JuniorEngineer(name='Ed')]),
+                Company(name='c2', engineers=[Engineer(name='Kurt')])
+            ]
+       )
+
+        # join() to Company.engineers, Employee as the requested entity
+        sess.clear()
+        self.assertEquals(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(),
+            [
+                (Company(name='c1'), JuniorEngineer(name='Ed')),
+                (Company(name='c2'), Engineer(name='Kurt'))
+            ]
+        )
+        
+        # join() to Company.engineers, Engineer as the requested entity.
+        # this actually applies the IN criterion twice which is less than ideal.
+        sess.clear()
+        self.assertEquals(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(),
+            [
+                (Company(name='c1'), JuniorEngineer(name='Ed')),
+                (Company(name='c2'), Engineer(name='Kurt'))
+            ]
+        )
+
+        # join() to Company.engineers without any Employee/Engineer entity
+        sess.clear()
+        self.assertEquals(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
+            [
+                Company(name='c2')
+            ]
+        )
+
+        # this however fails as it does not limit the subtypes to just "Engineer".  
+        # with joins constructed by filter(), we seem to be following a policy where
+        # we don't try to make decisions on how to join to the target class, whereas when using join() we
+        # seem to have a lot more capabilities.
+        # we might want to document "advantages of join() vs. straight filtering", or add a large
+        # section to "inheritance" laying out all the various behaviors Query has.
+        @testing.fails_on_everything_except()
+        def go():
+            sess.clear()
+            self.assertEquals(sess.query(Company).\
+                filter(Company.company_id==Engineer.company_id).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
+                [
+                    Company(name='c2')
+                ]
+            )
+        go()
         
 class SingleOnJoinedTest(ORMTest):
     def define_tables(self, metadata):