]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
rewriting scheme now works.
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2013 18:30:29 +0000 (14:30 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2013 18:30:29 +0000 (14:30 -0400)
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_join_rewriting.py

index 502b54f297a4b1506b40fcbc9e068d723c8cc2ce..8ab81cb130e1f382c3a2d6a3aa40668683db9c70 100644 (file)
@@ -1889,12 +1889,9 @@ class Query(object):
         aliased_entity = right_mapper and \
                             not right_is_aliased and \
                             (
-                                # TODO: this produces queries that fail the
-                                # compiler transformation in test_polymorphic_rel
-                                isinstance(right_mapper._with_polymorphic_selectable, expression.Alias)
-
-                                # current
-                                # right_mapper.with_polymorphic
+                                isinstance(
+                                    right_mapper._with_polymorphic_selectable,
+                                    expression.Alias)
                                 or
                                 overlap # test for overlap:
                                         # orm/inheritance/relationships.py
index 3e159b11266650b2ceadd8876aa17a3a263bb35a..d245c781a3490ad17caaf3c886a595bf2e8ee045 100644 (file)
@@ -1078,86 +1078,58 @@ class SQLCompiler(engine.Compiled):
         return None
 
     def _transform_select_for_nested_joins(self, select):
-        adapters = []
-        stop_on = []
-
-        # test for "unconditional" - any statement with
-        # no_replacement_traverse setup, i.e. query.statement, from_self(), etc.
-        #traverse_options = {"cloned": {}, "unconditional": True}
-        traverse_options = {"unconditional": True}
+        """Rewrite any "a JOIN (b JOIN c)" expression as
+        "a JOIN (select * from b JOIN c) AS anon", to support
+        databases that can't parse a parenthesized join correctly
+        (i.e. sqlite the main one).
 
+        """
         cloned = {}
-        def thing(element, **kw):
-            if element in cloned:
-                return cloned[element]
-
-            newelem = cloned[element] = element._clone()
-
-            if newelem.__visit_name__ == 'join' and \
-                isinstance(newelem.right, sql.FromGrouping):
-                selectable = sql.select([newelem.right.element], use_labels=True)
-                selectable = selectable.alias()
-                newelem.right = selectable
-                stop_on.append(selectable)
-                for c in selectable.c:
-                    c._label = c._key_label = c.name
-                adapter = sql_util.ClauseAdapter(selectable,
-                                        traverse_options=traverse_options)
-                adapter.magic_flag = True
-                adapters.append(adapter)
-            else:
-                newelem._copy_internals(clone=thing, **kw)
-
-            return newelem
+        column_translate = [{}]
 
-        elem = thing(select)
-        while adapters:
-            adapt = adapters.pop(-1)
-            adapt.__traverse_options__['stop_on'].extend(stop_on)
-            elem = adapt.traverse(elem)
-        return elem
+        join_name = sql.Join.__visit_name__
+        select_name = sql.Select.__visit_name__
 
+        def visit(element, **kw):
+            if element in column_translate[-1]:
+                return column_translate[-1][element]
 
-    def _transform_select_for_nested_joins(self, select):
-        adapters = []
-        stop_on = []
-
-        # test for "unconditional" - any statement with
-        # no_replacement_traverse setup, i.e. query.statement, from_self(), etc.
-        #traverse_options = {"cloned": {}, "unconditional": True}
-        traverse_options = {"unconditional": True}
+            elif element in cloned:
+                return cloned[element]
 
-        def visit_join(elem):
-            if isinstance(elem.right, sql.FromGrouping):
-                selectable = sql.select([elem.right.element], use_labels=True)
-                selectable = selectable.alias()
+            newelem = cloned[element] = element._clone()
 
-                while adapters:
-                    adapt = adapters.pop(-1)
-                    selectable = adapt.traverse(selectable)
-                #stop_on.append(selectable)
+            if newelem.__visit_name__ is join_name and \
+                isinstance(newelem.right, sql.FromGrouping):
 
-                # test: see test_subquery_relations:
-                # CyclicalInheritingEagerTestTwo.test_integrate
-                stop_on.append(elem.left)
+                newelem._reset_exported()
+                newelem.left = visit(newelem.left, **kw)
 
+                selectable = sql.select(
+                                    [newelem.right.element],
+                                    use_labels=True).alias()
 
                 for c in selectable.c:
                     c._label = c._key_label = c.name
+                translate_dict = dict(
+                        zip(newelem.right.element.c, selectable.c)
+                    )
+                translate_dict[newelem.right.element.left] = selectable
+                translate_dict[newelem.right.element.right] = selectable
+                column_translate[-1].update(translate_dict)
 
-                elem.right = selectable
-                adapter = sql_util.ClauseAdapter(selectable,
-                                        traverse_options=traverse_options)
-                adapter.__traverse_options__['stop_on'].extend(stop_on)
-                adapters.append(adapter)
-
+                newelem.right = selectable
+                newelem.onclause = visit(newelem.onclause, **kw)
+            elif newelem.__visit_name__ is select_name:
+                column_translate.append({})
+                newelem._copy_internals(clone=visit, **kw)
+                del column_translate[-1]
+            else:
+                newelem._copy_internals(clone=visit, **kw)
 
-        select = visitors.cloned_traverse(select,
-                                    traverse_options, {"join": visit_join})
+            return newelem
 
-        for adap in reversed(adapters):
-            select = adap.traverse(select)
-        return select
+        return visit(select)
 
     def _transform_result_map_for_nested_joins(self, select, transformed_select):
         d = dict(zip(transformed_select.inner_columns, select.inner_columns))
@@ -1172,10 +1144,12 @@ class SQLCompiler(engine.Compiled):
                             positional_names=None,
                             nested_join_translation=False, **kwargs):
 
+        needs_nested_translation = \
+                            not nested_join_translation and \
+                            not self.stack and \
+                            not self.dialect.supports_right_nested_joins
 
-        if self.dialect.supports_right_nested_joins:
-            nested_join_translation = True
-        if not nested_join_translation:
+        if needs_nested_translation:
             transformed_select = self._transform_select_for_nested_joins(select)
             text = self.visit_select(
                             transformed_select, asfrom=asfrom, parens=parens,
@@ -1186,8 +1160,6 @@ class SQLCompiler(engine.Compiled):
                             nested_join_translation=True, **kwargs
                         )
 
-
-
         entry = self.stack and self.stack[-1] or {}
 
         populate_result_map = force_result_map or (
@@ -1197,7 +1169,7 @@ class SQLCompiler(engine.Compiled):
                                         )
                                     )
 
-        if not nested_join_translation:
+        if needs_nested_translation:
             if populate_result_map:
                 self._transform_result_map_for_nested_joins(
                                                 select, transformed_select)
index 30cc109d59ee15d53a9e4643bf027934508c2efa..ba54acb0549169e80a1408b980b4a41341cbbb50 100644 (file)
@@ -39,6 +39,12 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_one(self):
         j1 = b.join(c)
         j2 = a.join(j1)
+        # TODO: if we remove 'b' or 'c', shouldn't we get just
+        # the subset of cols from anon_1 ?
+
+        # TODO: do this test also with individual cols, things change
+        # lots based on how you go with this
+
         s = select([a, b, c], use_labels=True).\
             select_from(j2).\
             where(b.c.id == 2).\
@@ -58,6 +64,8 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_two_froms_overlapping_joins(self):
+        # test here we're emulating is
+        # test.orm.inheritance.test_polymorphic_rel:PolymorphicJoinsTest.test_multi_join
         j1 = b.join(c)
         j2 = b.join(c).select(use_labels=True).alias()
         j3 = a.join(j1)
@@ -85,35 +93,20 @@ class JoinRewriteTest(fixtures.TestBase, AssertsCompiledSQL):
         ORDER BY anon_1.b_id
         """
 
-        """
-        SELECT a.id AS a_id, a_1.id AS a_1_id, anon_1.b_id AS b_id,
-        anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, anon_1.c_b_id AS c_b_id,
-        anon_2.b_id AS anon_2_b_id, anon_2.b_a_id AS anon_2_b_a_id,
-        anon_2.c_id AS anon_2_c_id, anon_2.c_b_id AS anon_2_c_b_id
-
-        FROM
-
-            a JOIN (
-                    SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id,
-                    c.b_id AS c_b_id
-                    FROM b JOIN c ON b.id = c.b_id) AS anon_2 ON a.id = anon_2.b_a_id,
-
-            a AS a_1 JOIN (
-                        SELECT anon_2.b_id AS anon_2_b_id, anon_2.b_a_id AS anon_2_b_a_id,
-                        anon_2.c_id AS anon_2_c_id, anon_2.c_b_id AS anon_2_c_b_id
-                    FROM (
-                        SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id,
-                        c.b_id AS c_b_id FROM b JOIN c ON b.id = c.b_id)
-                        AS anon_2 JOIN (
-                            SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id, c.b_id AS c_b_id
-                            FROM b JOIN c ON b.id = c.b_id) AS anon_2
-                        ON anon_2.b_id = anon_2.c_b_id) AS anon_1 ON a_1.id = anon_1.b_a_id
-
-            ORDER BY anon_1.b_id
-
-        """
 
         self.assert_compile(
             s,
-            ""
+            "SELECT a.id AS a_id, a_1.id AS a_1_id, anon_1.b_id AS b_id, "
+            "anon_1.b_a_id AS b_a_id, anon_1.c_id AS c_id, "
+            "anon_1.c_b_id AS c_b_id, anon_2.b_id AS anon_2_b_id, "
+            "anon_2.b_a_id AS anon_2_b_a_id, anon_2.c_id AS anon_2_c_id, "
+            "anon_2.c_b_id AS anon_2_c_b_id FROM a "
+            "JOIN (SELECT b.id AS b_id, b.a_id AS b_a_id, c.id AS c_id, "
+            "c.b_id AS c_b_id FROM b JOIN c ON b.id = c.b_id) AS anon_1 "
+            "ON a.id = anon_1.b_a_id, "
+            "a AS a_1 JOIN "
+                "(SELECT b.id AS b_id, b.a_id AS b_a_id, "
+                "c.id AS c_id, c.b_id AS c_b_id "
+                "FROM b JOIN c ON b.id = c.b_id) AS anon_2 "
+            "ON a_1.id = anon_2.b_a_id ORDER BY anon_2.b_id"
         )