]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
getting inheritance to work. some complex cases may have to fail for the time being.
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 Mar 2010 00:23:01 +0000 (20:23 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 Mar 2010 00:23:01 +0000 (20:23 -0400)
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/inheritance/test_query.py

index 5e7a2028e67f5e1224b40ed191a67145e1828e84..43b4e6d77adb5fe3411dc7b97c97315acb5042ed 100644 (file)
@@ -198,7 +198,13 @@ class Query(object):
     @_generative()
     def _adapt_all_clauses(self):
         self._disable_orm_filtering = True
-
+    
+    def _adapt_col_list(self, cols):
+        return [
+                    self._adapt_clause(expression._literal_as_text(o), True, True) 
+                    for o in cols
+                ]
+        
     def _adapt_clause(self, clause, as_filter, orm_only):
         adapters = []
         if as_filter and self._filter_aliases:
@@ -773,7 +779,6 @@ class Query(object):
 
         return self.filter(sql.and_(*clauses))
 
-
     @_generative(_no_statement_condition, _no_limit_offset)
     @util.accepts_a_list_as_starargs(list_deprecation='deprecated')
     def order_by(self, *criterion):
@@ -782,7 +787,7 @@ class Query(object):
         if len(criterion) == 1 and criterion[0] is None:
             self._order_by = None
         else:
-            criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion]
+            criterion = self._adapt_col_list(criterion)
 
             if self._order_by is False or self._order_by is None:
                 self._order_by = criterion
@@ -796,7 +801,7 @@ class Query(object):
 
         criterion = list(chain(*[_orm_columns(c) for c in criterion]))
 
-        criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion]
+        criterion = self._adapt_col_list(criterion)
 
         if self._group_by is False:
             self._group_by = criterion
@@ -2147,7 +2152,7 @@ class _MapperEntity(_QueryEntity):
         self._with_polymorphic = with_polymorphic
         self._polymorphic_discriminator = None
         self.is_aliased_class = is_aliased_class
-        self.disable_aliasing = False
+        self._subq_aliasing = False
         if is_aliased_class:
             self.path_entity = self.entity = self.entity_zero = entity
         else:
@@ -2179,8 +2184,6 @@ class _MapperEntity(_QueryEntity):
         query._entities.append(self)
 
     def _get_entity_clauses(self, query, context):
-        if self.disable_aliasing:
-            return None
             
         adapter = None
         if not self.is_aliased_class and query._polymorphic_adapters:
@@ -2188,7 +2191,11 @@ class _MapperEntity(_QueryEntity):
 
         if not adapter and self.adapter:
             adapter = self.adapter
-
+        
+        # special flag set by subquery loader
+        if self._subq_aliasing:
+            return adapter
+            
         if adapter:
             if query._from_obj_alias:
                 ret = adapter.wrap(query._from_obj_alias)
index b6ca1090d7ec4765c1ce1febbd695163143982b9..0e5e2efdfc47eae8455b2f31f50eb3ae7f142797 100644 (file)
@@ -647,7 +647,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
 
         if not context.query._enable_eagerloads:
             return
-
+        
         path = path + (self.key, )
 
         # build up a path indicating the path from the leftmost
@@ -657,7 +657,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
         subq_path = subq_path + path
 
         reduced_path = interfaces._reduce_path(subq_path)
-
+        
         # check for join_depth or basic recursion,
         # if the current path was not explicitly stated as 
         # a desired "loaderstrategy" (i.e. via query.options())
@@ -680,11 +680,14 @@ class SubqueryLoader(AbstractRelationshipLoader):
             
         orig_query = context.attributes[("orig_query", SubqueryLoader)]
 
-        
         local_cols, remote_cols = self._local_remote_columns(self.parent_property)
         
-        leftmost_mapper, leftmost_prop = \
-                            subq_path[0], subq_path[0].get_property(subq_path[1])
+        if self.parent.isa(subq_path[0]) and self.key==subq_path[1]:
+            leftmost_mapper, leftmost_prop = \
+                                self.parent, self.parent_property
+        else:
+            leftmost_mapper, leftmost_prop = \
+                                subq_path[0], subq_path[0].get_property(subq_path[1])
         leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop)
         
         leftmost_attr = [
@@ -692,23 +695,24 @@ class SubqueryLoader(AbstractRelationshipLoader):
             for c in leftmost_cols
         ]
 
-        # modify the query to just look for parent columns in the 
-        # join condition
-        
         # set the original query to only look
         # for the significant columns, not order
         # by anything.
         q = orig_query._clone()
         q._attributes = {}
         q._attributes[("orig_query", SubqueryLoader)] = orig_query
-        q._set_entities(leftmost_attr)
+        q._set_entities(q._adapt_col_list(leftmost_attr))
         if q._limit is None and q._offset is None:
             q._order_by = None
+            
+        q = q.from_self(self.mapper)
         
-        q._attributes[('subquery_path', None)] = subq_path
+        # TODO: this is currently a magic hardcody
+        # flag on _MapperEntity.  we should find 
+        # a way to turn it into public functionality.
+        q._entities[0]._subq_aliasing = True
 
-        q = q.from_self(self.mapper)
-        q._entities[0].disable_aliasing = True
+        q._attributes[('subquery_path', None)] = subq_path
 
         to_join = [
                     (subq_path[i], subq_path[i+1]) 
@@ -726,14 +730,17 @@ class SubqueryLoader(AbstractRelationshipLoader):
                 getattr(parent_alias, self.parent._get_col_to_prop(c).key)
                 for c in local_cols
             ]
-        q = q.add_columns(*local_attr)
         q = q.order_by(*local_attr)
-            
+        q = q.add_columns(*local_attr)
+        
         for i, (mapper, key) in enumerate(to_join):
             alias_join = i < len(to_join) - 1
             second_to_last = i == len(to_join) - 2
             
-            prop = mapper.get_property(key)
+            if i == 0:
+                prop = leftmost_prop
+            else:
+                prop = mapper.get_property(key)
             
             if second_to_last:
                 q = q.join((parent_alias, prop.class_attribute))
@@ -762,7 +769,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
         
         # this key is for the row_processor to pick up
         # within this same loader.
-        context.attributes[('subquery', path)] = q
+        context.attributes[('subquery', interfaces._reduce_path(path))] = q
     
     def _local_remote_columns(self, prop):
         if prop.secondary is None:
@@ -777,6 +784,8 @@ class SubqueryLoader(AbstractRelationshipLoader):
         
     def create_row_processor(self, context, path, mapper, row, adapter):
         path = path + (self.key,)
+
+        path = interfaces._reduce_path(path)
         
         if ('subquery', path) not in context.attributes:
             return None, None
@@ -825,6 +834,8 @@ class SubqueryLoader(AbstractRelationshipLoader):
             
         return execute, None
 
+log.class_logger(SubqueryLoader)
+
 class EagerLoader(AbstractRelationshipLoader):
     """Strategize a relationship() that loads within the process 
     of the parent object being selected."""
index f7eb5d5e40be6c291a7a9403980f38989013d3fe..e1118a3f8acb8c263cbcf1e64aa35234e3a4204e 100644 (file)
@@ -187,11 +187,23 @@ def _produce_test(select_type):
 
         def test_primary_eager_aliasing(self):
             sess = create_session()
+
+            # for both eagerload() and subqueryload(), if the original q is not loading
+            # the subclass table, the eagerload doesn't happen.
             
             def go():
                 eq_(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
             self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4))
 
+            # additionally, subqueryload() can't handle from_self() on the union.
+            # I'm not too concerned about that.
+            sess = create_session()
+            
+            @testing.fails_if(lambda:select_type == 'Unions')
+            def go():
+                eq_(sess.query(Person).options(subqueryload(Engineer.machines)).all(), all_employees)
+            self.assert_sql_count(testing.db, go, {'':14, 'Unions':3, 'Polymorphic':7}.get(select_type, 8))
+
             sess = create_session()
 
             # assert the JOINs dont over JOIN
@@ -199,7 +211,10 @@ def _produce_test(select_type):
                                     limit(2).offset(1).with_labels().subquery().count().scalar() == 2
 
             def go():
-                eq_(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
+                eq_(
+                    sess.query(Person).with_polymorphic('*').
+                        options(eagerload(Engineer.machines))[1:3], 
+                    all_employees[1:3])
             self.assert_sql_count(testing.db, go, 3)
             
             
@@ -489,11 +504,26 @@ def _produce_test(select_type):
             def go():
                 # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer).  eagerloader doesn't
                 # pick up on the "of_type()" as of yet.
-                eq_(sess.query(Company).options(eagerload_all(Company.employees.of_type(Engineer), Engineer.machines)).all(), assert_result)
+                eq_(
+                        sess.query(Company).options(
+                                                eagerload_all(Company.employees.of_type(Engineer), Engineer.machines
+                                            )).all(), 
+                                        assert_result)
             
             # in the case of select_type='', the eagerload doesn't take in this case; 
             # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines"            
             self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2))
+            
+            sess = create_session()
+            @testing.fails_if(lambda: select_type=='Unions')
+            def go():
+                eq_(
+                        sess.query(Company).options(
+                                                subqueryload_all(Company.employees.of_type(Engineer), Engineer.machines
+                                            )).all(), 
+                                        assert_result)
+        
+            self.assert_sql_count(testing.db, go, {'':9, 'Joins':6,'Unions':3,'Polymorphic':5,'AliasedJoins':6}[select_type])
     
         def test_eagerload_on_subclass(self):
             sess = create_session()
@@ -504,6 +534,14 @@ def _produce_test(select_type):
                 )
             self.assert_sql_count(testing.db, go, 1)
 
+            sess = create_session()
+            def go():
+                # test load People with subqueryload to engineers + machines
+                eq_(sess.query(Person).with_polymorphic('*').options(subqueryload(Engineer.machines)).filter(Person.name=='dilbert').all(), 
+                [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])]
+                )
+            self.assert_sql_count(testing.db, go, 2)
+
             
         def test_query_subclass_join_to_base_relationship(self):
             sess = create_session()
@@ -1147,7 +1185,19 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL):
         assert q.limit(1).with_labels().subquery().count().scalar() == 1
         
         assert q.first() is c1
-
+    
+    def test_subquery_load(self):
+        session = create_session()
+        
+        c1 = Child1()
+        c1.left_child2 = Child2()
+        session.add(c1)
+        session.flush()
+        session.expunge_all()
+        
+        for row in session.query(Child1).options(subqueryload('left_child2')).all():
+            assert row.left_child2
+        
 class EagerToSubclassTest(_base.MappedTest):
     """Test eagerloads to subclass mappers"""