From 5c18d3f493ea2be26790cc5bb6126c593759271b Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Mon, 4 Oct 2021 15:10:26 +0200 Subject: [PATCH] Documentation improvement --- lib/sqlalchemy/sql/compiler.py | 46 +++++++++++++------------------- lib/sqlalchemy/sql/selectable.py | 8 +++++- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index be25132840..411b5c00e2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -839,20 +839,15 @@ class SQLCompiler(Compiled): these collections otherwise. """ - # Use as a unique id to identify a CTE part of a compilation - self.next_id_counter = 0 # collect CTEs to tack on top of a SELECT - # Dict[cte, text_query] - # To remember the query to print + # To store the query to print - Dict[cte, text_query] self.ctes = util.OrderedDict() - # Detect same CTE references - # Dict[(level, name), cte_instance] + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting self.ctes_by_level_name = {} - # Dict[cte_id, cte_name] - # To retrieve key in self.ctes_by_level_name + # To retrieve key in ctes_by_level_name - Dict[cte_reference, cte_name] self.cte_names_by_id = {} - # Dict[cte_id, level] - # Remember level for nesting usage + # Store CTE level for nesting usage - Dict[cte_reference, level] self.level_by_ctes = {} self.ctes_recursive = False if self.positional: @@ -2537,8 +2532,8 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - if cte._get_unique_id() in self.level_by_ctes: - cte_level = self.level_by_ctes[cte._get_unique_id()] + if cte._get_reference_cte() in self.level_by_ctes: + cte_level = self.level_by_ctes[cte._get_reference_cte()] cte_level_name = (cte_level, cte_name) if cte_level_name in self.ctes_by_level_name: @@ -2554,7 +2549,7 @@ class SQLCompiler(Compiled): # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] - del self.level_by_ctes[existing_cte._get_unique_id()] + del self.level_by_ctes[existing_cte._get_reference_cte()] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " @@ -2575,8 +2570,8 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_level_name[cte_level_name] = cte - self.cte_names_by_id[cte._get_unique_id()] = cte_name - self.level_by_ctes[cte._get_unique_id()] = cte_level + self.cte_names_by_id[cte._get_reference_cte()] = cte_name + self.level_by_ctes[cte._get_reference_cte()] = cte_level if ( "autocommit" in cte.element._execution_options @@ -2593,10 +2588,7 @@ class SQLCompiler(Compiled): if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) - if ( - not cte_pre_alias_name - and cte not in self.ctes - ): + if not cte_pre_alias_name and cte not in self.ctes: if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) @@ -2664,7 +2656,7 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text - self.level_by_ctes[cte._get_unique_id()] = cte_level + self.level_by_ctes[cte._get_reference_cte()] = cte_level if asfrom: if from_linter: @@ -3490,9 +3482,9 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte._get_unique_id() - cte_level = self.level_by_ctes[cte._get_unique_id()] - cte_name = self.cte_names_by_id[cte._get_unique_id()] + cte._get_reference_cte() + cte_level = self.level_by_ctes[cte._get_reference_cte()] + cte_name = self.cte_names_by_id[cte._get_reference_cte()] is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) @@ -3520,11 +3512,11 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: for cte in list(ctes.keys()): - cte_level = self.level_by_ctes[cte._get_unique_id()] - cte_name = self.cte_names_by_id[cte._get_unique_id()] + cte_level = self.level_by_ctes[cte._get_reference_cte()] + cte_name = self.cte_names_by_id[cte._get_reference_cte()] del self.ctes[cte] - del self.level_by_ctes[cte._get_unique_id()] - del self.cte_names_by_id[cte._get_unique_id()] + del self.level_by_ctes[cte._get_reference_cte()] + del self.cte_names_by_id[cte._get_reference_cte()] del self.ctes_by_level_name[(cte_level, cte_name)] return cte_text diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 7de3e95be1..ef8b7686e1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2082,6 +2082,7 @@ class CTE( self.recursive = recursive self.nesting = nesting self._cte_alias = _cte_alias + # Keep recursivity reference with union/union_all self._restates = _restates if _prefixes: self._prefixes = _prefixes @@ -2141,7 +2142,12 @@ class CTE( _suffixes=self._suffixes, ) - def _get_unique_id(self): + def _get_reference_cte(self): + """ + A recursive CTE is updated to attach the recursive part. + Updated CTEs should still refer to the original CTE. + This function returns this reference identifier. + """ return self._restates if self._restates is not None else self -- 2.47.3