]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed
authorEric Masseran <eric.masseran@gmail.com>
Fri, 1 Oct 2021 17:36:39 +0000 (19:36 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Fri, 1 Oct 2021 17:36:39 +0000 (19:36 +0200)
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

index 89c6ddca5c5a084eec7171a2d145645bcf042f3d..63ae3a7bcd56bf606e0b4dd8895cfc8337f46950 100644 (file)
@@ -841,11 +841,19 @@ class SQLCompiler(Compiled):
         """
         # collect CTEs to tack on top of a SELECT
         # Dict[cte_id, text_query]
+        # To remember the query to print
         self.ctes = util.OrderedDict()
         # Detect same CTE references
-        # Dict[cte_id, cte_instance]
+        # Dict[(level, name), cte_instance]
         self.ctes_by_name = {}
+        # Dict[cte_id, cte_instance]
+        # Useful
+        self.ctes_by_id = {}
+        # Dict[cte_id, cte_instance]
+        # To retrieve ctes_by_name key
+        self.names_by_id = {}
         # Dict[cte_id, level]
+        # Remember level for nesting usage
         self.level_by_ctes = {}
         self.ctes_recursive = False
         if self.positional:
@@ -2530,10 +2538,10 @@ class SQLCompiler(Compiled):
         is_new_cte = True
         embedded_in_current_named_cte = False
 
-        if cte.unique_id in self.level_by_ctes:
-            cte_level = self.level_by_ctes[cte.unique_id]
+        if cte.unique_id in self.level_by_ctes:
+            cte_level = self.level_by_ctes[cte.unique_id]
 
-        cte_level_name = cte.unique_id
+        cte_level_name = (cte_level, cte_name)
         if cte_level_name in self.ctes_by_name:
             existing_cte = self.ctes_by_name[cte_level_name]
             embedded_in_current_named_cte = visiting_cte is existing_cte
@@ -2550,7 +2558,7 @@ class SQLCompiler(Compiled):
                 # we've generated a same-named CTE that is
                 # enclosed in us - we take precedence, so
                 # discard the text for the "inner".
-                del self.ctes[existing_cte.unique_id]
+                del self.ctes[existing_cte]
                 del self.level_by_ctes[existing_cte.unique_id]
             else:
                 raise exc.CompileError(
@@ -2572,8 +2580,8 @@ class SQLCompiler(Compiled):
 
         if is_new_cte:
             self.ctes_by_name[cte_level_name] = cte
-            # TODO:
-            self.level_by_ctes[cte.unique_id] = cte_level
+            self.names_by_id[cte.unique_id] = cte_name
+            self.level_by_ctes[cte.unique_id] = cte_level
 
             if (
                 "autocommit" in cte.element._execution_options
@@ -2631,9 +2639,7 @@ class SQLCompiler(Compiled):
                     )
 
                 if self.positional:
-                    kwargs["positional_names"] = self.cte_positional[
-                        cte.unique_id
-                    ] = []
+                    kwargs["positional_names"] = self.cte_positional[cte] = []
 
                 assert kwargs.get("subquery", False) is False
 
@@ -2659,7 +2665,7 @@ class SQLCompiler(Compiled):
                         cte, cte._suffixes, **kwargs
                     )
 
-                self.ctes[cte.unique_id] = text
+                self.ctes[cte] = text
                 self.level_by_ctes[cte.unique_id] = cte_level
 
         if asfrom:
@@ -3485,16 +3491,16 @@ class SQLCompiler(Compiled):
 
         if nesting_level and nesting_level > 1:
             ctes = util.OrderedDict()
-            for cte_id in list(self.ctes.keys()):
-                cte = self.ctes_by_name[cte_id]
-                cte_level = self.level_by_ctes[cte_id]
+            for cte in list(self.ctes.keys()):
+                cte_level = self.level_by_ctes[cte.unique_id]
+                cte_name = self.names_by_id[cte.unique_id]
                 is_rendered_level = cte_level == nesting_level or (
                     include_following_stack and cte_level == nesting_level + 1
                 )
                 if not (cte.nesting and is_rendered_level):
                     continue
 
-                ctes[cte_id] = self.ctes[cte_id]
+                ctes[cte] = self.ctes[cte]
 
         else:
             ctes = self.ctes
@@ -3502,23 +3508,25 @@ class SQLCompiler(Compiled):
         if not ctes:
             return ""
 
-        ctes_recursive = any(
-            [self.ctes_by_name[cte_id].recursive for cte_id in ctes]
-        )
+        ctes_recursive = any([cte.recursive for cte in ctes])
 
         if self.positional:
             self.positiontup = (
-                sum([self.cte_positional[cte_id] for cte_id in ctes], [])
+                sum([self.cte_positional[cte] for cte in ctes], [])
                 + self.positiontup
             )
         cte_text = self.get_cte_preamble(ctes_recursive) + " "
         cte_text += ", \n".join([txt for txt in ctes.values()])
         cte_text += "\n "
 
-        for cte_id in list(ctes.keys()):
-            del self.ctes[cte_id]
-            del self.level_by_ctes[cte_id]
-            del self.ctes_by_name[cte_id]
+        if nesting_level and nesting_level > 1:
+            for cte in list(ctes.keys()):
+                cte_level = self.level_by_ctes[cte.unique_id]
+                cte_name = self.names_by_id[cte.unique_id]
+                del self.ctes[cte]
+                del self.level_by_ctes[cte.unique_id]
+                del self.names_by_id[cte.unique_id]
+                del self.ctes_by_name[(cte_level, cte_name)]
 
         return cte_text
 
index 57878c0a4d99803d6ed3217381782637f4711092..2f847279aceffbd20821d0dd3bdc1ced28fada74 100644 (file)
@@ -2166,26 +2166,26 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             " SELECT recursive_cte.the_value FROM recursive_cte",
         )
 
-    def test_recursive_cte_w_union_aliased(self):
-        nesting_cte = select(literal(1).label("inner_cte")).cte(
-            "nesting", recursive=True, nesting=True
-        )
-        nesting_cte_a = nesting_cte.alias()
-        nesting_cte = nesting_cte.union(
-            select(nesting_cte_a.c.inner_cte).where(
-                nesting_cte_a.c.inner_cte == literal(1)
-            )
-        )
-
-        stmt = select(nesting_cte.c.inner_cte)
-        self.assert_compile(
-            stmt,
-            "WITH RECURSIVE nesting(inner_cte) AS "
-            "(SELECT :param_1 AS inner_cte UNION "
-            "SELECT anon_1.inner_cte AS inner_cte FROM nesting AS anon_1 "
-            "WHERE anon_1.inner_cte = :param_2) "
-            "SELECT nesting.inner_cte FROM nesting",
-        )
+    def test_recursive_cte_w_union_aliased(self):
+        nesting_cte = select(literal(1).label("inner_cte")).cte(
+            "nesting", recursive=True, nesting=True
+        )
+        nesting_cte_a = nesting_cte.alias()
+        nesting_cte = nesting_cte.union(
+            select(nesting_cte_a.c.inner_cte).where(
+                nesting_cte_a.c.inner_cte == literal(1)
+            )
+        )
+
+        stmt = select(nesting_cte.c.inner_cte)
+        self.assert_compile(
+            stmt,
+            "WITH RECURSIVE nesting(inner_cte) AS "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT anon_1.inner_cte AS inner_cte FROM nesting AS anon_1 "
+            "WHERE anon_1.inner_cte = :param_2) "
+            "SELECT nesting.inner_cte FROM nesting",
+        )
 
     def test_recursive_cte_w_union(self):
         nesting_cte = select(literal(1).label("inner_cte")).cte(