From: Mike Bayer Date: Wed, 24 Mar 2010 00:23:01 +0000 (-0400) Subject: getting inheritance to work. some complex cases may have to fail for the time being. X-Git-Tag: rel_0_6beta3~12^2~26 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=90b5ac47cddfc97df05cd30e33149f963090c0f0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git getting inheritance to work. some complex cases may have to fail for the time being. --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5e7a2028e6..43b4e6d77a 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -198,7 +198,13 @@ class Query(object): @_generative() def _adapt_all_clauses(self): self._disable_orm_filtering = True - + + def _adapt_col_list(self, cols): + return [ + self._adapt_clause(expression._literal_as_text(o), True, True) + for o in cols + ] + def _adapt_clause(self, clause, as_filter, orm_only): adapters = [] if as_filter and self._filter_aliases: @@ -773,7 +779,6 @@ class Query(object): return self.filter(sql.and_(*clauses)) - @_generative(_no_statement_condition, _no_limit_offset) @util.accepts_a_list_as_starargs(list_deprecation='deprecated') def order_by(self, *criterion): @@ -782,7 +787,7 @@ class Query(object): if len(criterion) == 1 and criterion[0] is None: self._order_by = None else: - criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion] + criterion = self._adapt_col_list(criterion) if self._order_by is False or self._order_by is None: self._order_by = criterion @@ -796,7 +801,7 @@ class Query(object): criterion = list(chain(*[_orm_columns(c) for c in criterion])) - criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion] + criterion = self._adapt_col_list(criterion) if self._group_by is False: self._group_by = criterion @@ -2147,7 +2152,7 @@ class _MapperEntity(_QueryEntity): self._with_polymorphic = with_polymorphic self._polymorphic_discriminator = None self.is_aliased_class = is_aliased_class - self.disable_aliasing = False + self._subq_aliasing = False if is_aliased_class: self.path_entity = self.entity = self.entity_zero = entity else: @@ -2179,8 +2184,6 @@ class _MapperEntity(_QueryEntity): query._entities.append(self) def _get_entity_clauses(self, query, context): - if self.disable_aliasing: - return None adapter = None if not self.is_aliased_class and query._polymorphic_adapters: @@ -2188,7 +2191,11 @@ class _MapperEntity(_QueryEntity): if not adapter and self.adapter: adapter = self.adapter - + + # special flag set by subquery loader + if self._subq_aliasing: + return adapter + if adapter: if query._from_obj_alias: ret = adapter.wrap(query._from_obj_alias) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index b6ca1090d7..0e5e2efdfc 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -647,7 +647,7 @@ class SubqueryLoader(AbstractRelationshipLoader): if not context.query._enable_eagerloads: return - + path = path + (self.key, ) # build up a path indicating the path from the leftmost @@ -657,7 +657,7 @@ class SubqueryLoader(AbstractRelationshipLoader): subq_path = subq_path + path reduced_path = interfaces._reduce_path(subq_path) - + # check for join_depth or basic recursion, # if the current path was not explicitly stated as # a desired "loaderstrategy" (i.e. via query.options()) @@ -680,11 +680,14 @@ class SubqueryLoader(AbstractRelationshipLoader): orig_query = context.attributes[("orig_query", SubqueryLoader)] - local_cols, remote_cols = self._local_remote_columns(self.parent_property) - leftmost_mapper, leftmost_prop = \ - subq_path[0], subq_path[0].get_property(subq_path[1]) + if self.parent.isa(subq_path[0]) and self.key==subq_path[1]: + leftmost_mapper, leftmost_prop = \ + self.parent, self.parent_property + else: + leftmost_mapper, leftmost_prop = \ + subq_path[0], subq_path[0].get_property(subq_path[1]) leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop) leftmost_attr = [ @@ -692,23 +695,24 @@ class SubqueryLoader(AbstractRelationshipLoader): for c in leftmost_cols ] - # modify the query to just look for parent columns in the - # join condition - # set the original query to only look # for the significant columns, not order # by anything. q = orig_query._clone() q._attributes = {} q._attributes[("orig_query", SubqueryLoader)] = orig_query - q._set_entities(leftmost_attr) + q._set_entities(q._adapt_col_list(leftmost_attr)) if q._limit is None and q._offset is None: q._order_by = None + + q = q.from_self(self.mapper) - q._attributes[('subquery_path', None)] = subq_path + # TODO: this is currently a magic hardcody + # flag on _MapperEntity. we should find + # a way to turn it into public functionality. + q._entities[0]._subq_aliasing = True - q = q.from_self(self.mapper) - q._entities[0].disable_aliasing = True + q._attributes[('subquery_path', None)] = subq_path to_join = [ (subq_path[i], subq_path[i+1]) @@ -726,14 +730,17 @@ class SubqueryLoader(AbstractRelationshipLoader): getattr(parent_alias, self.parent._get_col_to_prop(c).key) for c in local_cols ] - q = q.add_columns(*local_attr) q = q.order_by(*local_attr) - + q = q.add_columns(*local_attr) + for i, (mapper, key) in enumerate(to_join): alias_join = i < len(to_join) - 1 second_to_last = i == len(to_join) - 2 - prop = mapper.get_property(key) + if i == 0: + prop = leftmost_prop + else: + prop = mapper.get_property(key) if second_to_last: q = q.join((parent_alias, prop.class_attribute)) @@ -762,7 +769,7 @@ class SubqueryLoader(AbstractRelationshipLoader): # this key is for the row_processor to pick up # within this same loader. - context.attributes[('subquery', path)] = q + context.attributes[('subquery', interfaces._reduce_path(path))] = q def _local_remote_columns(self, prop): if prop.secondary is None: @@ -777,6 +784,8 @@ class SubqueryLoader(AbstractRelationshipLoader): def create_row_processor(self, context, path, mapper, row, adapter): path = path + (self.key,) + + path = interfaces._reduce_path(path) if ('subquery', path) not in context.attributes: return None, None @@ -825,6 +834,8 @@ class SubqueryLoader(AbstractRelationshipLoader): return execute, None +log.class_logger(SubqueryLoader) + class EagerLoader(AbstractRelationshipLoader): """Strategize a relationship() that loads within the process of the parent object being selected.""" diff --git a/test/orm/inheritance/test_query.py b/test/orm/inheritance/test_query.py index f7eb5d5e40..e1118a3f8a 100644 --- a/test/orm/inheritance/test_query.py +++ b/test/orm/inheritance/test_query.py @@ -187,11 +187,23 @@ def _produce_test(select_type): def test_primary_eager_aliasing(self): sess = create_session() + + # for both eagerload() and subqueryload(), if the original q is not loading + # the subclass table, the eagerload doesn't happen. def go(): eq_(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4)) + # additionally, subqueryload() can't handle from_self() on the union. + # I'm not too concerned about that. + sess = create_session() + + @testing.fails_if(lambda:select_type == 'Unions') + def go(): + eq_(sess.query(Person).options(subqueryload(Engineer.machines)).all(), all_employees) + self.assert_sql_count(testing.db, go, {'':14, 'Unions':3, 'Polymorphic':7}.get(select_type, 8)) + sess = create_session() # assert the JOINs dont over JOIN @@ -199,7 +211,10 @@ def _produce_test(select_type): limit(2).offset(1).with_labels().subquery().count().scalar() == 2 def go(): - eq_(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) + eq_( + sess.query(Person).with_polymorphic('*'). + options(eagerload(Engineer.machines))[1:3], + all_employees[1:3]) self.assert_sql_count(testing.db, go, 3) @@ -489,11 +504,26 @@ def _produce_test(select_type): def go(): # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer). eagerloader doesn't # pick up on the "of_type()" as of yet. - eq_(sess.query(Company).options(eagerload_all(Company.employees.of_type(Engineer), Engineer.machines)).all(), assert_result) + eq_( + sess.query(Company).options( + eagerload_all(Company.employees.of_type(Engineer), Engineer.machines + )).all(), + assert_result) # in the case of select_type='', the eagerload doesn't take in this case; # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines" self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2)) + + sess = create_session() + @testing.fails_if(lambda: select_type=='Unions') + def go(): + eq_( + sess.query(Company).options( + subqueryload_all(Company.employees.of_type(Engineer), Engineer.machines + )).all(), + assert_result) + + self.assert_sql_count(testing.db, go, {'':9, 'Joins':6,'Unions':3,'Polymorphic':5,'AliasedJoins':6}[select_type]) def test_eagerload_on_subclass(self): sess = create_session() @@ -504,6 +534,14 @@ def _produce_test(select_type): ) self.assert_sql_count(testing.db, go, 1) + sess = create_session() + def go(): + # test load People with subqueryload to engineers + machines + eq_(sess.query(Person).with_polymorphic('*').options(subqueryload(Engineer.machines)).filter(Person.name=='dilbert').all(), + [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])] + ) + self.assert_sql_count(testing.db, go, 2) + def test_query_subclass_join_to_base_relationship(self): sess = create_session() @@ -1147,7 +1185,19 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): assert q.limit(1).with_labels().subquery().count().scalar() == 1 assert q.first() is c1 - + + def test_subquery_load(self): + session = create_session() + + c1 = Child1() + c1.left_child2 = Child2() + session.add(c1) + session.flush() + session.expunge_all() + + for row in session.query(Child1).options(subqueryload('left_child2')).all(): + assert row.left_child2 + class EagerToSubclassTest(_base.MappedTest): """Test eagerloads to subclass mappers"""