From 70a5f3fe895a38a16e2a3387400ca2ee4cbc76a4 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Fri, 30 Jul 2021 19:48:48 +0200 Subject: [PATCH] Keep already defined cte --- lib/sqlalchemy/sql/compiler.py | 3 +++ test/sql/test_cte.py | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c191a5049d..6a55eea5ff 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2510,6 +2510,9 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False + if cte in self.level_by_ctes: + cte_level = self.level_by_ctes[cte] + cte_level_name = (cte_level, cte_name) if cte_level_name in self.ctes_by_name: existing_cte = self.ctes_by_name[cte_level_name] diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 45cbea3a88..f74efa13a1 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1888,6 +1888,49 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) + def test_same_nested_cte_is_not_generated_twice(self): + # Same = name and query + nesting_cte_used_twice = select([literal(1).label("inner_cte_1")]).cte( + "nesting_cte", nesting=True + ) + select_add_cte = select( + [(nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")] + ).cte("nesting_2", nesting=True) + + union_cte = ( + select( + [ + (nesting_cte_used_twice.c.inner_cte_1 - 1).label( + "next_value" + ) + ] + ) + .union(select([select_add_cte])) + .cte("wrapper", nesting=True) + ) + + stmt = ( + select([union_cte]) + .add_cte(nesting_cte_used_twice) + .union(select([nesting_cte_used_twice])) + ) + + self.assert_compile( + stmt, + "WITH nesting_cte AS " + "(SELECT %(param_1)s AS inner_cte_1)" + ", wrapper AS " + "(WITH nesting_2 AS " + "(SELECT nesting_cte.inner_cte_1 + %(inner_cte_1_2)s AS next_value " + "FROM nesting_cte)" + " SELECT nesting_cte.inner_cte_1 - %(inner_cte_1_1)s AS next_value " + "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value " + "FROM nesting_2)" + " SELECT wrapper.next_value " + "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " + "FROM nesting_cte", + ) + def test_recursive_nesting_cte_in_recursive_cte(self): nesting_cte = select([literal(1).label("inner_cte")]).cte( "nesting", nesting=True, recursive=True -- 2.47.3