]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support compound select
authorEric Masseran <eric.masseran@gmail.com>
Sat, 10 Jul 2021 16:23:17 +0000 (18:23 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Sat, 10 Jul 2021 16:23:17 +0000 (18:23 +0200)
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

index e7b769ca43bb466d731a4c0c9fa0d887f17490ab..2a2c280279a09fc148a01def9f92ad3c9fd3c596 100644 (file)
@@ -1822,8 +1822,12 @@ class SQLCompiler(Compiled):
         if cs._has_row_limiting_clause:
             text += self._row_limit_clause(cs, **kwargs)
 
-        if self.ctes and toplevel:
-            text = self._render_cte_clause() + text
+        if self.ctes:
+            # Nesting CTEs from deeper select
+            text = (
+                self._render_cte_clause(nesting_level=len(self.stack) + 1)
+                + text
+            )
 
         self.stack.pop(-1)
         return text
@@ -3153,8 +3157,11 @@ class SQLCompiler(Compiled):
             if per_dialect:
                 text += " " + self.get_statement_hint_text(per_dialect)
 
-        if self.ctes:
-            text = self._render_cte_clause(nesting_only=(not toplevel)) + text
+        # In compound query, CTEs are shared at the compound level
+        if self.ctes and compound_index is None:
+            text = (
+                self._render_cte_clause(nesting_level=len(self.stack)) + text
+            )
 
         if select_stmt._suffixes:
             text += " " + self._generate_prefixes(
@@ -3328,21 +3335,21 @@ class SQLCompiler(Compiled):
 
     def _render_cte_clause(
         self,
-        nesting_only=False,
+        nesting_level=None,
     ):
         ctes = self.ctes
 
-        if nesting_only:
+        if nesting_level and nesting_level > 1:
             ctes = {
                 cte: ctes[cte]
                 for cte in ctes
-                if cte.nesting and cte.nesting_level == len(self.stack)
+                if cte.nesting_level == nesting_level
             }
             # Remove them from the visible CTEs
             self.ctes = {
                 cte: self.ctes[cte]
                 for cte in self.ctes
-                if not (cte.nesting and cte.nesting_level == len(self.stack))
+                if not cte.nesting_level == nesting_level
             }
 
             if ctes and not self.dialect.supports_nesting_cte:
index 46472f45fddf3f9d7294cc51329a9eaf8dbac3ca..8643fa3f0ae54ee18f6b56b703d1c211467ef673 100644 (file)
@@ -1401,6 +1401,19 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             dialect="postgresql",
         )
 
+    def test_nesting_cte_at_top_level(self):
+        nesting_cte = select([literal(1).label("val")]).cte(
+            "nesting_cte", nesting=True
+        )
+        stmt = select([nesting_cte.c.val])
+
+        self.assert_compile(
+            stmt,
+            "WITH nesting_cte AS (SELECT %(param_1)s AS val) "
+            "SELECT nesting_cte.val FROM nesting_cte",
+            dialect="postgresql",
+        )
+
     def test_double_nesting_cte_in_cte(self):
         select_1_cte = select([literal(1).label("inner")]).cte(
             "nesting_1", nesting=True
@@ -1455,6 +1468,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             dialect="postgresql",
         )
 
+    def test_compound_select_with_nesting_cte_in_cte(self):
+        select_1_cte = select([literal(1).label("inner")]).cte(
+            "nesting_1", nesting=True
+        )
+        select_2_cte = select([literal(2).label("inner")]).cte(
+            "nesting_2", nesting=True
+        )
+
+        nesting_cte = (
+            select([select_1_cte]).union(select([select_2_cte])).subquery()
+        )
+
+        stmt = select(
+            [select([nesting_cte.c.inner.label("outer")]).cte("cte")]
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS ("
+            'SELECT anon_1."inner" AS "outer" FROM ('
+            'WITH nesting_1 AS (SELECT %(param_1)s AS "inner")'
+            ', nesting_2 AS (SELECT %(param_2)s AS "inner")'
+            ' SELECT nesting_1."inner" AS "inner" FROM nesting_1'
+            " UNION"
+            ' SELECT nesting_2."inner" AS "inner" FROM nesting_2'
+            ") AS anon_1"
+            ') SELECT cte."outer" FROM cte',
+            dialect="postgresql",
+        )
+
     def test_nesting_cte_in_recursive_cte(self):
         nesting_cte = select([literal(1).label("inner")]).cte(
             "nesting", nesting=True