From: Eric Masseran Date: Sat, 10 Jul 2021 16:08:11 +0000 (+0200) Subject: Respect nesting level X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fccb69cd4b1ac665b61e1c8f2e204c315575df48;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Respect nesting level --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 02b697480b..e7b769ca43 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2494,6 +2494,9 @@ class SQLCompiler(Compiled): ): self._init_cte_state() + if cte.nesting: + cte.nesting_level = len(self.stack) + kwargs["visiting_cte"] = cte if isinstance(cte.name, elements._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) @@ -3330,10 +3333,16 @@ class SQLCompiler(Compiled): ctes = self.ctes if nesting_only: - ctes = {cte: ctes[cte] for cte in ctes if cte.nesting} + ctes = { + cte: ctes[cte] + for cte in ctes + if cte.nesting and cte.nesting_level == len(self.stack) + } # Remove them from the visible CTEs self.ctes = { - cte: self.ctes[cte] for cte in self.ctes if not cte.nesting + cte: self.ctes[cte] + for cte in self.ctes + if not (cte.nesting and cte.nesting_level == len(self.stack)) } if ctes and not self.dialect.supports_nesting_cte: diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 6da5ded99b..46472f45fd 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1380,7 +1380,12 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "foo", ) - def test_nesting_cte_in_cte(self): + +class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): + + __dialect__ = "default_enhanced" + + def test_select_with_nesting_cte_in_cte(self): nesting_cte = select([literal(1).label("inner")]).cte( "nesting", nesting=True ) @@ -1396,6 +1401,60 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect="postgresql", ) + def test_double_nesting_cte_in_cte(self): + select_1_cte = select([literal(1).label("inner")]).cte( + "nesting_1", nesting=True + ) + select_2_cte = select([literal(2).label("inner")]).cte( + "nesting_2", nesting=True + ) + + stmt = select( + [ + select( + [ + select_1_cte.c.inner.label("outer_1"), + select_2_cte.c.inner.label("outer_2"), + ] + ).cte("cte") + ] + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + 'WITH nesting_1 AS (SELECT %(param_1)s AS "inner")' + ', nesting_2 AS (SELECT %(param_2)s AS "inner")' + ' SELECT nesting_1."inner" AS outer_1' + ', nesting_2."inner" AS outer_2' + " FROM nesting_1, nesting_2" + ") SELECT cte.outer_1, cte.outer_2 FROM cte", + dialect="postgresql", + ) + + def test_nesting_cte_in_nesting_cte_in_cte(self): + select_1_cte = select([literal(1).label("inner")]).cte( + "nesting_1", nesting=True + ) + select_2_cte = select([select_1_cte.c.inner.label("inner_2")]).cte( + "nesting_2", nesting=True + ) + + stmt = select( + [select([select_2_cte.c.inner_2.label("outer")]).cte("cte")] + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + "WITH nesting_2 AS (" + 'WITH nesting_1 AS (SELECT %(param_1)s AS "inner")' + ' SELECT nesting_1."inner" AS inner_2 FROM nesting_1' + ') SELECT nesting_2.inner_2 AS "outer" FROM nesting_2' + ') SELECT cte."outer" FROM cte', + dialect="postgresql", + ) + def test_nesting_cte_in_recursive_cte(self): nesting_cte = select([literal(1).label("inner")]).cte( "nesting", nesting=True