]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support independent nesting ctes on insert and compound
authorEric Masseran <eric.masseran@gmail.com>
Mon, 6 Sep 2021 15:20:21 +0000 (17:20 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Mon, 6 Sep 2021 15:20:21 +0000 (17:20 +0200)
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

index 6dd33ec46d56462d0a4d5d281633314486808595..ad53eaa9db0cda787ea8623f7517df6341ab8bcb 100644 (file)
@@ -1833,9 +1833,13 @@ class SQLCompiler(Compiled):
             text += self._row_limit_clause(cs, **kwargs)
 
         if self.ctes:
-            # Nesting CTEs from deeper select
-            nesting_level = (len(self.stack) + 1) if not toplevel else None
-            text = self._render_cte_clause(nesting_level=nesting_level) + text
+            nesting_level = len(self.stack) if not toplevel else None
+            text = (
+                self._render_cte_clause(
+                    nesting_level=nesting_level, include_following_stack=True
+                )
+                + text
+            )
 
         self.stack.pop(-1)
         return text
@@ -2499,6 +2503,12 @@ class SQLCompiler(Compiled):
         else:
             return self.bindtemplate % {"name": name}
 
+    def get_name(self, name):
+        if isinstance(name, elements._truncated_label):
+            return self._truncated_identifier("alias", name)
+        else:
+            return name
+
     def visit_cte(
         self,
         cte,
@@ -2514,10 +2524,8 @@ class SQLCompiler(Compiled):
         cte_level = len(self.stack) if cte.nesting else 1
 
         kwargs["visiting_cte"] = cte
-        if isinstance(cte.name, elements._truncated_label):
-            cte_name = self._truncated_identifier("alias", cte.name)
-        else:
-            cte_name = cte.name
+
+        cte_name = self.get_name(cte.name)
 
         is_new_cte = True
         embedded_in_current_named_cte = False
@@ -3125,6 +3133,8 @@ class SQLCompiler(Compiled):
         if toplevel and not self.compile_state:
             self.compile_state = compile_state
 
+        is_embedded_select = compound_index is not None or insert_into
+
         # translate step for Oracle, SQL Server which often need to
         # restructure the SELECT to allow for LIMIT/OFFSET and possibly
         # other conditions
@@ -3273,10 +3283,13 @@ class SQLCompiler(Compiled):
             if per_dialect:
                 text += " " + self.get_statement_hint_text(per_dialect)
 
-        # In compound query, CTEs are shared at the compound level
-        if self.ctes and compound_index is None and not insert_into:
-            nesting_level = len(self.stack) if not toplevel else None
-            text = self._render_cte_clause(nesting_level=nesting_level) + text
+        if self.ctes:
+            # In compound query, CTEs are shared at the compound level
+            if not is_embedded_select:
+                nesting_level = len(self.stack) if not toplevel else None
+                text = (
+                    self._render_cte_clause(nesting_level=nesting_level) + text
+                )
 
         if select_stmt._suffixes:
             text += " " + self._generate_prefixes(
@@ -3451,7 +3464,14 @@ class SQLCompiler(Compiled):
     def _render_cte_clause(
         self,
         nesting_level=None,
+        include_following_stack=False,
     ):
+        """
+        include_following_stack
+            Also render the nesting CTEs on the next stack. Useful for
+            SQL structures like UNION or INSERT that can wrap SELECT
+            statements containing nesting CTEs.
+        """
         if not self.ctes:
             return ""
 
@@ -3459,13 +3479,18 @@ class SQLCompiler(Compiled):
             ctes = {}
             for cte in list(self.ctes.keys()):
                 cte_level = self.level_by_ctes[cte]
-                if not (cte.nesting and cte_level == nesting_level):
+                is_rendered_level = cte_level == nesting_level or (
+                    include_following_stack and cte_level == nesting_level + 1
+                )
+                if not (cte.nesting and is_rendered_level):
                     continue
 
                 ctes[cte] = self.ctes[cte]
 
                 del self.ctes[cte]
                 del self.level_by_ctes[cte]
+                cte_name = self.get_name(cte.name)
+                del self.ctes_by_name[(cte_level, cte_name)]
         else:
             ctes = self.ctes
 
@@ -3732,9 +3757,12 @@ class SQLCompiler(Compiled):
             )
 
             if self.ctes and self.dialect.cte_follows_insert:
-                nesting_level = (len(self.stack) + 1) if not toplevel else None
+                nesting_level = len(self.stack) if not toplevel else None
                 text += " %s%s" % (
-                    self._render_cte_clause(nesting_level=nesting_level),
+                    self._render_cte_clause(
+                        nesting_level=nesting_level,
+                        include_following_stack=True,
+                    ),
                     select_text,
                 )
             else:
@@ -3775,8 +3803,13 @@ class SQLCompiler(Compiled):
             text += " " + returning_clause
 
         if self.ctes and not self.dialect.cte_follows_insert:
-            nesting_level = (len(self.stack) + 1) if not toplevel else None
-            text = self._render_cte_clause(nesting_level=nesting_level) + text
+            nesting_level = len(self.stack) if not toplevel else None
+            text = (
+                self._render_cte_clause(
+                    nesting_level=nesting_level, include_following_stack=True
+                )
+                + text
+            )
 
         self.stack.pop(-1)
 
index 86ba7a9e47dff3da3912db465642faa5c809fa36..b760018da1e18ce8a1b202abc5ad8601a2ff08f4 100644 (file)
@@ -2056,3 +2056,84 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "table_1.price) SELECT delete_cte.id, delete_cte.price "
             "FROM delete_cte",
         )
+
+    def test_compound_select_with_nesting_cte_in_custom_order(self):
+        select_1_cte = select(literal(1).label("inner_cte")).cte(
+            "nesting_1", nesting=True
+        )
+        select_2_cte = select(literal(2).label("inner_cte")).cte(
+            "nesting_2", nesting=True
+        )
+
+        nesting_cte = (
+            select(select_1_cte)
+            .union(select(select_2_cte))
+            # Generate "select_2_cte" first
+            .add_cte(select_2_cte)
+            .subquery()
+        )
+
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS ("
+            "SELECT anon_1.inner_cte AS outer_cte FROM ("
+            "WITH nesting_2 AS (SELECT %(param_1)s AS inner_cte)"
+            ", nesting_1 AS (SELECT %(param_2)s AS inner_cte)"
+            " SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1"
+            " UNION"
+            " SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2"
+            ") AS anon_1"
+            ") SELECT cte.outer_cte FROM cte",
+        )
+
+    def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
+        rec_root = select(literal(1).label("the_value")).cte(
+            "recursive_cte", recursive=True
+        )
+
+        # Allow to reference the recursive CTE more than once
+        rec_root_ref = rec_root.select().cte(
+            "allow_multiple_ref", nesting=True
+        )
+        should_continue = select(
+            exists(
+                select(rec_root_ref.c.the_value)
+                .where(rec_root_ref.c.the_value < 10)
+                .limit(1)
+            ).label("val")
+        ).cte("should_continue", nesting=True)
+
+        rec_part_1 = select(rec_root_ref.c.the_value * 2).where(
+            should_continue.c.val != True
+        )
+        rec_part_2 = select(rec_root_ref.c.the_value * 3).where(
+            should_continue.c.val != True
+        )
+
+        rec_part = rec_part_1.add_cte(rec_root_ref).union_all(rec_part_2)
+
+        rec_cte = rec_root.union_all(rec_part)
+
+        stmt = rec_cte.select()
+
+        self.assert_compile(
+            stmt,
+            "WITH RECURSIVE recursive_cte(the_value) AS ("
+            "SELECT %(param_1)s AS the_value UNION ALL ("
+            "WITH allow_multiple_ref AS ("
+            "SELECT recursive_cte.the_value AS the_value FROM recursive_cte)"
+            ", should_continue AS (SELECT EXISTS ("
+            "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref"
+            " WHERE allow_multiple_ref.the_value < %(the_value_2)s"
+            "  LIMIT %(param_2)s) AS val) "
+            "SELECT allow_multiple_ref.the_value * %(the_value_1)s AS anon_1"
+            " FROM allow_multiple_ref, should_continue WHERE should_continue.val != true"
+            " UNION ALL SELECT allow_multiple_ref.the_value * %(the_value_3)s"
+            " AS anon_2 FROM allow_multiple_ref, should_continue"
+            " WHERE should_continue.val != true))"
+            " SELECT recursive_cte.the_value FROM recursive_cte",
+        )