From bc9ee074d2549358da83535036877b8690ebb841 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Mon, 6 Sep 2021 17:20:21 +0200 Subject: [PATCH] Support independent nesting ctes on insert and compound --- lib/sqlalchemy/sql/compiler.py | 65 ++++++++++++++++++++------- test/sql/test_cte.py | 81 ++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 16 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6dd33ec46d..ad53eaa9db 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1833,9 +1833,13 @@ class SQLCompiler(Compiled): text += self._row_limit_clause(cs, **kwargs) if self.ctes: - # Nesting CTEs from deeper select - nesting_level = (len(self.stack) + 1) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) self.stack.pop(-1) return text @@ -2499,6 +2503,12 @@ class SQLCompiler(Compiled): else: return self.bindtemplate % {"name": name} + def get_name(self, name): + if isinstance(name, elements._truncated_label): + return self._truncated_identifier("alias", name) + else: + return name + def visit_cte( self, cte, @@ -2514,10 +2524,8 @@ class SQLCompiler(Compiled): cte_level = len(self.stack) if cte.nesting else 1 kwargs["visiting_cte"] = cte - if isinstance(cte.name, elements._truncated_label): - cte_name = self._truncated_identifier("alias", cte.name) - else: - cte_name = cte.name + + cte_name = self.get_name(cte.name) is_new_cte = True embedded_in_current_named_cte = False @@ -3125,6 +3133,8 @@ class SQLCompiler(Compiled): if toplevel and not self.compile_state: self.compile_state = compile_state + is_embedded_select = compound_index is not None or insert_into + # translate step for Oracle, SQL Server which often need to # restructure the SELECT to allow for LIMIT/OFFSET and possibly # other conditions @@ -3273,10 +3283,13 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - # In compound query, CTEs are shared at the compound level - if self.ctes and compound_index is None and not insert_into: - nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + if self.ctes: + # In compound query, CTEs are shared at the compound level + if not is_embedded_select: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause(nesting_level=nesting_level) + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -3451,7 +3464,14 @@ class SQLCompiler(Compiled): def _render_cte_clause( self, nesting_level=None, + include_following_stack=False, ): + """ + include_following_stack + Also render the nesting CTEs on the next stack. Useful for + SQL structures like UNION or INSERT that can wrap SELECT + statements containing nesting CTEs. + """ if not self.ctes: return "" @@ -3459,13 +3479,18 @@ class SQLCompiler(Compiled): ctes = {} for cte in list(self.ctes.keys()): cte_level = self.level_by_ctes[cte] - if not (cte.nesting and cte_level == nesting_level): + is_rendered_level = cte_level == nesting_level or ( + include_following_stack and cte_level == nesting_level + 1 + ) + if not (cte.nesting and is_rendered_level): continue ctes[cte] = self.ctes[cte] del self.ctes[cte] del self.level_by_ctes[cte] + cte_name = self.get_name(cte.name) + del self.ctes_by_name[(cte_level, cte_name)] else: ctes = self.ctes @@ -3732,9 +3757,12 @@ class SQLCompiler(Compiled): ) if self.ctes and self.dialect.cte_follows_insert: - nesting_level = (len(self.stack) + 1) if not toplevel else None + nesting_level = len(self.stack) if not toplevel else None text += " %s%s" % ( - self._render_cte_clause(nesting_level=nesting_level), + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ), select_text, ) else: @@ -3775,8 +3803,13 @@ class SQLCompiler(Compiled): text += " " + returning_clause if self.ctes and not self.dialect.cte_follows_insert: - nesting_level = (len(self.stack) + 1) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) self.stack.pop(-1) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 86ba7a9e47..b760018da1 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -2056,3 +2056,84 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "table_1.price) SELECT delete_cte.id, delete_cte.price " "FROM delete_cte", ) + + def test_compound_select_with_nesting_cte_in_custom_order(self): + select_1_cte = select(literal(1).label("inner_cte")).cte( + "nesting_1", nesting=True + ) + select_2_cte = select(literal(2).label("inner_cte")).cte( + "nesting_2", nesting=True + ) + + nesting_cte = ( + select(select_1_cte) + .union(select(select_2_cte)) + # Generate "select_2_cte" first + .add_cte(select_2_cte) + .subquery() + ) + + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + "SELECT anon_1.inner_cte AS outer_cte FROM (" + "WITH nesting_2 AS (SELECT %(param_1)s AS inner_cte)" + ", nesting_1 AS (SELECT %(param_2)s AS inner_cte)" + " SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1" + " UNION" + " SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2" + ") AS anon_1" + ") SELECT cte.outer_cte FROM cte", + ) + + def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self): + rec_root = select(literal(1).label("the_value")).cte( + "recursive_cte", recursive=True + ) + + # Allow to reference the recursive CTE more than once + rec_root_ref = rec_root.select().cte( + "allow_multiple_ref", nesting=True + ) + should_continue = select( + exists( + select(rec_root_ref.c.the_value) + .where(rec_root_ref.c.the_value < 10) + .limit(1) + ).label("val") + ).cte("should_continue", nesting=True) + + rec_part_1 = select(rec_root_ref.c.the_value * 2).where( + should_continue.c.val != True + ) + rec_part_2 = select(rec_root_ref.c.the_value * 3).where( + should_continue.c.val != True + ) + + rec_part = rec_part_1.add_cte(rec_root_ref).union_all(rec_part_2) + + rec_cte = rec_root.union_all(rec_part) + + stmt = rec_cte.select() + + self.assert_compile( + stmt, + "WITH RECURSIVE recursive_cte(the_value) AS (" + "SELECT %(param_1)s AS the_value UNION ALL (" + "WITH allow_multiple_ref AS (" + "SELECT recursive_cte.the_value AS the_value FROM recursive_cte)" + ", should_continue AS (SELECT EXISTS (" + "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref" + " WHERE allow_multiple_ref.the_value < %(the_value_2)s" + " LIMIT %(param_2)s) AS val) " + "SELECT allow_multiple_ref.the_value * %(the_value_1)s AS anon_1" + " FROM allow_multiple_ref, should_continue WHERE should_continue.val != true" + " UNION ALL SELECT allow_multiple_ref.the_value * %(the_value_3)s" + " AS anon_2 FROM allow_multiple_ref, should_continue" + " WHERE should_continue.val != true))" + " SELECT recursive_cte.the_value FROM recursive_cte", + ) -- 2.47.3