From 9253da2869c171f1c99d743bc49a5dfbf6953ea2 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Mon, 4 Oct 2021 14:45:22 +0200 Subject: [PATCH] Use _restates attribute and the CTE itself as link id --- lib/sqlalchemy/sql/compiler.py | 40 ++++++++++++++++++-------------- lib/sqlalchemy/sql/selectable.py | 16 +++++-------- test/sql/test_cte.py | 6 ++--- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3d35293902..be25132840 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -839,8 +839,10 @@ class SQLCompiler(Compiled): these collections otherwise. """ + # Use as a unique id to identify a CTE part of a compilation + self.next_id_counter = 0 # collect CTEs to tack on top of a SELECT - # Dict[cte_id, text_query] + # Dict[cte, text_query] # To remember the query to print self.ctes = util.OrderedDict() # Detect same CTE references @@ -2535,8 +2537,8 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - if cte.unique_id in self.level_by_ctes: - cte_level = self.level_by_ctes[cte.unique_id] + if cte._get_unique_id() in self.level_by_ctes: + cte_level = self.level_by_ctes[cte._get_unique_id()] cte_level_name = (cte_level, cte_name) if cte_level_name in self.ctes_by_level_name: @@ -2545,14 +2547,14 @@ class SQLCompiler(Compiled): # we've generated a same-named CTE that we are enclosed in, # or this is the same CTE. just return the name. - if cte in existing_cte._restates or cte is existing_cte: + if cte is existing_cte._restates or cte is existing_cte: is_new_cte = False - elif existing_cte in cte._restates: + elif existing_cte is cte._restates: # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] - del self.level_by_ctes[existing_cte.unique_id] + del self.level_by_ctes[existing_cte._get_unique_id()] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " @@ -2573,8 +2575,8 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_level_name[cte_level_name] = cte - self.cte_names_by_id[cte.unique_id] = cte_name - self.level_by_ctes[cte.unique_id] = cte_level + self.cte_names_by_id[cte._get_unique_id()] = cte_name + self.level_by_ctes[cte._get_unique_id()] = cte_level if ( "autocommit" in cte.element._execution_options @@ -2588,10 +2590,13 @@ class SQLCompiler(Compiled): } ) - if pre_alias_cte.unique_id not in self.ctes: + if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) - if not cte_pre_alias_name and cte.unique_id not in self.ctes: + if ( + not cte_pre_alias_name + and cte not in self.ctes + ): if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) @@ -2659,7 +2664,7 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text - self.level_by_ctes[cte.unique_id] = cte_level + self.level_by_ctes[cte._get_unique_id()] = cte_level if asfrom: if from_linter: @@ -3485,8 +3490,9 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte_level = self.level_by_ctes[cte.unique_id] - cte_name = self.cte_names_by_id[cte.unique_id] + cte._get_unique_id() + cte_level = self.level_by_ctes[cte._get_unique_id()] + cte_name = self.cte_names_by_id[cte._get_unique_id()] is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) @@ -3514,11 +3520,11 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: for cte in list(ctes.keys()): - cte_level = self.level_by_ctes[cte.unique_id] - cte_name = self.cte_names_by_id[cte.unique_id] + cte_level = self.level_by_ctes[cte._get_unique_id()] + cte_name = self.cte_names_by_id[cte._get_unique_id()] del self.ctes[cte] - del self.level_by_ctes[cte.unique_id] - del self.cte_names_by_id[cte.unique_id] + del self.level_by_ctes[cte._get_unique_id()] + del self.cte_names_by_id[cte._get_unique_id()] del self.ctes_by_level_name[(cte_level, cte_name)] return cte_text diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index bc4c29d408..7de3e95be1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2074,9 +2074,8 @@ class CTE( name=None, recursive=False, nesting=False, - _unique_id=None, _cte_alias=None, - _restates=(), + _restates=None, _prefixes=None, _suffixes=None, ): @@ -2084,9 +2083,6 @@ class CTE( self.nesting = nesting self._cte_alias = _cte_alias self._restates = _restates - import uuid - - self.unique_id = _unique_id if _unique_id else uuid.uuid4() if _prefixes: self._prefixes = _prefixes if _suffixes: @@ -2118,7 +2114,6 @@ class CTE( name=name, recursive=self.recursive, nesting=self.nesting, - # _unique_id is not need as _cte_alias is doing the link _cte_alias=self, _prefixes=self._prefixes, _suffixes=self._suffixes, @@ -2130,8 +2125,7 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, - _unique_id=self.unique_id, - _restates=self._restates + (self,), + _restates=self, _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -2142,12 +2136,14 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, - _unique_id=self.unique_id, - _restates=self._restates + (self,), + _restates=self, _prefixes=self._prefixes, _suffixes=self._suffixes, ) + def _get_unique_id(self): + return self._restates if self._restates is not None else self + class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index f6070239e4..4d6f9ffa87 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -444,13 +444,13 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): cte = s1.cte(name="cte", recursive=True) bar = select(cte).cte("bar").alias("cs1") - cte = cte.union_all(select(cte.c.x + 1).where(cte.c.x < 10)).alias( + cte_rec = cte.union_all(select(cte.c.x + 1).where(cte.c.x < 10)).alias( "cs2" ) # outer cte rendered first, then bar, which # includes "inner" cte - s2 = select(cte, bar) + s2 = select(cte_rec, bar) self.assert_compile( s2, "WITH RECURSIVE cte(x) AS " @@ -474,7 +474,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): # bar rendered, but then the "outer" # cte is rendered. - s2 = select(bar, cte) + s2 = select(bar, cte_rec) self.assert_compile( s2, "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " -- 2.47.3