From ae013d1d1f4fa6d26829198e4e54957f79227068 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Fri, 1 Oct 2021 19:36:39 +0200 Subject: [PATCH] fixed --- lib/sqlalchemy/sql/compiler.py | 54 +++++++++++++++++++--------------- test/sql/test_cte.py | 40 ++++++++++++------------- 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 89c6ddca5c..63ae3a7bcd 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -841,11 +841,19 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # Dict[cte_id, text_query] + # To remember the query to print self.ctes = util.OrderedDict() # Detect same CTE references - # Dict[cte_id, cte_instance] + # Dict[(level, name), cte_instance] self.ctes_by_name = {} + # Dict[cte_id, cte_instance] + # Useful + self.ctes_by_id = {} + # Dict[cte_id, cte_instance] + # To retrieve ctes_by_name key + self.names_by_id = {} # Dict[cte_id, level] + # Remember level for nesting usage self.level_by_ctes = {} self.ctes_recursive = False if self.positional: @@ -2530,10 +2538,10 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - # if cte.unique_id in self.level_by_ctes: - # cte_level = self.level_by_ctes[cte.unique_id] + if cte.unique_id in self.level_by_ctes: + cte_level = self.level_by_ctes[cte.unique_id] - cte_level_name = cte.unique_id + cte_level_name = (cte_level, cte_name) if cte_level_name in self.ctes_by_name: existing_cte = self.ctes_by_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte @@ -2550,7 +2558,7 @@ class SQLCompiler(Compiled): # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". - del self.ctes[existing_cte.unique_id] + del self.ctes[existing_cte] del self.level_by_ctes[existing_cte.unique_id] else: raise exc.CompileError( @@ -2572,8 +2580,8 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_name[cte_level_name] = cte - # TODO: - # self.level_by_ctes[cte.unique_id] = cte_level + self.names_by_id[cte.unique_id] = cte_name + self.level_by_ctes[cte.unique_id] = cte_level if ( "autocommit" in cte.element._execution_options @@ -2631,9 +2639,7 @@ class SQLCompiler(Compiled): ) if self.positional: - kwargs["positional_names"] = self.cte_positional[ - cte.unique_id - ] = [] + kwargs["positional_names"] = self.cte_positional[cte] = [] assert kwargs.get("subquery", False) is False @@ -2659,7 +2665,7 @@ class SQLCompiler(Compiled): cte, cte._suffixes, **kwargs ) - self.ctes[cte.unique_id] = text + self.ctes[cte] = text self.level_by_ctes[cte.unique_id] = cte_level if asfrom: @@ -3485,16 +3491,16 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() - for cte_id in list(self.ctes.keys()): - cte = self.ctes_by_name[cte_id] - cte_level = self.level_by_ctes[cte_id] + for cte in list(self.ctes.keys()): + cte_level = self.level_by_ctes[cte.unique_id] + cte_name = self.names_by_id[cte.unique_id] is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) if not (cte.nesting and is_rendered_level): continue - ctes[cte_id] = self.ctes[cte_id] + ctes[cte] = self.ctes[cte] else: ctes = self.ctes @@ -3502,23 +3508,25 @@ class SQLCompiler(Compiled): if not ctes: return "" - ctes_recursive = any( - [self.ctes_by_name[cte_id].recursive for cte_id in ctes] - ) + ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: self.positiontup = ( - sum([self.cte_positional[cte_id] for cte_id in ctes], []) + sum([self.cte_positional[cte] for cte in ctes], []) + self.positiontup ) cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " - for cte_id in list(ctes.keys()): - del self.ctes[cte_id] - del self.level_by_ctes[cte_id] - del self.ctes_by_name[cte_id] + if nesting_level and nesting_level > 1: + for cte in list(ctes.keys()): + cte_level = self.level_by_ctes[cte.unique_id] + cte_name = self.names_by_id[cte.unique_id] + del self.ctes[cte] + del self.level_by_ctes[cte.unique_id] + del self.names_by_id[cte.unique_id] + del self.ctes_by_name[(cte_level, cte_name)] return cte_text diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 57878c0a4d..2f847279ac 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -2166,26 +2166,26 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): " SELECT recursive_cte.the_value FROM recursive_cte", ) - # def test_recursive_cte_w_union_aliased(self): - # nesting_cte = select(literal(1).label("inner_cte")).cte( - # "nesting", recursive=True, nesting=True - # ) - # nesting_cte_a = nesting_cte.alias() - # nesting_cte = nesting_cte.union( - # select(nesting_cte_a.c.inner_cte).where( - # nesting_cte_a.c.inner_cte == literal(1) - # ) - # ) - - # stmt = select(nesting_cte.c.inner_cte) - # self.assert_compile( - # stmt, - # "WITH RECURSIVE nesting(inner_cte) AS " - # "(SELECT :param_1 AS inner_cte UNION " - # "SELECT anon_1.inner_cte AS inner_cte FROM nesting AS anon_1 " - # "WHERE anon_1.inner_cte = :param_2) " - # "SELECT nesting.inner_cte FROM nesting", - # ) + def test_recursive_cte_w_union_aliased(self): + nesting_cte = select(literal(1).label("inner_cte")).cte( + "nesting", recursive=True, nesting=True + ) + nesting_cte_a = nesting_cte.alias() + nesting_cte = nesting_cte.union( + select(nesting_cte_a.c.inner_cte).where( + nesting_cte_a.c.inner_cte == literal(1) + ) + ) + + stmt = select(nesting_cte.c.inner_cte) + self.assert_compile( + stmt, + "WITH RECURSIVE nesting(inner_cte) AS " + "(SELECT :param_1 AS inner_cte UNION " + "SELECT anon_1.inner_cte AS inner_cte FROM nesting AS anon_1 " + "WHERE anon_1.inner_cte = :param_2) " + "SELECT nesting.inner_cte FROM nesting", + ) def test_recursive_cte_w_union(self): nesting_cte = select(literal(1).label("inner_cte")).cte( -- 2.47.3