From: Eric Masseran Date: Fri, 1 Oct 2021 16:36:40 +0000 (+0200) Subject: WIP works but also accepts same name CTEs X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e7d4f0c97583926639d39c1e253ecf2f07deec3e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git WIP works but also accepts same name CTEs --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 333ed36f41..89c6ddca5c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -840,9 +840,12 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT + # Dict[cte_id, text_query] self.ctes = util.OrderedDict() # Detect same CTE references + # Dict[cte_id, cte_instance] self.ctes_by_name = {} + # Dict[cte_id, level] self.level_by_ctes = {} self.ctes_recursive = False if self.positional: @@ -2527,24 +2530,28 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - if cte in self.level_by_ctes: - cte_level = self.level_by_ctes[cte] + # if cte.unique_id in self.level_by_ctes: + # cte_level = self.level_by_ctes[cte.unique_id] - cte_level_name = (cte_level, cte_name) + cte_level_name = cte.unique_id if cte_level_name in self.ctes_by_name: existing_cte = self.ctes_by_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, # or this is the same CTE. just return the name. - if cte in existing_cte._restates or cte is existing_cte: + if ( + cte in existing_cte._restates + or cte is existing_cte + # or cte.unique_id == existing_cte.unique_id + ): is_new_cte = False elif existing_cte in cte._restates: # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". - del self.ctes[existing_cte] - del self.level_by_ctes[existing_cte] + del self.ctes[existing_cte.unique_id] + del self.level_by_ctes[existing_cte.unique_id] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " @@ -2565,6 +2572,8 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_name[cte_level_name] = cte + # TODO: + # self.level_by_ctes[cte.unique_id] = cte_level if ( "autocommit" in cte.element._execution_options @@ -2578,10 +2587,10 @@ class SQLCompiler(Compiled): } ) - if pre_alias_cte not in self.ctes: + if pre_alias_cte.unique_id not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) - if not cte_pre_alias_name and cte not in self.ctes: + if not cte_pre_alias_name and cte.unique_id not in self.ctes: if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) @@ -2622,7 +2631,9 @@ class SQLCompiler(Compiled): ) if self.positional: - kwargs["positional_names"] = self.cte_positional[cte] = [] + kwargs["positional_names"] = self.cte_positional[ + cte.unique_id + ] = [] assert kwargs.get("subquery", False) is False @@ -2648,8 +2659,8 @@ class SQLCompiler(Compiled): cte, cte._suffixes, **kwargs ) - self.ctes[cte] = text - self.level_by_ctes[cte] = cte_level + self.ctes[cte.unique_id] = text + self.level_by_ctes[cte.unique_id] = cte_level if asfrom: if from_linter: @@ -3474,40 +3485,41 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() - for cte in list(self.ctes.keys()): - cte_level = self.level_by_ctes[cte] + for cte_id in list(self.ctes.keys()): + cte = self.ctes_by_name[cte_id] + cte_level = self.level_by_ctes[cte_id] is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) if not (cte.nesting and is_rendered_level): continue - ctes[cte] = self.ctes[cte] - - del self.ctes[cte] - del self.level_by_ctes[cte] - - cte_name = cte.name - if isinstance(cte_name, elements._truncated_label): - cte_name = self._truncated_identifier("alias", cte_name) + ctes[cte_id] = self.ctes[cte_id] - del self.ctes_by_name[(cte_level, cte_name)] else: ctes = self.ctes if not ctes: return "" - ctes_recursive = any([cte.recursive for cte in ctes]) + ctes_recursive = any( + [self.ctes_by_name[cte_id].recursive for cte_id in ctes] + ) if self.positional: self.positiontup = ( - sum([self.cte_positional[cte] for cte in ctes], []) + sum([self.cte_positional[cte_id] for cte_id in ctes], []) + self.positiontup ) cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " + + for cte_id in list(ctes.keys()): + del self.ctes[cte_id] + del self.level_by_ctes[cte_id] + del self.ctes_by_name[cte_id] + return cte_text def get_cte_preamble(self, recursive): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 970c7a0c56..db64d63347 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2074,6 +2074,7 @@ class CTE( name=None, recursive=False, nesting=False, + _unique_id=None, _cte_alias=None, _restates=(), _prefixes=None, @@ -2083,6 +2084,9 @@ class CTE( self.nesting = nesting self._cte_alias = _cte_alias self._restates = _restates + import uuid + + self.unique_id = _unique_id if _unique_id else uuid.uuid4() if _prefixes: self._prefixes = _prefixes if _suffixes: @@ -2115,6 +2119,7 @@ class CTE( recursive=self.recursive, nesting=self.nesting, _cte_alias=self, + # _unique_id=self.unique_id, _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -2125,6 +2130,7 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, + _unique_id=self.unique_id, _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes, @@ -2136,6 +2142,7 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, + _unique_id=self.unique_id, _restates=self._restates + (self,), _prefixes=self._prefixes, _suffixes=self._suffixes,