From d8a38839483aede934b6cbeb6d0828d362767a4d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 4 Jun 2013 19:44:57 -0400 Subject: [PATCH] - support for a__b_dc, i.e. two levels of nesting --- lib/sqlalchemy/sql/compiler.py | 23 ++++- test/sql/test_join_rewriting.py | 166 +++++++++++++++++++++++++------- 2 files changed, 150 insertions(+), 39 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2024666b63..116fb39716 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1109,17 +1109,32 @@ class SQLCompiler(engine.Compiled): newelem._reset_exported() newelem.left = visit(newelem.left, **kw) + right = visit(newelem.right, **kw) + selectable = sql.select( - [newelem.right.element], + [right.element], use_labels=True).alias() for c in selectable.c: c._label = c._key_label = c.name translate_dict = dict( - zip(newelem.right.element.c, selectable.c) + zip(right.element.c, selectable.c) ) - translate_dict[newelem.right.element.left] = selectable - translate_dict[newelem.right.element.right] = selectable + translate_dict[right.element.left] = selectable + translate_dict[right.element.right] = selectable + + # propagate translations that we've gained + # from nested visit(newelem.right) outwards + # to the enclosing select here. this happens + # only when we have more than one level of right + # join nesting, i.e. "a JOIN (b JOIN (c JOIN d))" + for k, v in list(column_translate[-1].items()): + if v in translate_dict: + # remarkably, no current ORM tests (May 2013) + # hit this condition, only test_join_rewriting + # does. + column_translate[-1][k] = translate_dict[v] + column_translate[-1].update(translate_dict) newelem.right = selectable diff --git a/test/sql/test_join_rewriting.py b/test/sql/test_join_rewriting.py index 2cf98be3ff..4034484925 100644 --- a/test/sql/test_join_rewriting.py +++ b/test/sql/test_join_rewriting.py @@ -2,6 +2,8 @@ from sqlalchemy import Table, Column, Integer, MetaData, ForeignKey, select from sqlalchemy.testing import fixtures, AssertsCompiledSQL from sqlalchemy import util from sqlalchemy.engine import default +from sqlalchemy import testing + m = MetaData() @@ -28,16 +30,17 @@ e = Table('e', m, Column('id', Integer, primary_key=True) ) -class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL): - @util.classproperty - def __dialect__(cls): - dialect = default.DefaultDialect() - dialect.supports_right_nested_joins = False - return dialect +class _JoinRewriteTestBase(AssertsCompiledSQL): + def _test(self, s, assert_): + self.assert_compile( + s, + assert_ + ) - def test_one(self): + def test_a_bc(self): j1 = b.join(c) j2 = a.join(j1) + # TODO: if we remove 'b' or 'c', shouldn't we get just # the subset of cols from anon_1 ? @@ -49,20 +52,26 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL): where(b.c.id == 2).\ where(c.c.id == 3).order_by(a.c.id, b.c.id, c.c.id) - self.assert_compile( + self._test(s, self._a_bc) + + def test_a__b_dc(self): + j1 = c.join(d) + j2 = b.join(j1) + j3 = a.join(j2) + + s = select([a, b, c, d], use_labels=True).\ + select_from(j3).\ + where(b.c.id == 2).\ + where(c.c.id == 3).\ + where(d.c.id == 4).\ + order_by(a.c.id, b.c.id, c.c.id, d.c.id) + + self._test( s, - "SELECT a.id AS a_id, anon_1.b_id AS b_id, " - "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, " - "anon_1.c_b_id AS c_b_id FROM a JOIN " - "(SELECT b.id AS b_id, b.a_id AS b_a_id, " - "c.id AS c_id, c.b_id AS c_b_id " - "FROM b JOIN c ON b.id = c.b_id) AS anon_1 " - "ON a.id = anon_1.b_a_id " - "WHERE anon_1.b_id = :id_1 AND anon_1.c_id = :id_2 " - "ORDER BY a.id, anon_1.b_id, anon_1.c_id" + self._a__b_dc ) - def test_two_froms_overlapping_joins(self): + def test_a_bc_comma_a1_selbc(self): # test here we're emulating is # test.orm.inheritance.test_polymorphic_rel:PolymorphicJoinsTest.test_multi_join j1 = b.join(c) @@ -74,27 +83,51 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL): s = select([a, a_a, b, c, j2], use_labels=True).\ select_from(j3).select_from(j4).order_by(j2.c.b_id) - # this is the non-converted version - """ - SELECT a.id AS a_id, a_1.id AS a_1_id, b.id AS b_id, b.a_id AS b_a_id, - c.id AS c_id, c.b_id AS c_b_id, anon_1.b_id AS anon_1_b_id, - anon_1.b_a_id AS anon_1_b_a_id, - anon_1.c_id AS anon_1_c_id, anon_1.c_b_id AS anon_1_c_b_id - - FROM - a JOIN (b JOIN c ON b.id = c.b_id) ON a.id = b.a_id, + self._test( + s, + self._a_bc_comma_a1_selbc + ) - a AS a_1 JOIN (SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id, - c.b_id AS c_b_id - FROM b JOIN c ON b.id = c.b_id - ) AS anon_1 ON a_1.id = anon_1.b_a_id +class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase): + """test rendering of each join with right-nested rewritten as + aliased SELECT statements..""" - ORDER BY anon_1.b_id - """ + @util.classproperty + def __dialect__(cls): + dialect = default.DefaultDialect() + dialect.supports_right_nested_joins = False + return dialect + _a__b_dc = ( + "SELECT a.id AS a_id, anon_1.b_id AS b_id, " + "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, " + "anon_1.c_b_id AS c_b_id, anon_1.d_id AS d_id, " + "anon_1.d_c_id AS d_c_id " + "FROM a JOIN (SELECT b.id AS b_id, b.a_id AS b_a_id, " + "anon_2.c_id AS c_id, anon_2.c_b_id AS c_b_id, " + "anon_2.d_id AS d_id, anon_2.d_c_id AS d_c_id " + "FROM b JOIN (SELECT c.id AS c_id, c.b_id AS c_b_id, " + "d.id AS d_id, d.c_id AS d_c_id " + "FROM c JOIN d ON c.id = d.c_id) AS anon_2 " + "ON b.id = anon_2.c_b_id) AS anon_1 ON a.id = anon_1.b_a_id " + "WHERE anon_1.b_id = :id_1 AND anon_1.c_id = :id_2 AND " + "anon_1.d_id = :id_3 " + "ORDER BY a.id, anon_1.b_id, anon_1.c_id, anon_1.d_id" + ) + + _a_bc = ( + "SELECT a.id AS a_id, anon_1.b_id AS b_id, " + "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, " + "anon_1.c_b_id AS c_b_id FROM a JOIN " + "(SELECT b.id AS b_id, b.a_id AS b_a_id, " + "c.id AS c_id, c.b_id AS c_b_id " + "FROM b JOIN c ON b.id = c.b_id) AS anon_1 " + "ON a.id = anon_1.b_a_id " + "WHERE anon_1.b_id = :id_1 AND anon_1.c_id = :id_2 " + "ORDER BY a.id, anon_1.b_id, anon_1.c_id" + ) - self.assert_compile( - s, + _a_bc_comma_a1_selbc = ( "SELECT a.id AS a_id, a_1.id AS a_1_id, anon_1.b_id AS b_id, " "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, " "anon_1.c_b_id AS c_b_id, anon_2.b_id AS anon_2_b_id, " @@ -110,6 +143,69 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL): "ON a_1.id = anon_2.b_a_id ORDER BY anon_2.b_id" ) +class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase): + """test rendering of each join with normal nesting.""" + @util.classproperty + def __dialect__(cls): + dialect = default.DefaultDialect() + return dialect + + _a__b_dc = ( + "SELECT a.id AS a_id, b.id AS b_id, " + "b.a_id AS b_a_id, c.id AS c_id, " + "c.b_id AS c_b_id, d.id AS d_id, " + "d.c_id AS d_c_id " + "FROM a JOIN (b JOIN (c JOIN d ON c.id = d.c_id) " + "ON b.id = c.b_id) ON a.id = b.a_id " + "WHERE b.id = :id_1 AND c.id = :id_2 AND " + "d.id = :id_3 " + "ORDER BY a.id, b.id, c.id, d.id" + ) + + + _a_bc = ( + "SELECT a.id AS a_id, b.id AS b_id, " + "b.a_id AS b_a_id, c.id AS c_id, " + "c.b_id AS c_b_id FROM a JOIN " + "(b JOIN c ON b.id = c.b_id) " + "ON a.id = b.a_id " + "WHERE b.id = :id_1 AND c.id = :id_2 " + "ORDER BY a.id, b.id, c.id" + ) + + _a_bc_comma_a1_selbc = ( + "SELECT a.id AS a_id, a_1.id AS a_1_id, b.id AS b_id, " + "b.a_id AS b_a_id, c.id AS c_id, " + "c.b_id AS c_b_id, anon_1.b_id AS anon_1_b_id, " + "anon_1.b_a_id AS anon_1_b_a_id, anon_1.c_id AS anon_1_c_id, " + "anon_1.c_b_id AS anon_1_c_b_id FROM a " + "JOIN (b JOIN c ON b.id = c.b_id) " + "ON a.id = b.a_id, " + "a AS a_1 JOIN " + "(SELECT b.id AS b_id, b.a_id AS b_a_id, " + "c.id AS c_id, c.b_id AS c_b_id " + "FROM b JOIN c ON b.id = c.b_id) AS anon_1 " + "ON a_1.id = anon_1.b_a_id ORDER BY anon_1.b_id" + ) + +class JoinExecTest(_JoinRewriteTestBase, fixtures.TestBase): + """invoke the SQL on the current backend to ensure compatibility""" + + _a_bc = _a_bc_comma_a1_selbc = _a__b_dc = None + + @classmethod + def setup_class(cls): + m.create_all(testing.db) + + @classmethod + def teardown_class(cls): + m.drop_all(testing.db) + + def _test(self, selectable, assert_): + testing.db.execute(selectable) + + +class DialectFlagTest(fixtures.TestBase, AssertsCompiledSQL): def test_dialect_flag(self): d1 = default.DefaultDialect(supports_right_nested_joins=True) d2 = default.DefaultDialect(supports_right_nested_joins=False) -- 2.47.3