]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Initial implementation on select only
authorEric Masseran <eric.masseran@gmail.com>
Fri, 2 Jul 2021 16:24:47 +0000 (18:24 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Fri, 2 Jul 2021 16:24:47 +0000 (18:24 +0200)
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/sql/test_cte.py

index 4c654a643b04cbde16a03f47e32781cb356dac57..baa0efe4b2f74c586db727c2d6d0ee65bb76b18d 100644 (file)
@@ -3091,6 +3091,7 @@ class PGDialect(default.DefaultDialect):
 
     supports_comments = True
     supports_default_values = True
+    supports_nesting_cte = True
 
     supports_default_metavalue = True
 
index 66a556ae0d5cb3497275159bd993fff37e599663..4d0cb56adef8e7259ff59f736c0a5ac9477833d5 100644 (file)
@@ -1807,6 +1807,7 @@ class SQLiteDialect(default.DefaultDialect):
     supports_multivalues_insert = True
     tuple_in_values = True
     supports_statement_cache = True
+    supports_nesting_cte = True
 
     default_paramstyle = "qmark"
     execution_ctx_cls = SQLiteExecutionContext
index 7b3fa091fd430dddd703869cf79a6f7a403aee3a..8cdde558b601ea05336a0cbd212656a73dd86c15 100644 (file)
@@ -81,6 +81,7 @@ class DefaultDialect(interfaces.Dialect):
     insert_executemany_returning = False
 
     cte_follows_insert = False
+    supports_nesting_cte = False
 
     supports_native_enum = False
     supports_native_boolean = False
index a89b86c7935f91255684cb385c95220d6ed88e40..f240b330e6a10fcda76aa43c2e2e5cc4351d9859 100644 (file)
@@ -500,7 +500,7 @@ class Query(
             q = q.reduce_columns()
         return q.alias(name=name)
 
-    def cte(self, name=None, recursive=False):
+    def cte(self, name=None, recursive=False, nesting=False):
         r"""Return the full SELECT statement represented by this
         :class:`_query.Query` represented as a common table expression (CTE).
 
@@ -556,7 +556,7 @@ class Query(
 
         """
         return self.enable_eagerloads(False).statement.cte(
-            name=name, recursive=recursive
+            name=name, recursive=recursive, nesting=nesting
         )
 
     def label(self, name):
index 67da036830406095462ec3efd7b195eacc26516d..58d01088aca08f55f400758847b30da1fb14903f 100644 (file)
@@ -3150,8 +3150,8 @@ class SQLCompiler(Compiled):
             if per_dialect:
                 text += " " + self.get_statement_hint_text(per_dialect)
 
-        if self.ctes and toplevel:
-            text = self._render_cte_clause() + text
+        if self.ctes:
+            text = self._render_cte_clause(nesting_only=(not toplevel)) + text
 
         if select_stmt._suffixes:
             text += " " + self._generate_prefixes(
@@ -3323,14 +3323,37 @@ class SQLCompiler(Compiled):
             clause += " "
         return clause
 
-    def _render_cte_clause(self):
+    def _render_cte_clause(
+        self,
+        nesting_only: bool = False,
+    ):
+        ctes = self.ctes
+
+        if nesting_only:
+            ctes = {cte: ctes[cte] for cte in ctes if cte.nesting}
+            # Remove them from the visible CTEs
+            self.ctes = {
+                cte: self.ctes[cte] for cte in self.ctes if not cte.nesting
+            }
+
+            if ctes and not self.dialect.supports_nesting_cte:
+                raise exc.CompileError(
+                    "Nesting CTE is not supported by this "
+                    "dialect's statement compiler."
+                )
+
+        ctes_recursive = any([cte.recursive for cte in ctes])
+
+        if not ctes:
+            return ""
+
         if self.positional:
             self.positiontup = (
-                sum([self.cte_positional[cte] for cte in self.ctes], [])
+                sum([self.cte_positional[cte] for cte in ctes], [])
                 + self.positiontup
             )
-        cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
-        cte_text += ", \n".join([txt for txt in self.ctes.values()])
+        cte_text = self.get_cte_preamble(ctes_recursive) + " "
+        cte_text += ", \n".join([txt for txt in ctes.values()])
         cte_text += "\n "
         return cte_text
 
index 557c443bf7cfdbd2fe9271df887d3ed5d1b0bf2b..740ccc0ae70d9805adc40a4552f1fdd39f48ae7b 100644 (file)
@@ -2039,12 +2039,14 @@ class CTE(
         selectable,
         name=None,
         recursive=False,
+        nesting=False,
         _cte_alias=None,
         _restates=(),
         _prefixes=None,
         _suffixes=None,
     ):
         self.recursive = recursive
+        self.nesting = nesting
         self._cte_alias = _cte_alias
         self._restates = _restates
         if _prefixes:
@@ -2077,6 +2079,7 @@ class CTE(
             self.element,
             name=name,
             recursive=self.recursive,
+            nesting=self.nesting,
             _cte_alias=self,
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
@@ -2087,6 +2090,7 @@ class CTE(
             self.element.union(other),
             name=self.name,
             recursive=self.recursive,
+            nesting=self.nesting,
             _restates=self._restates + (self,),
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
@@ -2097,6 +2101,7 @@ class CTE(
             self.element.union_all(other),
             name=self.name,
             recursive=self.recursive,
+            nesting=self.nesting,
             _restates=self._restates + (self,),
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
@@ -2110,7 +2115,7 @@ class HasCTE(roles.HasCTERole):
 
     """
 
-    def cte(self, name=None, recursive=False):
+    def cte(self, name=None, recursive=False, nesting=False):
         r"""Return a new :class:`_expression.CTE`,
         or Common Table Expression instance.
 
@@ -2276,7 +2281,9 @@ class HasCTE(roles.HasCTERole):
             :meth:`_expression.HasCTE.cte`.
 
         """
-        return CTE._construct(self, name=name, recursive=recursive)
+        return CTE._construct(
+            self, name=name, recursive=recursive, nesting=nesting
+        )
 
 
 class Subquery(AliasedReturnsRows):
index 01186c340cc297529389ab12ae724e8fa94bc99d..6da5ded99b8bf14eaa87d0ea242dd13350e29d5e 100644 (file)
@@ -1,8 +1,10 @@
+import functools
+import pytest
 from sqlalchemy import delete
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import update
-from sqlalchemy.dialects import mssql
+from sqlalchemy.dialects import mssql, mysql
 from sqlalchemy.engine import default
 from sqlalchemy.exc import CompileError
 from sqlalchemy.sql import and_
@@ -1377,3 +1379,98 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             a_stmt,
             "foo",
         )
+
+    def test_nesting_cte_in_cte(self):
+        nesting_cte = select([literal(1).label("inner")]).cte(
+            "nesting", nesting=True
+        )
+        stmt = select(
+            [select([nesting_cte.c.inner.label("outer")]).cte("cte")]
+        )
+
+        self.assert_compile(
+            stmt,
+            'WITH cte AS (WITH nesting AS (SELECT %(param_1)s AS "inner") '
+            'SELECT nesting."inner" AS "outer" FROM nesting) '
+            'SELECT cte."outer" FROM cte',
+            dialect="postgresql",
+        )
+
+    def test_nesting_cte_in_recursive_cte(self):
+        nesting_cte = select([literal(1).label("inner")]).cte(
+            "nesting", nesting=True
+        )
+        stmt = select(
+            [
+                select([nesting_cte.c.inner.label("outer")]).cte(
+                    "cte", recursive=True
+                )
+            ]
+        )
+
+        self.assert_compile(
+            stmt,
+            'WITH RECURSIVE cte("outer") AS (WITH nesting AS '
+            '(SELECT %(param_1)s AS "inner") '
+            'SELECT nesting."inner" AS "outer" FROM nesting) '
+            'SELECT cte."outer" FROM cte',
+            dialect="postgresql",
+        )
+
+    def test_recursive_nesting_cte_in_cte(self):
+        nesting_cte = select([literal(1).label("inner")]).cte(
+            "nesting", nesting=True, recursive=True
+        )
+        stmt = select(
+            [select([nesting_cte.c.inner.label("outer")]).cte("cte")]
+        )
+
+        self.assert_compile(
+            stmt,
+            'WITH cte AS (WITH RECURSIVE nesting("inner") AS '
+            '(SELECT %(param_1)s AS "inner") '
+            'SELECT nesting."inner" AS "outer" FROM nesting) '
+            'SELECT cte."outer" FROM cte',
+            dialect="postgresql",
+        )
+
+    def test_recursive_nesting_cte_in_recursive_cte(self):
+        nesting_cte = select([literal(1).label("inner")]).cte(
+            "nesting", nesting=True, recursive=True
+        )
+        stmt = select(
+            [
+                select([nesting_cte.c.inner.label("outer")]).cte(
+                    "cte", recursive=True
+                )
+            ]
+        )
+
+        self.assert_compile(
+            stmt,
+            'WITH RECURSIVE cte("outer") AS (WITH RECURSIVE nesting("inner") '
+            'AS (SELECT %(param_1)s AS "inner") '
+            'SELECT nesting."inner" AS "outer" FROM nesting) '
+            'SELECT cte."outer" FROM cte',
+            dialect="postgresql",
+        )
+
+    @pytest.mark.parametrize(
+        "dialect",
+        [mysql],
+    )
+    def test_nesting_cte_unsupported_backend_raise(self, dialect):
+        stmt = select(
+            [
+                select(
+                    [select([literal(1).label("one")]).cte("t2", nesting=True)]
+                ).cte("t")
+            ]
+        )
+
+        assert_raises_message(
+            CompileError,
+            "Nesting CTE is not supported by this "
+            "dialect's statement compiler.",
+            functools.partial(stmt.compile, dialect=dialect.dialect()),
+        )