]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- support for a__b_dc, i.e. two levels of nesting
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2013 23:44:57 +0000 (19:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2013 23:44:57 +0000 (19:44 -0400)
lib/sqlalchemy/sql/compiler.py
test/sql/test_join_rewriting.py

index 2024666b632363eb9596915d699526cf825962f7..116fb397164ed60e1b6da08b5e3e3f4ecd865454 100644 (file)
@@ -1109,17 +1109,32 @@ class SQLCompiler(engine.Compiled):
                 newelem._reset_exported()
                 newelem.left = visit(newelem.left, **kw)
 
+                right = visit(newelem.right, **kw)
+
                 selectable = sql.select(
-                                    [newelem.right.element],
+                                    [right.element],
                                     use_labels=True).alias()
 
                 for c in selectable.c:
                     c._label = c._key_label = c.name
                 translate_dict = dict(
-                        zip(newelem.right.element.c, selectable.c)
+                        zip(right.element.c, selectable.c)
                     )
-                translate_dict[newelem.right.element.left] = selectable
-                translate_dict[newelem.right.element.right] = selectable
+                translate_dict[right.element.left] = selectable
+                translate_dict[right.element.right] = selectable
+
+                # propagate translations that we've gained
+                # from nested visit(newelem.right) outwards
+                # to the enclosing select here.  this happens
+                # only when we have more than one level of right
+                # join nesting, i.e. "a JOIN (b JOIN (c JOIN d))"
+                for k, v in list(column_translate[-1].items()):
+                    if v in translate_dict:
+                        # remarkably, no current ORM tests (May 2013)
+                        # hit this condition, only test_join_rewriting
+                        # does.
+                        column_translate[-1][k] = translate_dict[v]
+
                 column_translate[-1].update(translate_dict)
 
                 newelem.right = selectable
index 2cf98be3ff25ded4d21612eaf5e3347ce6e6881b..403448492563c8d2136e8c440f7860cd02e2340f 100644 (file)
@@ -2,6 +2,8 @@ from sqlalchemy import Table, Column, Integer, MetaData, ForeignKey, select
 from sqlalchemy.testing import fixtures, AssertsCompiledSQL
 from sqlalchemy import util
 from sqlalchemy.engine import default
+from sqlalchemy import testing
+
 
 m = MetaData()
 
@@ -28,16 +30,17 @@ e = Table('e', m,
         Column('id', Integer, primary_key=True)
     )
 
-class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL):
-    @util.classproperty
-    def __dialect__(cls):
-        dialect = default.DefaultDialect()
-        dialect.supports_right_nested_joins = False
-        return dialect
+class _JoinRewriteTestBase(AssertsCompiledSQL):
+    def _test(self, s, assert_):
+        self.assert_compile(
+            s,
+            assert_
+        )
 
-    def test_one(self):
+    def test_a_bc(self):
         j1 = b.join(c)
         j2 = a.join(j1)
+
         # TODO: if we remove 'b' or 'c', shouldn't we get just
         # the subset of cols from anon_1 ?
 
@@ -49,20 +52,26 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL):
             where(b.c.id == 2).\
             where(c.c.id == 3).order_by(a.c.id, b.c.id, c.c.id)
 
-        self.assert_compile(
+        self._test(s, self._a_bc)
+
+    def test_a__b_dc(self):
+        j1 = c.join(d)
+        j2 = b.join(j1)
+        j3 = a.join(j2)
+
+        s = select([a, b, c, d], use_labels=True).\
+            select_from(j3).\
+            where(b.c.id == 2).\
+            where(c.c.id == 3).\
+            where(d.c.id == 4).\
+                order_by(a.c.id, b.c.id, c.c.id, d.c.id)
+
+        self._test(
             s,
-            "SELECT a.id AS a_id, anon_1.b_id AS b_id, "
-            "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, "
-            "anon_1.c_b_id AS c_b_id FROM a JOIN "
-            "(SELECT b.id AS b_id, b.a_id AS b_a_id, "
-                "c.id AS c_id, c.b_id AS c_b_id "
-                "FROM b JOIN c ON b.id = c.b_id) AS anon_1 "
-            "ON a.id = anon_1.b_a_id "
-            "WHERE anon_1.b_id = :id_1 AND anon_1.c_id = :id_2 "
-            "ORDER BY a.id, anon_1.b_id, anon_1.c_id"
+            self._a__b_dc
         )
 
-    def test_two_froms_overlapping_joins(self):
+    def test_a_bc_comma_a1_selbc(self):
         # test here we're emulating is
         # test.orm.inheritance.test_polymorphic_rel:PolymorphicJoinsTest.test_multi_join
         j1 = b.join(c)
@@ -74,27 +83,51 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL):
         s = select([a, a_a, b, c, j2], use_labels=True).\
                 select_from(j3).select_from(j4).order_by(j2.c.b_id)
 
-        # this is the non-converted version
-        """
-        SELECT a.id AS a_id, a_1.id AS a_1_id, b.id AS b_id, b.a_id AS b_a_id,
-            c.id AS c_id, c.b_id AS c_b_id, anon_1.b_id AS anon_1_b_id,
-            anon_1.b_a_id AS anon_1_b_a_id,
-            anon_1.c_id AS anon_1_c_id, anon_1.c_b_id AS anon_1_c_b_id
-
-        FROM
-            a JOIN (b JOIN c ON b.id = c.b_id) ON a.id = b.a_id,
+        self._test(
+            s,
+            self._a_bc_comma_a1_selbc
+        )
 
-            a AS a_1 JOIN (SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id,
-                        c.b_id AS c_b_id
-                    FROM b JOIN c ON b.id = c.b_id
-            ) AS anon_1 ON a_1.id = anon_1.b_a_id
+class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase):
+    """test rendering of each join with right-nested rewritten as
+    aliased SELECT statements.."""
 
-        ORDER BY anon_1.b_id
-        """
+    @util.classproperty
+    def __dialect__(cls):
+        dialect = default.DefaultDialect()
+        dialect.supports_right_nested_joins = False
+        return dialect
 
+    _a__b_dc = (
+            "SELECT a.id AS a_id, anon_1.b_id AS b_id, "
+            "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, "
+            "anon_1.c_b_id AS c_b_id, anon_1.d_id AS d_id, "
+            "anon_1.d_c_id AS d_c_id "
+            "FROM a JOIN (SELECT b.id AS b_id, b.a_id AS b_a_id, "
+            "anon_2.c_id AS c_id, anon_2.c_b_id AS c_b_id, "
+            "anon_2.d_id AS d_id, anon_2.d_c_id AS d_c_id "
+            "FROM b JOIN (SELECT c.id AS c_id, c.b_id AS c_b_id, "
+            "d.id AS d_id, d.c_id AS d_c_id "
+            "FROM c JOIN d ON c.id = d.c_id) AS anon_2 "
+            "ON b.id = anon_2.c_b_id) AS anon_1 ON a.id = anon_1.b_a_id "
+            "WHERE anon_1.b_id = :id_1 AND anon_1.c_id = :id_2 AND "
+            "anon_1.d_id = :id_3 "
+            "ORDER BY a.id, anon_1.b_id, anon_1.c_id, anon_1.d_id"
+            )
+
+    _a_bc = (
+            "SELECT a.id AS a_id, anon_1.b_id AS b_id, "
+            "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, "
+            "anon_1.c_b_id AS c_b_id FROM a JOIN "
+            "(SELECT b.id AS b_id, b.a_id AS b_a_id, "
+                "c.id AS c_id, c.b_id AS c_b_id "
+                "FROM b JOIN c ON b.id = c.b_id) AS anon_1 "
+            "ON a.id = anon_1.b_a_id "
+            "WHERE anon_1.b_id = :id_1 AND anon_1.c_id = :id_2 "
+            "ORDER BY a.id, anon_1.b_id, anon_1.c_id"
+            )
 
-        self.assert_compile(
-            s,
+    _a_bc_comma_a1_selbc = (
             "SELECT a.id AS a_id, a_1.id AS a_1_id, anon_1.b_id AS b_id, "
             "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, "
             "anon_1.c_b_id AS c_b_id, anon_2.b_id AS anon_2_b_id, "
@@ -110,6 +143,69 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL):
             "ON a_1.id = anon_2.b_a_id ORDER BY anon_2.b_id"
         )
 
+class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase):
+    """test rendering of each join with normal nesting."""
+    @util.classproperty
+    def __dialect__(cls):
+        dialect = default.DefaultDialect()
+        return dialect
+
+    _a__b_dc = (
+            "SELECT a.id AS a_id, b.id AS b_id, "
+            "b.a_id AS b_a_id, c.id AS c_id, "
+            "c.b_id AS c_b_id, d.id AS d_id, "
+            "d.c_id AS d_c_id "
+            "FROM a JOIN (b JOIN (c JOIN d ON c.id = d.c_id) "
+            "ON b.id = c.b_id) ON a.id = b.a_id "
+            "WHERE b.id = :id_1 AND c.id = :id_2 AND "
+            "d.id = :id_3 "
+            "ORDER BY a.id, b.id, c.id, d.id"
+            )
+
+
+    _a_bc = (
+            "SELECT a.id AS a_id, b.id AS b_id, "
+            "b.a_id AS b_a_id, c.id AS c_id, "
+            "c.b_id AS c_b_id FROM a JOIN "
+            "(b JOIN c ON b.id = c.b_id) "
+            "ON a.id = b.a_id "
+            "WHERE b.id = :id_1 AND c.id = :id_2 "
+            "ORDER BY a.id, b.id, c.id"
+            )
+
+    _a_bc_comma_a1_selbc = (
+            "SELECT a.id AS a_id, a_1.id AS a_1_id, b.id AS b_id, "
+            "b.a_id AS b_a_id, c.id AS c_id, "
+            "c.b_id AS c_b_id, anon_1.b_id AS anon_1_b_id, "
+            "anon_1.b_a_id AS anon_1_b_a_id, anon_1.c_id AS anon_1_c_id, "
+            "anon_1.c_b_id AS anon_1_c_b_id FROM a "
+            "JOIN (b JOIN c ON b.id = c.b_id) "
+            "ON a.id = b.a_id, "
+            "a AS a_1 JOIN "
+                "(SELECT b.id AS b_id, b.a_id AS b_a_id, "
+                "c.id AS c_id, c.b_id AS c_b_id "
+                "FROM b JOIN c ON b.id = c.b_id) AS anon_1 "
+            "ON a_1.id = anon_1.b_a_id ORDER BY anon_1.b_id"
+        )
+
+class JoinExecTest(_JoinRewriteTestBase, fixtures.TestBase):
+    """invoke the SQL on the current backend to ensure compatibility"""
+
+    _a_bc = _a_bc_comma_a1_selbc = _a__b_dc = None
+
+    @classmethod
+    def setup_class(cls):
+        m.create_all(testing.db)
+
+    @classmethod
+    def teardown_class(cls):
+        m.drop_all(testing.db)
+
+    def _test(self, selectable, assert_):
+        testing.db.execute(selectable)
+
+
+class DialectFlagTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_dialect_flag(self):
         d1 = default.DefaultDialect(supports_right_nested_joins=True)
         d2 = default.DefaultDialect(supports_right_nested_joins=False)