]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use _restates attribute and the CTE itself as link id
authorEric Masseran <eric.masseran@gmail.com>
Mon, 4 Oct 2021 12:45:22 +0000 (14:45 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Mon, 4 Oct 2021 12:45:22 +0000 (14:45 +0200)
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/sql/test_cte.py

index 3d35293902dfac2d0e520e85faec716847e47e64..be251328403574f220299bb6938a30c60f98f2f7 100644 (file)
@@ -839,8 +839,10 @@ class SQLCompiler(Compiled):
         these collections otherwise.
 
         """
+        # Use as a unique id to identify a CTE part of a compilation
+        self.next_id_counter = 0
         # collect CTEs to tack on top of a SELECT
-        # Dict[cte_id, text_query]
+        # Dict[cte, text_query]
         # To remember the query to print
         self.ctes = util.OrderedDict()
         # Detect same CTE references
@@ -2535,8 +2537,8 @@ class SQLCompiler(Compiled):
         is_new_cte = True
         embedded_in_current_named_cte = False
 
-        if cte.unique_id in self.level_by_ctes:
-            cte_level = self.level_by_ctes[cte.unique_id]
+        if cte._get_unique_id() in self.level_by_ctes:
+            cte_level = self.level_by_ctes[cte._get_unique_id()]
 
         cte_level_name = (cte_level, cte_name)
         if cte_level_name in self.ctes_by_level_name:
@@ -2545,14 +2547,14 @@ class SQLCompiler(Compiled):
 
             # we've generated a same-named CTE that we are enclosed in,
             # or this is the same CTE.  just return the name.
-            if cte in existing_cte._restates or cte is existing_cte:
+            if cte is existing_cte._restates or cte is existing_cte:
                 is_new_cte = False
-            elif existing_cte in cte._restates:
+            elif existing_cte is cte._restates:
                 # we've generated a same-named CTE that is
                 # enclosed in us - we take precedence, so
                 # discard the text for the "inner".
                 del self.ctes[existing_cte]
-                del self.level_by_ctes[existing_cte.unique_id]
+                del self.level_by_ctes[existing_cte._get_unique_id()]
             else:
                 raise exc.CompileError(
                     "Multiple, unrelated CTEs found with "
@@ -2573,8 +2575,8 @@ class SQLCompiler(Compiled):
 
         if is_new_cte:
             self.ctes_by_level_name[cte_level_name] = cte
-            self.cte_names_by_id[cte.unique_id] = cte_name
-            self.level_by_ctes[cte.unique_id] = cte_level
+            self.cte_names_by_id[cte._get_unique_id()] = cte_name
+            self.level_by_ctes[cte._get_unique_id()] = cte_level
 
             if (
                 "autocommit" in cte.element._execution_options
@@ -2588,10 +2590,13 @@ class SQLCompiler(Compiled):
                     }
                 )
 
-            if pre_alias_cte.unique_id not in self.ctes:
+            if pre_alias_cte not in self.ctes:
                 self.visit_cte(pre_alias_cte, **kwargs)
 
-            if not cte_pre_alias_name and cte.unique_id not in self.ctes:
+            if (
+                not cte_pre_alias_name
+                and cte not in self.ctes
+            ):
                 if cte.recursive:
                     self.ctes_recursive = True
                 text = self.preparer.format_alias(cte, cte_name)
@@ -2659,7 +2664,7 @@ class SQLCompiler(Compiled):
                     )
 
                 self.ctes[cte] = text
-                self.level_by_ctes[cte.unique_id] = cte_level
+                self.level_by_ctes[cte._get_unique_id()] = cte_level
 
         if asfrom:
             if from_linter:
@@ -3485,8 +3490,9 @@ class SQLCompiler(Compiled):
         if nesting_level and nesting_level > 1:
             ctes = util.OrderedDict()
             for cte in list(self.ctes.keys()):
-                cte_level = self.level_by_ctes[cte.unique_id]
-                cte_name = self.cte_names_by_id[cte.unique_id]
+                cte._get_unique_id()
+                cte_level = self.level_by_ctes[cte._get_unique_id()]
+                cte_name = self.cte_names_by_id[cte._get_unique_id()]
                 is_rendered_level = cte_level == nesting_level or (
                     include_following_stack and cte_level == nesting_level + 1
                 )
@@ -3514,11 +3520,11 @@ class SQLCompiler(Compiled):
 
         if nesting_level and nesting_level > 1:
             for cte in list(ctes.keys()):
-                cte_level = self.level_by_ctes[cte.unique_id]
-                cte_name = self.cte_names_by_id[cte.unique_id]
+                cte_level = self.level_by_ctes[cte._get_unique_id()]
+                cte_name = self.cte_names_by_id[cte._get_unique_id()]
                 del self.ctes[cte]
-                del self.level_by_ctes[cte.unique_id]
-                del self.cte_names_by_id[cte.unique_id]
+                del self.level_by_ctes[cte._get_unique_id()]
+                del self.cte_names_by_id[cte._get_unique_id()]
                 del self.ctes_by_level_name[(cte_level, cte_name)]
 
         return cte_text
index bc4c29d408a4af7147201c105ef01d1f00b399d4..7de3e95be15f6eb4e716b55aa6c778b3ceb604ae 100644 (file)
@@ -2074,9 +2074,8 @@ class CTE(
         name=None,
         recursive=False,
         nesting=False,
-        _unique_id=None,
         _cte_alias=None,
-        _restates=(),
+        _restates=None,
         _prefixes=None,
         _suffixes=None,
     ):
@@ -2084,9 +2083,6 @@ class CTE(
         self.nesting = nesting
         self._cte_alias = _cte_alias
         self._restates = _restates
-        import uuid
-
-        self.unique_id = _unique_id if _unique_id else uuid.uuid4()
         if _prefixes:
             self._prefixes = _prefixes
         if _suffixes:
@@ -2118,7 +2114,6 @@ class CTE(
             name=name,
             recursive=self.recursive,
             nesting=self.nesting,
-            # _unique_id is not need as _cte_alias is doing the link
             _cte_alias=self,
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
@@ -2130,8 +2125,7 @@ class CTE(
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
-            _unique_id=self.unique_id,
-            _restates=self._restates + (self,),
+            _restates=self,
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
         )
@@ -2142,12 +2136,14 @@ class CTE(
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
-            _unique_id=self.unique_id,
-            _restates=self._restates + (self,),
+            _restates=self,
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
         )
 
+    def _get_unique_id(self):
+        return self._restates if self._restates is not None else self
+
 
 class HasCTE(roles.HasCTERole):
     """Mixin that declares a class to include CTE support.
index f6070239e43af4c81c6d0f14c8f28aec45692704..4d6f9ffa87bb118178184a1c45fa68d6d7bc27d4 100644 (file)
@@ -444,13 +444,13 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         cte = s1.cte(name="cte", recursive=True)
 
         bar = select(cte).cte("bar").alias("cs1")
-        cte = cte.union_all(select(cte.c.x + 1).where(cte.c.x < 10)).alias(
+        cte_rec = cte.union_all(select(cte.c.x + 1).where(cte.c.x < 10)).alias(
             "cs2"
         )
 
         # outer cte rendered first, then bar, which
         # includes "inner" cte
-        s2 = select(cte, bar)
+        s2 = select(cte_rec, bar)
         self.assert_compile(
             s2,
             "WITH RECURSIVE cte(x) AS "
@@ -474,7 +474,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
         # bar rendered, but then the "outer"
         # cte is rendered.
-        s2 = select(bar, cte)
+        s2 = select(bar, cte_rec)
         self.assert_compile(
             s2,
             "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), "