From a0953bb7095dde805de8c13699b122767ed001b9 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 3 Jul 2021 19:48:55 -0400 Subject: [PATCH] Adjust CTE recrusive col list to accommodate dupe col names Fixed issue in CTE constructs where a recursive CTE that referred to a SELECT that has duplicate column names, which are typically deduplicated using labeling logic in 1.4, would fail to refer to the deduplicated label name correctly within the WITH clause. As part of this change we are also attempting to remove the behavior of SelectStatementGrouping forcing off the "asfrom" contextual flag, which will have the result of additional labeling being applied to some UNION and similar statements when they are interpreted as subqueries. To maintain compatibility with "grouping", the Grouping/SelectStatementGrouping are now broken out into two separate compiler cases, as the "asfrom" logic appears to be tailored towards table valued SELECTS as column expressions. Fixes: #6710 Change-Id: I8af07a5c670dbe5736cd9f16084ef82f5e4c8642 --- doc/build/changelog/unreleased_14/6710.rst | 8 ++ lib/sqlalchemy/sql/compiler.py | 41 +++++- lib/sqlalchemy/sql/selectable.py | 5 +- test/orm/test_core_compilation.py | 11 +- test/sql/test_compiler.py | 3 + test/sql/test_cte.py | 148 +++++++++++++++++++++ test/sql/test_selectable.py | 6 +- 7 files changed, 209 insertions(+), 13 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6710.rst diff --git a/doc/build/changelog/unreleased_14/6710.rst b/doc/build/changelog/unreleased_14/6710.rst new file mode 100644 index 0000000000..32784e889c --- /dev/null +++ b/doc/build/changelog/unreleased_14/6710.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 6710 + + Fixed issue in CTE constructs where a recursive CTE that referred to a + SELECT that has duplicate column names, which are typically deduplicated + using labeling logic in 1.4, would fail to refer to the deduplicated label + name correctly within the WITH clause. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7007c2e869..360a53ac85 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1311,6 +1311,9 @@ class SQLCompiler(Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + def visit_select_statement_grouping(self, grouping, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + def visit_label_reference( self, element, within_columns_clause=False, **kwargs ): @@ -2562,17 +2565,29 @@ class SQLCompiler(Compiled): col_source = cte.element.selects[0] else: assert False, "cte should only be against SelectBase" + + # TODO: can we get at the .columns_plus_names collection + # that is already (or will be?) generated for the SELECT + # rather than calling twice? recur_cols = [ - c - for c in util.unique_list( - col_source._all_selected_columns - ) - if c is not None + # TODO: proxy_name is not technically safe, + # see test_cte-> + # test_with_recursive_no_name_currently_buggy. not + # clear what should be done with such a case + fallback_label_name or proxy_name + for ( + _, + proxy_name, + fallback_label_name, + c, + repeated, + ) in (col_source._generate_columns_plus_names(True)) + if not repeated ] text += "(%s)" % ( ", ".join( - self.preparer.format_column( + self.preparer.format_label_name( ident, anon_map=self.anon_map ) for ident in recur_cols @@ -5103,6 +5118,20 @@ class IdentifierPreparer(object): return self.quote(name) + def format_label_name( + self, + name, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + return self.quote(name) + def format_column( self, column, diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 30a613089e..b6cf7f55e8 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3139,7 +3139,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ - __visit_name__ = "grouping" + __visit_name__ = "select_statement_grouping" _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] _is_select_container = True @@ -3173,6 +3173,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def self_group(self, against=None): return self + def _generate_columns_plus_names(self, anon_for_dupe_key): + return self.element._generate_columns_plus_names(anon_for_dupe_key) + def _generate_fromclause_column_proxies(self, subquery): self.element._generate_fromclause_column_proxies(subquery) diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 5f25b56e88..e730d90975 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -627,8 +627,10 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select(u_alias), - "SELECT anon_1.id FROM ((SELECT users.name, users.id FROM users " - "WHERE users.id = :id_1 UNION SELECT users.name, users.id " + "SELECT anon_1.id FROM ((SELECT users.name AS name, " + "users.id AS id FROM users " + "WHERE users.id = :id_1 UNION SELECT users.name AS name, " + "users.id AS id " "FROM users WHERE users.id = :id_2) " "UNION SELECT users.name AS name, users.id AS id " "FROM users WHERE users.id = :id_3) AS anon_1", @@ -656,8 +658,9 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select(u_alias).options(undefer(u_alias.name)), "SELECT anon_1.name, anon_1.id FROM " - "((SELECT users.name, users.id FROM users " - "WHERE users.id = :id_1 UNION SELECT users.name, users.id " + "((SELECT users.name AS name, users.id AS id FROM users " + "WHERE users.id = :id_1 UNION SELECT users.name AS name, " + "users.id AS id " "FROM users WHERE users.id = :id_2) " "UNION SELECT users.name AS name, users.id AS id " "FROM users WHERE users.id = :id_3) AS anon_1", diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index f2c1e004d8..40faab4867 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -892,6 +892,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "WITH RECURSIVE (colnames)" part. This test shows that this isn't correct when keys are present. + See also test_cte -> + test_wrecur_ovlp_lbls_plus_dupes_separate_keys_use_labels + """ m = MetaData() foo = Table( diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index e8a8a3150c..f1d27aa8f1 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,4 +1,9 @@ +from sqlalchemy import Column from sqlalchemy import delete +from sqlalchemy import Integer +from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL +from sqlalchemy import MetaData +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update @@ -495,6 +500,149 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): s.compile, ) + def test_with_recursive_no_name_currently_buggy(self): + s1 = select(1) + c1 = s1.cte(name="cte1", recursive=True) + + # this is nonsensical at the moment + self.assert_compile( + select(c1), + 'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1 FROM cte1', + ) + + # however, so is subquery, which is worse as it isn't even trying + # to quote "1" as a label + self.assert_compile( + select(s1.subquery()), "SELECT anon_1.1 FROM (SELECT 1) AS anon_1" + ) + + def test_wrecur_dupe_col_names(self): + """test #6710""" + + manager = table("manager", column("id")) + employee = table("employee", column("id"), column("manager_id")) + + top_q = select(employee, manager).join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + + top_q = top_q.cte("cte", recursive=True) + + bottom_q = ( + select(employee, manager) + .join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + .join(top_q, top_q.c.id == employee.c.id) + ) + + rec_cte = select(top_q.union_all(bottom_q)) + self.assert_compile( + rec_cte, + "WITH RECURSIVE cte(id, manager_id, id_1) AS " + "(SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager " + "ON employee.manager_id = manager.id UNION ALL " + "SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager ON " + "employee.manager_id = manager.id " + "JOIN cte ON cte.id = employee.id) " + "SELECT cte.id, cte.manager_id, cte.id_1 FROM cte", + ) + + def test_wrecur_dupe_col_names_w_grouping(self): + """test #6710 + + by adding order_by() to the top query, the CTE will have + a compound select with the first element a SelectStatementGrouping + object, which we can test has the correct methods for the compiler + to call upon. + + """ + + manager = table("manager", column("id")) + employee = table("employee", column("id"), column("manager_id")) + + top_q = ( + select(employee, manager) + .join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + .order_by(employee.c.id) + .cte("cte", recursive=True) + ) + + bottom_q = ( + select(employee, manager) + .join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + .join(top_q, top_q.c.id == employee.c.id) + ) + + rec_cte = select(top_q.union_all(bottom_q)) + + self.assert_compile( + rec_cte, + "WITH RECURSIVE cte(id, manager_id, id_1) AS " + "((SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager " + "ON employee.manager_id = manager.id ORDER BY employee.id) " + "UNION ALL " + "SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager ON " + "employee.manager_id = manager.id " + "JOIN cte ON cte.id = employee.id) " + "SELECT cte.id, cte.manager_id, cte.id_1 FROM cte", + ) + + def test_wrecur_ovlp_lbls_plus_dupes_separate_keys_use_labels(self): + """test a condition related to #6710. + + also see test_compiler-> + test_overlapping_labels_plus_dupes_separate_keys_use_labels + + for a non cte form of this test. + + """ + + m = MetaData() + foo = Table( + "foo", + m, + Column("id", Integer), + Column("bar_id", Integer, key="bb"), + ) + foo_bar = Table("foo_bar", m, Column("id", Integer, key="bb")) + + stmt = select( + foo.c.id, + foo.c.bb, + foo_bar.c.bb, + foo.c.bb, + foo.c.id, + foo.c.bb, + foo_bar.c.bb, + foo_bar.c.bb, + ).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + cte = stmt.cte(recursive=True) + + self.assert_compile( + select(cte), + "WITH RECURSIVE anon_1(foo_id, foo_bar_id, foo_bar_id_1) AS " + "(SELECT foo.id AS foo_id, foo.bar_id AS foo_bar_id, " + "foo_bar.id AS foo_bar_id_1, foo.bar_id AS foo_bar_id__1, " + "foo.id AS foo_id__1, foo.bar_id AS foo_bar_id__1, " + "foo_bar.id AS foo_bar_id__2, foo_bar.id AS foo_bar_id__2 " + "FROM foo, foo_bar) " + "SELECT anon_1.foo_id, anon_1.foo_bar_id, anon_1.foo_bar_id_1, " + "anon_1.foo_bar_id AS foo_bar_id_2, anon_1.foo_id AS foo_id_1, " + "anon_1.foo_bar_id AS foo_bar_id_3, " + "anon_1.foo_bar_id_1 AS foo_bar_id_1_1, " + "anon_1.foo_bar_id_1 AS foo_bar_id_1_2 FROM anon_1", + ) + def test_union(self): orders = table("orders", column("region"), column("amount")) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index be894d239e..cfdf4ad02e 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -998,9 +998,11 @@ class SelectableTest( self.assert_compile( stmt, "SELECT anon_1.col1, anon_1.col2, anon_1.col1_1 FROM " - "((SELECT table1.col1, table1.col2, table2.col1 AS col1_1 " + "((SELECT table1.col1 AS col1, table1.col2 AS col2, table2.col1 " + "AS col1_1 " "FROM table1, table2 LIMIT :param_1) UNION " - "(SELECT table2.col1, table2.col2, table2.col3 FROM table2 " + "(SELECT table2.col1 AS col1, table2.col2 AS col2, " + "table2.col3 AS col3 FROM table2 " "LIMIT :param_2)) AS anon_1", ) -- 2.47.2