]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Adjust CTE recrusive col list to accommodate dupe col names
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Jul 2021 23:48:55 +0000 (19:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Jul 2021 14:25:52 +0000 (10:25 -0400)
Fixed issue in CTE constructs where a recursive CTE that referred to a
SELECT that has duplicate column names, which are typically deduplicated
using labeling logic in 1.4, would fail to refer to the deduplicated label
name correctly within the WITH clause.

As part of this change we are also attempting to remove the
behavior of SelectStatementGrouping forcing off the "asfrom"
contextual flag, which will have the result of additional labeling
being applied to some UNION and similar statements when they are
interpreted as subqueries.  To maintain compatibility with
"grouping", the Grouping/SelectStatementGrouping are now broken
out into two separate compiler cases, as the "asfrom" logic appears
to be tailored towards table valued SELECTS as column expressions.

Fixes: #6710
Change-Id: I8af07a5c670dbe5736cd9f16084ef82f5e4c8642

doc/build/changelog/unreleased_14/6710.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/orm/test_core_compilation.py
test/sql/test_compiler.py
test/sql/test_cte.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/6710.rst b/doc/build/changelog/unreleased_14/6710.rst
new file mode 100644 (file)
index 0000000..32784e8
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 6710
+
+    Fixed issue in CTE constructs where a recursive CTE that referred to a
+    SELECT that has duplicate column names, which are typically deduplicated
+    using labeling logic in 1.4, would fail to refer to the deduplicated label
+    name correctly within the WITH clause.
index 7007c2e869ed7cca01f7123da404bcbf687a9f04..360a53ac854fde92af4a577fe1beea1c2ef21b52 100644 (file)
@@ -1311,6 +1311,9 @@ class SQLCompiler(Compiled):
     def visit_grouping(self, grouping, asfrom=False, **kwargs):
         return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
 
+    def visit_select_statement_grouping(self, grouping, **kwargs):
+        return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
+
     def visit_label_reference(
         self, element, within_columns_clause=False, **kwargs
     ):
@@ -2562,17 +2565,29 @@ class SQLCompiler(Compiled):
                         col_source = cte.element.selects[0]
                     else:
                         assert False, "cte should only be against SelectBase"
+
+                    # TODO: can we get at the .columns_plus_names collection
+                    # that is already (or will be?) generated for the SELECT
+                    # rather than calling twice?
                     recur_cols = [
-                        c
-                        for c in util.unique_list(
-                            col_source._all_selected_columns
-                        )
-                        if c is not None
+                        # TODO: proxy_name is not technically safe,
+                        # see test_cte->
+                        # test_with_recursive_no_name_currently_buggy.  not
+                        # clear what should be done with such a case
+                        fallback_label_name or proxy_name
+                        for (
+                            _,
+                            proxy_name,
+                            fallback_label_name,
+                            c,
+                            repeated,
+                        ) in (col_source._generate_columns_plus_names(True))
+                        if not repeated
                     ]
 
                     text += "(%s)" % (
                         ", ".join(
-                            self.preparer.format_column(
+                            self.preparer.format_label_name(
                                 ident, anon_map=self.anon_map
                             )
                             for ident in recur_cols
@@ -5103,6 +5118,20 @@ class IdentifierPreparer(object):
 
         return self.quote(name)
 
+    def format_label_name(
+        self,
+        name,
+        anon_map=None,
+    ):
+        """Prepare a quoted column name."""
+
+        if anon_map is not None and isinstance(
+            name, elements._truncated_label
+        ):
+            name = name.apply_map(anon_map)
+
+        return self.quote(name)
+
     def format_column(
         self,
         column,
index 30a613089e1df6a089bd9723714230ff3b5b52dc..b6cf7f55e85336145328f9d0675acd7a122f5864 100644 (file)
@@ -3139,7 +3139,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
 
     """
 
-    __visit_name__ = "grouping"
+    __visit_name__ = "select_statement_grouping"
     _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
 
     _is_select_container = True
@@ -3173,6 +3173,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase):
     def self_group(self, against=None):
         return self
 
+    def _generate_columns_plus_names(self, anon_for_dupe_key):
+        return self.element._generate_columns_plus_names(anon_for_dupe_key)
+
     def _generate_fromclause_column_proxies(self, subquery):
         self.element._generate_fromclause_column_proxies(subquery)
 
index 5f25b56e882b5147335a1fe88b02acb2baa93eda..e730d9097581e2b038f5db61318e2da023c46257 100644 (file)
@@ -627,8 +627,10 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL):
 
         self.assert_compile(
             select(u_alias),
-            "SELECT anon_1.id FROM ((SELECT users.name, users.id FROM users "
-            "WHERE users.id = :id_1 UNION SELECT users.name, users.id "
+            "SELECT anon_1.id FROM ((SELECT users.name AS name, "
+            "users.id AS id FROM users "
+            "WHERE users.id = :id_1 UNION SELECT users.name AS name, "
+            "users.id AS id "
             "FROM users WHERE users.id = :id_2) "
             "UNION SELECT users.name AS name, users.id AS id "
             "FROM users WHERE users.id = :id_3) AS anon_1",
@@ -656,8 +658,9 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL):
         self.assert_compile(
             select(u_alias).options(undefer(u_alias.name)),
             "SELECT anon_1.name, anon_1.id FROM "
-            "((SELECT users.name, users.id FROM users "
-            "WHERE users.id = :id_1 UNION SELECT users.name, users.id "
+            "((SELECT users.name AS name, users.id AS id FROM users "
+            "WHERE users.id = :id_1 UNION SELECT users.name AS name, "
+            "users.id AS id "
             "FROM users WHERE users.id = :id_2) "
             "UNION SELECT users.name AS name, users.id AS id "
             "FROM users WHERE users.id = :id_3) AS anon_1",
index f2c1e004d8714fed505dde1dd85986fa5d47ee70..40faab486791dd17bae037bcdd6caa5dbfb88d2c 100644 (file)
@@ -892,6 +892,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         "WITH RECURSIVE (colnames)" part.  This test shows that this isn't
         correct when keys are present.
 
+        See also test_cte ->
+        test_wrecur_ovlp_lbls_plus_dupes_separate_keys_use_labels
+
         """
         m = MetaData()
         foo = Table(
index e8a8a3150ce9deb6d2907c1d6e1bd0dddb5bb3ac..f1d27aa8f174c6d73d7cad8bd2144ae5fc4b1030 100644 (file)
@@ -1,4 +1,9 @@
+from sqlalchemy import Column
 from sqlalchemy import delete
+from sqlalchemy import Integer
+from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
+from sqlalchemy import MetaData
+from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import update
@@ -495,6 +500,149 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             s.compile,
         )
 
+    def test_with_recursive_no_name_currently_buggy(self):
+        s1 = select(1)
+        c1 = s1.cte(name="cte1", recursive=True)
+
+        # this is nonsensical at the moment
+        self.assert_compile(
+            select(c1),
+            'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1 FROM cte1',
+        )
+
+        # however, so is subquery, which is worse as it isn't even trying
+        # to quote "1" as a label
+        self.assert_compile(
+            select(s1.subquery()), "SELECT anon_1.1 FROM (SELECT 1) AS anon_1"
+        )
+
+    def test_wrecur_dupe_col_names(self):
+        """test #6710"""
+
+        manager = table("manager", column("id"))
+        employee = table("employee", column("id"), column("manager_id"))
+
+        top_q = select(employee, manager).join_from(
+            employee, manager, employee.c.manager_id == manager.c.id
+        )
+
+        top_q = top_q.cte("cte", recursive=True)
+
+        bottom_q = (
+            select(employee, manager)
+            .join_from(
+                employee, manager, employee.c.manager_id == manager.c.id
+            )
+            .join(top_q, top_q.c.id == employee.c.id)
+        )
+
+        rec_cte = select(top_q.union_all(bottom_q))
+        self.assert_compile(
+            rec_cte,
+            "WITH RECURSIVE cte(id, manager_id, id_1) AS "
+            "(SELECT employee.id AS id, employee.manager_id AS manager_id, "
+            "manager.id AS id_1 FROM employee JOIN manager "
+            "ON employee.manager_id = manager.id UNION ALL "
+            "SELECT employee.id AS id, employee.manager_id AS manager_id, "
+            "manager.id AS id_1 FROM employee JOIN manager ON "
+            "employee.manager_id = manager.id "
+            "JOIN cte ON cte.id = employee.id) "
+            "SELECT cte.id, cte.manager_id, cte.id_1 FROM cte",
+        )
+
+    def test_wrecur_dupe_col_names_w_grouping(self):
+        """test #6710
+
+        by adding order_by() to the top query, the CTE will have
+        a compound select with the first element a SelectStatementGrouping
+        object, which we can test has the correct methods for the compiler
+        to call upon.
+
+        """
+
+        manager = table("manager", column("id"))
+        employee = table("employee", column("id"), column("manager_id"))
+
+        top_q = (
+            select(employee, manager)
+            .join_from(
+                employee, manager, employee.c.manager_id == manager.c.id
+            )
+            .order_by(employee.c.id)
+            .cte("cte", recursive=True)
+        )
+
+        bottom_q = (
+            select(employee, manager)
+            .join_from(
+                employee, manager, employee.c.manager_id == manager.c.id
+            )
+            .join(top_q, top_q.c.id == employee.c.id)
+        )
+
+        rec_cte = select(top_q.union_all(bottom_q))
+
+        self.assert_compile(
+            rec_cte,
+            "WITH RECURSIVE cte(id, manager_id, id_1) AS "
+            "((SELECT employee.id AS id, employee.manager_id AS manager_id, "
+            "manager.id AS id_1 FROM employee JOIN manager "
+            "ON employee.manager_id = manager.id ORDER BY employee.id) "
+            "UNION ALL "
+            "SELECT employee.id AS id, employee.manager_id AS manager_id, "
+            "manager.id AS id_1 FROM employee JOIN manager ON "
+            "employee.manager_id = manager.id "
+            "JOIN cte ON cte.id = employee.id) "
+            "SELECT cte.id, cte.manager_id, cte.id_1 FROM cte",
+        )
+
+    def test_wrecur_ovlp_lbls_plus_dupes_separate_keys_use_labels(self):
+        """test a condition related to #6710.
+
+        also see test_compiler->
+        test_overlapping_labels_plus_dupes_separate_keys_use_labels
+
+        for a non cte form of this test.
+
+        """
+
+        m = MetaData()
+        foo = Table(
+            "foo",
+            m,
+            Column("id", Integer),
+            Column("bar_id", Integer, key="bb"),
+        )
+        foo_bar = Table("foo_bar", m, Column("id", Integer, key="bb"))
+
+        stmt = select(
+            foo.c.id,
+            foo.c.bb,
+            foo_bar.c.bb,
+            foo.c.bb,
+            foo.c.id,
+            foo.c.bb,
+            foo_bar.c.bb,
+            foo_bar.c.bb,
+        ).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+        cte = stmt.cte(recursive=True)
+
+        self.assert_compile(
+            select(cte),
+            "WITH RECURSIVE anon_1(foo_id, foo_bar_id, foo_bar_id_1) AS "
+            "(SELECT foo.id AS foo_id, foo.bar_id AS foo_bar_id, "
+            "foo_bar.id AS foo_bar_id_1, foo.bar_id AS foo_bar_id__1, "
+            "foo.id AS foo_id__1, foo.bar_id AS foo_bar_id__1, "
+            "foo_bar.id AS foo_bar_id__2, foo_bar.id AS foo_bar_id__2 "
+            "FROM foo, foo_bar) "
+            "SELECT anon_1.foo_id, anon_1.foo_bar_id, anon_1.foo_bar_id_1, "
+            "anon_1.foo_bar_id AS foo_bar_id_2, anon_1.foo_id AS foo_id_1, "
+            "anon_1.foo_bar_id AS foo_bar_id_3, "
+            "anon_1.foo_bar_id_1 AS foo_bar_id_1_1, "
+            "anon_1.foo_bar_id_1 AS foo_bar_id_1_2 FROM anon_1",
+        )
+
     def test_union(self):
         orders = table("orders", column("region"), column("amount"))
 
index be894d239e10a9fbfc5aafe1cf619a8fabb5d031..cfdf4ad02ee884dea7514b601f48700ad2a66e5a 100644 (file)
@@ -998,9 +998,11 @@ class SelectableTest(
         self.assert_compile(
             stmt,
             "SELECT anon_1.col1, anon_1.col2, anon_1.col1_1 FROM "
-            "((SELECT table1.col1, table1.col2, table2.col1 AS col1_1 "
+            "((SELECT table1.col1 AS col1, table1.col2 AS col2, table2.col1 "
+            "AS col1_1 "
             "FROM table1, table2 LIMIT :param_1) UNION "
-            "(SELECT table2.col1, table2.col2, table2.col3 FROM table2 "
+            "(SELECT table2.col1 AS col1, table2.col2 AS col2, "
+            "table2.col3 AS col3 FROM table2 "
             "LIMIT :param_2)) AS anon_1",
         )