From 5f021df967b5a74911358891001b38996840bb46 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Sat, 10 Jul 2021 23:56:18 +0200 Subject: [PATCH] Control CTE level in compiler class --- lib/sqlalchemy/sql/compiler.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 847bc12c56..66a8db1eab 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -844,6 +844,7 @@ class SQLCompiler(Compiled): self.ctes = util.OrderedDict() # Detect same CTE references self.ctes_by_name = {} + self.level_by_ctes = {} self.ctes_recursive = False if self.positional: self.cte_positional = {} @@ -1825,10 +1826,8 @@ class SQLCompiler(Compiled): if self.ctes: # Nesting CTEs from deeper select - text = ( - self._render_cte_clause(nesting_level=len(self.stack) + 1) - + text - ) + nesting_level = (len(self.stack) + 1) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) return text @@ -2499,9 +2498,7 @@ class SQLCompiler(Compiled): ): self._init_cte_state() - cte_level = len(self.stack) - if cte.nesting: - cte.nesting_level = cte_level + cte_level = len(self.stack) if cte.nesting else 1 kwargs["visiting_cte"] = cte if isinstance(cte.name, elements._truncated_label): @@ -2526,6 +2523,7 @@ class SQLCompiler(Compiled): # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] + del self.level_by_ctes[existing_cte] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " @@ -2618,6 +2616,7 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text + self.level_by_ctes[cte] = cte_level if asfrom: if from_linter: @@ -3162,9 +3161,8 @@ class SQLCompiler(Compiled): # In compound query, CTEs are shared at the compound level if self.ctes and compound_index is None: - text = ( - self._render_cte_clause(nesting_level=len(self.stack)) + text - ) + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -3340,19 +3338,26 @@ class SQLCompiler(Compiled): self, nesting_level=None, ): + if not self.ctes: + return "" + ctes = self.ctes if nesting_level and nesting_level > 1: ctes = { cte: ctes[cte] for cte in ctes - if cte.nesting_level == nesting_level + if cte.nesting and self.level_by_ctes[cte] == nesting_level } + + if not ctes: + return "" + # Remove them from the visible CTEs self.ctes = { cte: self.ctes[cte] for cte in self.ctes - if not cte.nesting_level == nesting_level + if not cte.nesting and self.level_by_ctes[cte] == nesting_level } if ctes and not self.dialect.supports_nesting_cte: @@ -3363,9 +3368,6 @@ class SQLCompiler(Compiled): ctes_recursive = any([cte.recursive for cte in ctes]) - if not ctes: - return "" - if self.positional: self.positiontup = ( sum([self.cte_positional[cte] for cte in ctes], []) -- 2.47.3