From: Eric Masseran Date: Sat, 10 Jul 2021 16:23:17 +0000 (+0200) Subject: Support compound select X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6467a6531ceb51b06ef15071f47a4b4db98606a1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support compound select --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index e7b769ca43..2a2c280279 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1822,8 +1822,12 @@ class SQLCompiler(Compiled): if cs._has_row_limiting_clause: text += self._row_limit_clause(cs, **kwargs) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + # Nesting CTEs from deeper select + text = ( + self._render_cte_clause(nesting_level=len(self.stack) + 1) + + text + ) self.stack.pop(-1) return text @@ -3153,8 +3157,11 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes: - text = self._render_cte_clause(nesting_only=(not toplevel)) + text + # In compound query, CTEs are shared at the compound level + if self.ctes and compound_index is None: + text = ( + self._render_cte_clause(nesting_level=len(self.stack)) + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -3328,21 +3335,21 @@ class SQLCompiler(Compiled): def _render_cte_clause( self, - nesting_only=False, + nesting_level=None, ): ctes = self.ctes - if nesting_only: + if nesting_level and nesting_level > 1: ctes = { cte: ctes[cte] for cte in ctes - if cte.nesting and cte.nesting_level == len(self.stack) + if cte.nesting_level == nesting_level } # Remove them from the visible CTEs self.ctes = { cte: self.ctes[cte] for cte in self.ctes - if not (cte.nesting and cte.nesting_level == len(self.stack)) + if not cte.nesting_level == nesting_level } if ctes and not self.dialect.supports_nesting_cte: diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 46472f45fd..8643fa3f0a 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1401,6 +1401,19 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): dialect="postgresql", ) + def test_nesting_cte_at_top_level(self): + nesting_cte = select([literal(1).label("val")]).cte( + "nesting_cte", nesting=True + ) + stmt = select([nesting_cte.c.val]) + + self.assert_compile( + stmt, + "WITH nesting_cte AS (SELECT %(param_1)s AS val) " + "SELECT nesting_cte.val FROM nesting_cte", + dialect="postgresql", + ) + def test_double_nesting_cte_in_cte(self): select_1_cte = select([literal(1).label("inner")]).cte( "nesting_1", nesting=True @@ -1455,6 +1468,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): dialect="postgresql", ) + def test_compound_select_with_nesting_cte_in_cte(self): + select_1_cte = select([literal(1).label("inner")]).cte( + "nesting_1", nesting=True + ) + select_2_cte = select([literal(2).label("inner")]).cte( + "nesting_2", nesting=True + ) + + nesting_cte = ( + select([select_1_cte]).union(select([select_2_cte])).subquery() + ) + + stmt = select( + [select([nesting_cte.c.inner.label("outer")]).cte("cte")] + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + 'SELECT anon_1."inner" AS "outer" FROM (' + 'WITH nesting_1 AS (SELECT %(param_1)s AS "inner")' + ', nesting_2 AS (SELECT %(param_2)s AS "inner")' + ' SELECT nesting_1."inner" AS "inner" FROM nesting_1' + " UNION" + ' SELECT nesting_2."inner" AS "inner" FROM nesting_2' + ") AS anon_1" + ') SELECT cte."outer" FROM cte', + dialect="postgresql", + ) + def test_nesting_cte_in_recursive_cte(self): nesting_cte = select([literal(1).label("inner")]).cte( "nesting", nesting=True