From: Mike Bayer Date: Tue, 23 Mar 2010 22:33:31 +0000 (-0400) Subject: this version actually works for all existing tests plus simple self-referential. X-Git-Tag: rel_0_6beta3~12^2~28 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=4d396e5ff0ea111c81605527d415b251d73629f7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git this version actually works for all existing tests plus simple self-referential. I don't like how difficult it was to get Query() to do it, however. --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index f067172174..2dfefc4332 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -134,7 +134,7 @@ class Query(object): self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter def _set_select_from(self, *obj): - + fa = [] for from_obj in obj: if isinstance(from_obj, expression._SelectBaseMixin): @@ -143,9 +143,8 @@ class Query(object): self._from_obj = tuple(fa) - # TODO: only use this adapter for from_self() ? right - # now its usage is somewhat arbitrary. - if len(self._from_obj) == 1 and isinstance(self._from_obj[0], expression.Alias): + if len(self._from_obj) == 1 and \ + isinstance(self._from_obj[0], expression.Alias): equivs = self.__all_equivs() self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj[0], equivs) @@ -625,7 +624,7 @@ class Query(object): if entities: q._set_entities(entities) return q - + @_generative() def _from_selectable(self, fromclause): for attr in ('_statement', '_criterion', '_order_by', '_group_by', @@ -2139,6 +2138,7 @@ class _MapperEntity(_QueryEntity): self._with_polymorphic = with_polymorphic self._polymorphic_discriminator = None self.is_aliased_class = is_aliased_class + self.disable_aliasing = False if is_aliased_class: self.path_entity = self.entity = self.entity_zero = entity else: @@ -2170,7 +2170,9 @@ class _MapperEntity(_QueryEntity): query._entities.append(self) def _get_entity_clauses(self, query, context): - + if self.disable_aliasing: + return None + adapter = None if not self.is_aliased_class and query._polymorphic_adapters: adapter = query._polymorphic_adapters.get(self.mapper, None) @@ -2251,7 +2253,6 @@ class _MapperEntity(_QueryEntity): def __str__(self): return str(self.mapper) - class _ColumnEntity(_QueryEntity): """Column/expression based entity.""" diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index f507bfbe53..4431b408fc 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -692,11 +692,6 @@ class SubqueryLoader(AbstractRelationshipLoader): for c in leftmost_cols ] - local_attr = [ - self.parent._get_col_to_prop(c).class_attribute - for c in local_cols - ] - # modify the query to just look for parent columns in the # join condition @@ -713,24 +708,44 @@ class SubqueryLoader(AbstractRelationshipLoader): q._attributes[('subquery_path', None)] = subq_path # now select from it as a subquery. - q = q.from_self(self.mapper, *local_attr) + local_attr = [ + self.parent._get_col_to_prop(c).class_attribute + for c in local_cols + ] + + q = q.from_self(self.mapper) + q._entities[0].disable_aliasing = True - # and join to the related thing we want - # to load. - for mapper, key in [(subq_path[i], subq_path[i+1]) - for i in xrange(0, len(subq_path), 2)]: + to_join = [(subq_path[i], subq_path[i+1]) + for i in xrange(0, len(subq_path), 2)] + + for i, (mapper, key) in enumerate(to_join): + alias_join = i < len(to_join) - 1 + second_to_last = i == len(to_join) - 2 + prop = mapper.get_property(key) - q = q.join(prop.class_attribute) + q = q.join(prop.class_attribute, aliased=alias_join) - #join_on = [(subq_path[i], subq_path[i+1]) - # for i in xrange(0, len(subq_path), 2)] - #for i, (mapper, key) in enumerate(join_on): - # aliased = i != len(join_on) - 1 - # prop = mapper.get_property(key) - # q = q.join(prop.class_attribute, aliased=aliased) - - q = q.order_by(*local_attr) + if alias_join and second_to_last: + cols = [ + q._adapt_clause(col, True, False) + for col in local_cols + ] + for col in cols: + q = q.add_column(col) + q = q.order_by(*cols) + if len(to_join) < 2: + local_attr = [ + self.parent._get_col_to_prop(c).class_attribute + for c in local_cols + ] + + for col in local_attr: + q = q.add_column(col) + q = q.order_by(*local_attr) + + # propagate loader options etc. to the new query q = q._with_current_path(subq_path) q = q._conditional_options(*orig_query._with_options) @@ -774,7 +789,6 @@ class SubqueryLoader(AbstractRelationshipLoader): local_cols, remote_cols = self._local_remote_columns(self.parent_property) - local_attr = [self.parent._get_col_to_prop(c).key for c in local_cols] remote_attr = [ self.mapper._get_col_to_prop(c).key for c in remote_cols] diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index e1372fbfe6..1be8156862 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -569,7 +569,7 @@ class OrderBySecondaryTest(_base.MappedTest): ]) self.assert_sql_count(testing.db, go, 2) -class SelfReferentialEagerTest(_base.MappedTest): +class SelfReferentialTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('nodes', metadata, @@ -579,7 +579,7 @@ class SelfReferentialEagerTest(_base.MappedTest): @testing.fails_on('maxdb', 'FIXME: unknown') @testing.resolve_artifact_names - def _test_basic(self): + def test_basic(self): class Node(_base.ComparableEntity): def append(self, node): self.children.append(node) @@ -594,13 +594,13 @@ class SelfReferentialEagerTest(_base.MappedTest): n1.append(Node(data='n11')) n1.append(Node(data='n12')) n1.append(Node(data='n13')) -# n1.children[1].append(Node(data='n121')) -# n1.children[1].append(Node(data='n122')) -# n1.children[1].append(Node(data='n123')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) n2 = Node(data='n2') n2.append(Node(data='n21')) -# n2.children[0].append(Node(data='n211')) -# n2.children[0].append(Node(data='n212')) + n2.children[0].append(Node(data='n211')) + n2.children[0].append(Node(data='n212')) sess.add(n1) sess.add(n2) @@ -612,20 +612,20 @@ class SelfReferentialEagerTest(_base.MappedTest): eq_([Node(data='n1', children=[ Node(data='n11'), Node(data='n12', children=[ -# Node(data='n121'), -# Node(data='n122'), -# Node(data='n123') + Node(data='n121'), + Node(data='n122'), + Node(data='n123') ]), Node(data='n13') ]), Node(data='n2', children=[ Node(data='n21', children=[ -# Node(data='n211'), -# Node(data='n212'), + Node(data='n211'), + Node(data='n212'), ]) ]) ], d) - self.assert_sql_count(testing.db, go, 1) + self.assert_sql_count(testing.db, go, 4)