]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- property.of_type() is now recognized on a single-table
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jan 2009 19:23:56 +0000 (19:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jan 2009 19:23:56 +0000 (19:23 +0000)
inheriting target, when used in the context of
prop.of_type(..).any()/has(), as well as
query.join(prop.of_type(...)).

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/util.py
test/orm/inheritance/single.py

diff --git a/CHANGES b/CHANGES
index ebf0f75a2b46aef44aadb2a96fad7558c9a29b99..444d90fe921b069c048d18830ac7cc89c8a7e27e 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -81,6 +81,11 @@ CHANGES
       next compile() call.  This issue occurs frequently
       when using declarative.
 
+    - property.of_type() is now recognized on a single-table
+      inheriting target, when used in the context of 
+      prop.of_type(..).any()/has(), as well as 
+      query.join(prop.of_type(...)).
+      
     - Fixed bug when using weak_instance_map=False where modified
       events would not be intercepted for a flush(). [ticket:1272]
       
index 5c0629dafb9077357351ddf56d112f6854a0985f..31661783523d55e39a9efbc7c9ac9de8fa99574c 100644 (file)
@@ -839,6 +839,19 @@ class Mapper(object):
 
         return from_obj
 
+    @property
+    def _single_table_criterion(self):
+        if self.single and \
+            self.inherits and \
+            self.polymorphic_on and \
+            self.polymorphic_identity is not None:
+            return self.polymorphic_on.in_(
+                m.polymorphic_identity
+                for m in self.polymorphic_iterator())
+        else:
+            return None
+        
+    
     @util.memoized_property
     def _with_polymorphic_mappers(self):
         if not self.with_polymorphic:
index 675b505e78d2e0c99935f7f722f3a4307717571c..343b73f4270ef78e862ac80d932768179e02fa1a 100644 (file)
@@ -386,6 +386,13 @@ class RelationProperty(StrategizedProperty):
                 to_selectable = target_mapper._with_polymorphic_selectable
                 if self.prop._is_self_referential():
                     to_selectable = to_selectable.alias()
+
+                single_crit = target_mapper._single_table_criterion
+                if single_crit:
+                    if criterion is not None:
+                        criterion = single_crit & criterion
+                    else:
+                        criterion = single_crit
             else:
                 to_selectable = None
 
@@ -393,6 +400,7 @@ class RelationProperty(StrategizedProperty):
                 source_selectable = self.__clause_element__()
             else:
                 source_selectable = None
+                
             pj, sj, source, dest, secondary, target_adapter = \
                 self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable)
 
@@ -863,8 +871,8 @@ class RelationProperty(StrategizedProperty):
     def _is_self_referential(self):
         return self.mapper.common_parent(self.parent)
 
-    def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None):
-        key = util.WeakCompositeKey(source_polymorphic, source_selectable, dest_polymorphic, dest_selectable)
+    def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None, of_type=None):
+        key = util.WeakCompositeKey(source_polymorphic, source_selectable, dest_polymorphic, dest_selectable, of_type)
         try:
             return self.__join_cache[key]
         except KeyError:
@@ -896,14 +904,15 @@ class RelationProperty(StrategizedProperty):
         # in the case that the join is to a subclass
         # this is analgous to the "_adjust_for_single_table_inheritance()"
         # method in Query.
-        if self.mapper.single and self.mapper.inherits and self.mapper.polymorphic_on and self.mapper.polymorphic_identity is not None:
-            crit = self.mapper.polymorphic_on.in_(
-                m.polymorphic_identity
-                for m in self.mapper.polymorphic_iterator())
+
+        dest_mapper = of_type or self.mapper
+        
+        single_crit = dest_mapper._single_table_criterion
+        if single_crit:
             if secondaryjoin:
-                secondaryjoin = secondaryjoin & crit
+                secondaryjoin = secondaryjoin & single_crit
             else:
-                primaryjoin = primaryjoin & crit
+                primaryjoin = primaryjoin & single_crit
             
 
         if aliased:
index e733cc0242956f7976dd7ebb3588cdaeffc73281..7a5721761e0b47929cd6c7c0825b9377fdd7ee48 100644 (file)
@@ -887,7 +887,8 @@ class Query(object):
                         self.__currenttables.add(prop.secondary)
                     self.__currenttables.add(prop.table)
 
-                    right_entity = prop.mapper
+                    if not right_entity:
+                        right_entity = prop.mapper
 
             if alias_criterion:
                 right_adapter = ORMAdapter(right_entity,
@@ -1618,14 +1619,12 @@ class Query(object):
 
         """
         for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems():
-            if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None:
-                crit = mapper.polymorphic_on.in_(
-                    m.polymorphic_identity
-                    for m in mapper.polymorphic_iterator())
+            single_crit = mapper._single_table_criterion
+            if single_crit:
                 if adapter:
-                    crit = adapter.traverse(crit)
-                crit = self._adapt_clause(crit, False, False)
-                context.whereclause = sql.and_(context.whereclause, crit)
+                    single_crit = adapter.traverse(single_crit)
+                single_crit = self._adapt_clause(single_crit, False, False)
+                context.whereclause = sql.and_(context.whereclause, single_crit)
 
     def __str__(self):
         return str(self._compile_context().statement)
index 411c827c6759de4d21c70da50577811195863f0b..4f99586da4e0cdde67bae0f69a73059e99643cf6 100644 (file)
@@ -386,7 +386,7 @@ class _ORMJoin(expression.Join):
                 prop = None
 
             if prop:
-                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True)
+                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True, of_type=right_mapper)
 
                 if sj:
                     left = sql.join(left, secondary, pj, isouter)
index 65e8a2c79b811bae98bf11d54296f18ddcad587e..19c9a83f25f9e8c29318179816b2012227cbe441 100644 (file)
@@ -5,17 +5,22 @@ from testlib import *
 from testlib.fixtures import Base
 from orm._base import MappedTest, ComparableEntity
 
+
 class SingleInheritanceTest(MappedTest):
     def define_tables(self, metadata):
-        global employees_table
-        employees_table = Table('employees', metadata,
+        Table('employees', metadata,
             Column('employee_id', Integer, primary_key=True),
             Column('name', String(50)),
             Column('manager_data', String(50)),
             Column('engineer_info', String(50)),
-            Column('type', String(20))
+            Column('type', String(20)))
+
+        Table('reports', metadata,
+              Column('report_id', Integer, primary_key=True),
+              Column('employee_id', ForeignKey('employees.employee_id')),
+              Column('name', String(50)),
         )
-    
+
     def setup_classes(self):
         class Employee(ComparableEntity):
             pass
@@ -28,7 +33,7 @@ class SingleInheritanceTest(MappedTest):
 
     @testing.resolve_artifact_names
     def setup_mappers(self):
-        mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
+        mapper(Employee, employees, polymorphic_on=employees.c.type)
         mapper(Manager, inherits=Employee, polymorphic_identity='manager')
         mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
         mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
@@ -116,7 +121,7 @@ class SingleInheritanceTest(MappedTest):
         sess.flush()
         
         self.assertEquals(
-            sess.query(Manager).select_from(employees_table.select().limit(10)).all(), 
+            sess.query(Manager).select_from(employees.select().limit(10)).all(), 
             [m1, m2]
         )
         
@@ -137,6 +142,42 @@ class SingleInheritanceTest(MappedTest):
         self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
         self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
 
+    @testing.resolve_artifact_names
+    def test_type_filtering(self):
+        class Report(ComparableEntity): pass
+
+        mapper(Report, reports, properties={
+            'employee': relation(Employee, backref='reports')})
+        sess = create_session()
+
+        m1 = Manager(name='Tom', manager_data='data1')
+        r1 = Report(employee=m1)
+        sess.add_all([m1, r1])
+        sess.flush()
+        rq = sess.query(Report)
+
+        assert len(rq.filter(Report.employee.of_type(Manager).has()).all()) == 1
+        assert len(rq.filter(Report.employee.of_type(Engineer).has()).all()) == 0
+
+    @testing.resolve_artifact_names
+    def test_type_joins(self):
+        class Report(ComparableEntity): pass
+
+        mapper(Report, reports, properties={
+            'employee': relation(Employee, backref='reports')})
+        sess = create_session()
+
+        m1 = Manager(name='Tom', manager_data='data1')
+        r1 = Report(employee=m1)
+        sess.add_all([m1, r1])
+        sess.flush()
+
+        rq = sess.query(Report)
+
+        assert len(rq.join(Report.employee.of_type(Manager)).all()) == 1
+        assert len(rq.join(Report.employee.of_type(Engineer)).all()) == 0
+
+
 class RelationToSingleTest(MappedTest):
     def define_tables(self, metadata):
         Table('employees', metadata,
@@ -166,6 +207,42 @@ class RelationToSingleTest(MappedTest):
         class JuniorEngineer(Engineer):
             pass
 
+    @testing.resolve_artifact_names
+    def test_of_type(self):
+        mapper(Company, companies, properties={
+            'employees':relation(Employee, backref='company')
+        })
+        mapper(Employee, employees, polymorphic_on=employees.c.type)
+        mapper(Manager, inherits=Employee, polymorphic_identity='manager')
+        mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
+        mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
+        sess = sessionmaker()()
+        
+        c1 = Company(name='c1')
+        c2 = Company(name='c2')
+        
+        m1 = Manager(name='Tom', manager_data='data1', company=c1)
+        m2 = Manager(name='Tom2', manager_data='data2', company=c2)
+        e1 = Engineer(name='Kurt', engineer_info='knows how to hack', company=c2)
+        e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed', company=c1)
+        sess.add_all([c1, c2, m1, m2, e1, e2])
+        sess.commit()
+        sess.clear()
+        self.assertEquals(
+            sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(),
+            [
+                Company(name='c1'),
+            ]
+        )
+
+        self.assertEquals(
+            sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(),
+            [
+                Company(name='c1'),
+            ]
+        )
+
+
     @testing.resolve_artifact_names
     def test_relation_to_subclass(self):
         mapper(Company, companies, properties={