From: Eric Masseran Date: Fri, 2 Jul 2021 16:24:47 +0000 (+0200) Subject: Initial implementation on select only X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6be7744a6c93fbf5f2dcf36c5a8262034669197c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Initial implementation on select only --- diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 4c654a643b..baa0efe4b2 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3091,6 +3091,7 @@ class PGDialect(default.DefaultDialect): supports_comments = True supports_default_values = True + supports_nesting_cte = True supports_default_metavalue = True diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 66a556ae0d..4d0cb56ade 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1807,6 +1807,7 @@ class SQLiteDialect(default.DefaultDialect): supports_multivalues_insert = True tuple_in_values = True supports_statement_cache = True + supports_nesting_cte = True default_paramstyle = "qmark" execution_ctx_cls = SQLiteExecutionContext diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 7b3fa091fd..8cdde558b6 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -81,6 +81,7 @@ class DefaultDialect(interfaces.Dialect): insert_executemany_returning = False cte_follows_insert = False + supports_nesting_cte = False supports_native_enum = False supports_native_boolean = False diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a89b86c793..f240b330e6 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -500,7 +500,7 @@ class Query( q = q.reduce_columns() return q.alias(name=name) - def cte(self, name=None, recursive=False): + def cte(self, name=None, recursive=False, nesting=False): r"""Return the full SELECT statement represented by this :class:`_query.Query` represented as a common table expression (CTE). @@ -556,7 +556,7 @@ class Query( """ return self.enable_eagerloads(False).statement.cte( - name=name, recursive=recursive + name=name, recursive=recursive, nesting=nesting ) def label(self, name): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 67da036830..58d01088ac 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3150,8 +3150,8 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + text = self._render_cte_clause(nesting_only=(not toplevel)) + text if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -3323,14 +3323,37 @@ class SQLCompiler(Compiled): clause += " " return clause - def _render_cte_clause(self): + def _render_cte_clause( + self, + nesting_only: bool = False, + ): + ctes = self.ctes + + if nesting_only: + ctes = {cte: ctes[cte] for cte in ctes if cte.nesting} + # Remove them from the visible CTEs + self.ctes = { + cte: self.ctes[cte] for cte in self.ctes if not cte.nesting + } + + if ctes and not self.dialect.supports_nesting_cte: + raise exc.CompileError( + "Nesting CTE is not supported by this " + "dialect's statement compiler." + ) + + ctes_recursive = any([cte.recursive for cte in ctes]) + + if not ctes: + return "" + if self.positional: self.positiontup = ( - sum([self.cte_positional[cte] for cte in self.ctes], []) + sum([self.cte_positional[cte] for cte in ctes], []) + self.positiontup ) - cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join([txt for txt in self.ctes.values()]) + cte_text = self.get_cte_preamble(ctes_recursive) + " " + cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " return cte_text diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 557c443bf7..740ccc0ae7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2039,12 +2039,14 @@ class CTE( selectable, name=None, recursive=False, + nesting=False, _cte_alias=None, _restates=(), _prefixes=None, _suffixes=None, ): self.recursive = recursive + self.nesting = nesting self._cte_alias = _cte_alias self._restates = _restates if _prefixes: @@ -2077,6 +2079,7 @@ class CTE( self.element, name=name, recursive=self.recursive, + nesting=self.nesting, _cte_alias=self, _prefixes=self._prefixes, _suffixes=self._suffixes, @@ -2087,6 +2090,7 @@ class CTE( self.element.union(other), name=self.name, recursive=self.recursive, + nesting=self.nesting, _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes, @@ -2097,6 +2101,7 @@ class CTE( self.element.union_all(other), name=self.name, recursive=self.recursive, + nesting=self.nesting, _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes, @@ -2110,7 +2115,7 @@ class HasCTE(roles.HasCTERole): """ - def cte(self, name=None, recursive=False): + def cte(self, name=None, recursive=False, nesting=False): r"""Return a new :class:`_expression.CTE`, or Common Table Expression instance. @@ -2276,7 +2281,9 @@ class HasCTE(roles.HasCTERole): :meth:`_expression.HasCTE.cte`. """ - return CTE._construct(self, name=name, recursive=recursive) + return CTE._construct( + self, name=name, recursive=recursive, nesting=nesting + ) class Subquery(AliasedReturnsRows): diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 01186c340c..6da5ded99b 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,8 +1,10 @@ +import functools +import pytest from sqlalchemy import delete from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update -from sqlalchemy.dialects import mssql +from sqlalchemy.dialects import mssql, mysql from sqlalchemy.engine import default from sqlalchemy.exc import CompileError from sqlalchemy.sql import and_ @@ -1377,3 +1379,98 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): a_stmt, "foo", ) + + def test_nesting_cte_in_cte(self): + nesting_cte = select([literal(1).label("inner")]).cte( + "nesting", nesting=True + ) + stmt = select( + [select([nesting_cte.c.inner.label("outer")]).cte("cte")] + ) + + self.assert_compile( + stmt, + 'WITH cte AS (WITH nesting AS (SELECT %(param_1)s AS "inner") ' + 'SELECT nesting."inner" AS "outer" FROM nesting) ' + 'SELECT cte."outer" FROM cte', + dialect="postgresql", + ) + + def test_nesting_cte_in_recursive_cte(self): + nesting_cte = select([literal(1).label("inner")]).cte( + "nesting", nesting=True + ) + stmt = select( + [ + select([nesting_cte.c.inner.label("outer")]).cte( + "cte", recursive=True + ) + ] + ) + + self.assert_compile( + stmt, + 'WITH RECURSIVE cte("outer") AS (WITH nesting AS ' + '(SELECT %(param_1)s AS "inner") ' + 'SELECT nesting."inner" AS "outer" FROM nesting) ' + 'SELECT cte."outer" FROM cte', + dialect="postgresql", + ) + + def test_recursive_nesting_cte_in_cte(self): + nesting_cte = select([literal(1).label("inner")]).cte( + "nesting", nesting=True, recursive=True + ) + stmt = select( + [select([nesting_cte.c.inner.label("outer")]).cte("cte")] + ) + + self.assert_compile( + stmt, + 'WITH cte AS (WITH RECURSIVE nesting("inner") AS ' + '(SELECT %(param_1)s AS "inner") ' + 'SELECT nesting."inner" AS "outer" FROM nesting) ' + 'SELECT cte."outer" FROM cte', + dialect="postgresql", + ) + + def test_recursive_nesting_cte_in_recursive_cte(self): + nesting_cte = select([literal(1).label("inner")]).cte( + "nesting", nesting=True, recursive=True + ) + stmt = select( + [ + select([nesting_cte.c.inner.label("outer")]).cte( + "cte", recursive=True + ) + ] + ) + + self.assert_compile( + stmt, + 'WITH RECURSIVE cte("outer") AS (WITH RECURSIVE nesting("inner") ' + 'AS (SELECT %(param_1)s AS "inner") ' + 'SELECT nesting."inner" AS "outer" FROM nesting) ' + 'SELECT cte."outer" FROM cte', + dialect="postgresql", + ) + + @pytest.mark.parametrize( + "dialect", + [mysql], + ) + def test_nesting_cte_unsupported_backend_raise(self, dialect): + stmt = select( + [ + select( + [select([literal(1).label("one")]).cte("t2", nesting=True)] + ).cte("t") + ] + ) + + assert_raises_message( + CompileError, + "Nesting CTE is not supported by this " + "dialect's statement compiler.", + functools.partial(stmt.compile, dialect=dialect.dialect()), + )