]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Control CTE level in compiler class
authorEric Masseran <eric.masseran@gmail.com>
Sat, 10 Jul 2021 21:56:18 +0000 (23:56 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Sat, 10 Jul 2021 21:56:18 +0000 (23:56 +0200)
lib/sqlalchemy/sql/compiler.py

index 847bc12c562e3fd08a34c08bc1bb7c4c1f6cce40..66a8db1eab1c5920e1298a6eb6470509be7728bd 100644 (file)
@@ -844,6 +844,7 @@ class SQLCompiler(Compiled):
         self.ctes = util.OrderedDict()
         # Detect same CTE references
         self.ctes_by_name = {}
+        self.level_by_ctes = {}
         self.ctes_recursive = False
         if self.positional:
             self.cte_positional = {}
@@ -1825,10 +1826,8 @@ class SQLCompiler(Compiled):
 
         if self.ctes:
             # Nesting CTEs from deeper select
-            text = (
-                self._render_cte_clause(nesting_level=len(self.stack) + 1)
-                + text
-            )
+            nesting_level = (len(self.stack) + 1) if not toplevel else None
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
 
         self.stack.pop(-1)
         return text
@@ -2499,9 +2498,7 @@ class SQLCompiler(Compiled):
     ):
         self._init_cte_state()
 
-        cte_level = len(self.stack)
-        if cte.nesting:
-            cte.nesting_level = cte_level
+        cte_level = len(self.stack) if cte.nesting else 1
 
         kwargs["visiting_cte"] = cte
         if isinstance(cte.name, elements._truncated_label):
@@ -2526,6 +2523,7 @@ class SQLCompiler(Compiled):
                 # enclosed in us - we take precedence, so
                 # discard the text for the "inner".
                 del self.ctes[existing_cte]
+                del self.level_by_ctes[existing_cte]
             else:
                 raise exc.CompileError(
                     "Multiple, unrelated CTEs found with "
@@ -2618,6 +2616,7 @@ class SQLCompiler(Compiled):
                     )
 
                 self.ctes[cte] = text
+                self.level_by_ctes[cte] = cte_level
 
         if asfrom:
             if from_linter:
@@ -3162,9 +3161,8 @@ class SQLCompiler(Compiled):
 
         # In compound query, CTEs are shared at the compound level
         if self.ctes and compound_index is None:
-            text = (
-                self._render_cte_clause(nesting_level=len(self.stack)) + text
-            )
+            nesting_level = len(self.stack) if not toplevel else None
+            text = self._render_cte_clause(nesting_level=nesting_level) + text
 
         if select_stmt._suffixes:
             text += " " + self._generate_prefixes(
@@ -3340,19 +3338,26 @@ class SQLCompiler(Compiled):
         self,
         nesting_level=None,
     ):
+        if not self.ctes:
+            return ""
+
         ctes = self.ctes
 
         if nesting_level and nesting_level > 1:
             ctes = {
                 cte: ctes[cte]
                 for cte in ctes
-                if cte.nesting_level == nesting_level
+                if cte.nesting and self.level_by_ctes[cte] == nesting_level
             }
+
+            if not ctes:
+                return ""
+
             # Remove them from the visible CTEs
             self.ctes = {
                 cte: self.ctes[cte]
                 for cte in self.ctes
-                if not cte.nesting_level == nesting_level
+                if not cte.nesting and self.level_by_ctes[cte] == nesting_level
             }
 
             if ctes and not self.dialect.supports_nesting_cte:
@@ -3363,9 +3368,6 @@ class SQLCompiler(Compiled):
 
         ctes_recursive = any([cte.recursive for cte in ctes])
 
-        if not ctes:
-            return ""
-
         if self.positional:
             self.positiontup = (
                 sum([self.cte_positional[cte] for cte in ctes], [])