From: Mike Bayer Date: Tue, 4 Jun 2013 18:30:29 +0000 (-0400) Subject: rewriting scheme now works. X-Git-Tag: rel_0_9_0b1~294^2~11 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=9998e9e0131ff83a4e38e3c17a835a0854789174;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git rewriting scheme now works. --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 502b54f297..8ab81cb130 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1889,12 +1889,9 @@ class Query(object): aliased_entity = right_mapper and \ not right_is_aliased and \ ( - # TODO: this produces queries that fail the - # compiler transformation in test_polymorphic_rel - isinstance(right_mapper._with_polymorphic_selectable, expression.Alias) - - # current - # right_mapper.with_polymorphic + isinstance( + right_mapper._with_polymorphic_selectable, + expression.Alias) or overlap # test for overlap: # orm/inheritance/relationships.py diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3e159b1126..d245c781a3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1078,86 +1078,58 @@ class SQLCompiler(engine.Compiled): return None def _transform_select_for_nested_joins(self, select): - adapters = [] - stop_on = [] - - # test for "unconditional" - any statement with - # no_replacement_traverse setup, i.e. query.statement, from_self(), etc. - #traverse_options = {"cloned": {}, "unconditional": True} - traverse_options = {"unconditional": True} + """Rewrite any "a JOIN (b JOIN c)" expression as + "a JOIN (select * from b JOIN c) AS anon", to support + databases that can't parse a parenthesized join correctly + (i.e. sqlite the main one). + """ cloned = {} - def thing(element, **kw): - if element in cloned: - return cloned[element] - - newelem = cloned[element] = element._clone() - - if newelem.__visit_name__ == 'join' and \ - isinstance(newelem.right, sql.FromGrouping): - selectable = sql.select([newelem.right.element], use_labels=True) - selectable = selectable.alias() - newelem.right = selectable - stop_on.append(selectable) - for c in selectable.c: - c._label = c._key_label = c.name - adapter = sql_util.ClauseAdapter(selectable, - traverse_options=traverse_options) - adapter.magic_flag = True - adapters.append(adapter) - else: - newelem._copy_internals(clone=thing, **kw) - - return newelem + column_translate = [{}] - elem = thing(select) - while adapters: - adapt = adapters.pop(-1) - adapt.__traverse_options__['stop_on'].extend(stop_on) - elem = adapt.traverse(elem) - return elem + join_name = sql.Join.__visit_name__ + select_name = sql.Select.__visit_name__ + def visit(element, **kw): + if element in column_translate[-1]: + return column_translate[-1][element] - def _transform_select_for_nested_joins(self, select): - adapters = [] - stop_on = [] - - # test for "unconditional" - any statement with - # no_replacement_traverse setup, i.e. query.statement, from_self(), etc. - #traverse_options = {"cloned": {}, "unconditional": True} - traverse_options = {"unconditional": True} + elif element in cloned: + return cloned[element] - def visit_join(elem): - if isinstance(elem.right, sql.FromGrouping): - selectable = sql.select([elem.right.element], use_labels=True) - selectable = selectable.alias() + newelem = cloned[element] = element._clone() - while adapters: - adapt = adapters.pop(-1) - selectable = adapt.traverse(selectable) - #stop_on.append(selectable) + if newelem.__visit_name__ is join_name and \ + isinstance(newelem.right, sql.FromGrouping): - # test: see test_subquery_relations: - # CyclicalInheritingEagerTestTwo.test_integrate - stop_on.append(elem.left) + newelem._reset_exported() + newelem.left = visit(newelem.left, **kw) + selectable = sql.select( + [newelem.right.element], + use_labels=True).alias() for c in selectable.c: c._label = c._key_label = c.name + translate_dict = dict( + zip(newelem.right.element.c, selectable.c) + ) + translate_dict[newelem.right.element.left] = selectable + translate_dict[newelem.right.element.right] = selectable + column_translate[-1].update(translate_dict) - elem.right = selectable - adapter = sql_util.ClauseAdapter(selectable, - traverse_options=traverse_options) - adapter.__traverse_options__['stop_on'].extend(stop_on) - adapters.append(adapter) - + newelem.right = selectable + newelem.onclause = visit(newelem.onclause, **kw) + elif newelem.__visit_name__ is select_name: + column_translate.append({}) + newelem._copy_internals(clone=visit, **kw) + del column_translate[-1] + else: + newelem._copy_internals(clone=visit, **kw) - select = visitors.cloned_traverse(select, - traverse_options, {"join": visit_join}) + return newelem - for adap in reversed(adapters): - select = adap.traverse(select) - return select + return visit(select) def _transform_result_map_for_nested_joins(self, select, transformed_select): d = dict(zip(transformed_select.inner_columns, select.inner_columns)) @@ -1172,10 +1144,12 @@ class SQLCompiler(engine.Compiled): positional_names=None, nested_join_translation=False, **kwargs): + needs_nested_translation = \ + not nested_join_translation and \ + not self.stack and \ + not self.dialect.supports_right_nested_joins - if self.dialect.supports_right_nested_joins: - nested_join_translation = True - if not nested_join_translation: + if needs_nested_translation: transformed_select = self._transform_select_for_nested_joins(select) text = self.visit_select( transformed_select, asfrom=asfrom, parens=parens, @@ -1186,8 +1160,6 @@ class SQLCompiler(engine.Compiled): nested_join_translation=True, **kwargs ) - - entry = self.stack and self.stack[-1] or {} populate_result_map = force_result_map or ( @@ -1197,7 +1169,7 @@ class SQLCompiler(engine.Compiled): ) ) - if not nested_join_translation: + if needs_nested_translation: if populate_result_map: self._transform_result_map_for_nested_joins( select, transformed_select) diff --git a/test/sql/test_join_rewriting.py b/test/sql/test_join_rewriting.py index 30cc109d59..ba54acb054 100644 --- a/test/sql/test_join_rewriting.py +++ b/test/sql/test_join_rewriting.py @@ -39,6 +39,12 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL): def test_one(self): j1 = b.join(c) j2 = a.join(j1) + # TODO: if we remove 'b' or 'c', shouldn't we get just + # the subset of cols from anon_1 ? + + # TODO: do this test also with individual cols, things change + # lots based on how you go with this + s = select([a, b, c], use_labels=True).\ select_from(j2).\ where(b.c.id == 2).\ @@ -58,6 +64,8 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_two_froms_overlapping_joins(self): + # test here we're emulating is + # test.orm.inheritance.test_polymorphic_rel:PolymorphicJoinsTest.test_multi_join j1 = b.join(c) j2 = b.join(c).select(use_labels=True).alias() j3 = a.join(j1) @@ -85,35 +93,20 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL): ORDER BY anon_1.b_id """ - """ - SELECT a.id AS a_id, a_1.id AS a_1_id, anon_1.b_id AS b_id, - anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, anon_1.c_b_id AS c_b_id, - anon_2.b_id AS anon_2_b_id, anon_2.b_a_id AS anon_2_b_a_id, - anon_2.c_id AS anon_2_c_id, anon_2.c_b_id AS anon_2_c_b_id - - FROM - - a JOIN ( - SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id, - c.b_id AS c_b_id - FROM b JOIN c ON b.id = c.b_id) AS anon_2 ON a.id = anon_2.b_a_id, - - a AS a_1 JOIN ( - SELECT anon_2.b_id AS anon_2_b_id, anon_2.b_a_id AS anon_2_b_a_id, - anon_2.c_id AS anon_2_c_id, anon_2.c_b_id AS anon_2_c_b_id - FROM ( - SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id, - c.b_id AS c_b_id FROM b JOIN c ON b.id = c.b_id) - AS anon_2 JOIN ( - SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id, c.b_id AS c_b_id - FROM b JOIN c ON b.id = c.b_id) AS anon_2 - ON anon_2.b_id = anon_2.c_b_id) AS anon_1 ON a_1.id = anon_1.b_a_id - - ORDER BY anon_1.b_id - - """ self.assert_compile( s, - "" + "SELECT a.id AS a_id, a_1.id AS a_1_id, anon_1.b_id AS b_id, " + "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, " + "anon_1.c_b_id AS c_b_id, anon_2.b_id AS anon_2_b_id, " + "anon_2.b_a_id AS anon_2_b_a_id, anon_2.c_id AS anon_2_c_id, " + "anon_2.c_b_id AS anon_2_c_b_id FROM a " + "JOIN (SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id, " + "c.b_id AS c_b_id FROM b JOIN c ON b.id = c.b_id) AS anon_1 " + "ON a.id = anon_1.b_a_id, " + "a AS a_1 JOIN " + "(SELECT b.id AS b_id, b.a_id AS b_a_id, " + "c.id AS c_id, c.b_id AS c_b_id " + "FROM b JOIN c ON b.id = c.b_id) AS anon_2 " + "ON a_1.id = anon_2.b_a_id ORDER BY anon_2.b_id" )