From 489e8cd05d269554c8205e8bdce6147acfa83015 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Sat, 10 Jul 2021 23:13:10 +0200 Subject: [PATCH] Allow nesting cte to override name --- lib/sqlalchemy/sql/compiler.py | 11 +++++++---- test/sql/test_cte.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2a2c280279..847bc12c56 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -842,6 +842,7 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT self.ctes = util.OrderedDict() + # Detect same CTE references self.ctes_by_name = {} self.ctes_recursive = False if self.positional: @@ -2498,8 +2499,9 @@ class SQLCompiler(Compiled): ): self._init_cte_state() + cte_level = len(self.stack) if cte.nesting: - cte.nesting_level = len(self.stack) + cte.nesting_level = cte_level kwargs["visiting_cte"] = cte if isinstance(cte.name, elements._truncated_label): @@ -2510,8 +2512,9 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - if cte_name in self.ctes_by_name: - existing_cte = self.ctes_by_name[cte_name] + cte_level_name = (cte_level, cte_name) + if cte_level_name in self.ctes_by_name: + existing_cte = self.ctes_by_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, @@ -2542,7 +2545,7 @@ class SQLCompiler(Compiled): cte_pre_alias_name = None if is_new_cte: - self.ctes_by_name[cte_name] = cte + self.ctes_by_name[cte_level_name] = cte if ( "autocommit" in cte.element._execution_options diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index a2e5a2726e..9a739cea54 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1401,6 +1401,22 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): dialect="postgresql", ) + def test_nesting_cte_in_cte_with_same_name(self): + nesting_cte = select([literal(1).label("inner")]).cte( + "some_cte", nesting=True + ) + stmt = select( + [select([nesting_cte.c.inner.label("outer")]).cte("some_cte")] + ) + + self.assert_compile( + stmt, + 'WITH some_cte AS (WITH some_cte AS (SELECT %(param_1)s AS "inner") ' + 'SELECT some_cte."inner" AS "outer" FROM some_cte) ' + 'SELECT some_cte."outer" FROM some_cte', + dialect="postgresql", + ) + def test_nesting_cte_at_top_level(self): nesting_cte = select([literal(1).label("val")]).cte( "nesting_cte", nesting=True -- 2.47.3