]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
WIP works but also accepts same name CTEs
authorEric Masseran <eric.masseran@gmail.com>
Fri, 1 Oct 2021 16:36:40 +0000 (18:36 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Fri, 1 Oct 2021 16:36:40 +0000 (18:36 +0200)
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py

index 333ed36f41f56355f5e50814316ac18c27be9151..89c6ddca5c5a084eec7171a2d145645bcf042f3d 100644 (file)
@@ -840,9 +840,12 @@ class SQLCompiler(Compiled):
 
         """
         # collect CTEs to tack on top of a SELECT
+        # Dict[cte_id, text_query]
         self.ctes = util.OrderedDict()
         # Detect same CTE references
+        # Dict[cte_id, cte_instance]
         self.ctes_by_name = {}
+        # Dict[cte_id, level]
         self.level_by_ctes = {}
         self.ctes_recursive = False
         if self.positional:
@@ -2527,24 +2530,28 @@ class SQLCompiler(Compiled):
         is_new_cte = True
         embedded_in_current_named_cte = False
 
-        if cte in self.level_by_ctes:
-            cte_level = self.level_by_ctes[cte]
+        # if cte.unique_id in self.level_by_ctes:
+        #     cte_level = self.level_by_ctes[cte.unique_id]
 
-        cte_level_name = (cte_level, cte_name)
+        cte_level_name = cte.unique_id
         if cte_level_name in self.ctes_by_name:
             existing_cte = self.ctes_by_name[cte_level_name]
             embedded_in_current_named_cte = visiting_cte is existing_cte
 
             # we've generated a same-named CTE that we are enclosed in,
             # or this is the same CTE.  just return the name.
-            if cte in existing_cte._restates or cte is existing_cte:
+            if (
+                cte in existing_cte._restates
+                or cte is existing_cte
+                # or cte.unique_id == existing_cte.unique_id
+            ):
                 is_new_cte = False
             elif existing_cte in cte._restates:
                 # we've generated a same-named CTE that is
                 # enclosed in us - we take precedence, so
                 # discard the text for the "inner".
-                del self.ctes[existing_cte]
-                del self.level_by_ctes[existing_cte]
+                del self.ctes[existing_cte.unique_id]
+                del self.level_by_ctes[existing_cte.unique_id]
             else:
                 raise exc.CompileError(
                     "Multiple, unrelated CTEs found with "
@@ -2565,6 +2572,8 @@ class SQLCompiler(Compiled):
 
         if is_new_cte:
             self.ctes_by_name[cte_level_name] = cte
+            # TODO:
+            # self.level_by_ctes[cte.unique_id] = cte_level
 
             if (
                 "autocommit" in cte.element._execution_options
@@ -2578,10 +2587,10 @@ class SQLCompiler(Compiled):
                     }
                 )
 
-            if pre_alias_cte not in self.ctes:
+            if pre_alias_cte.unique_id not in self.ctes:
                 self.visit_cte(pre_alias_cte, **kwargs)
 
-            if not cte_pre_alias_name and cte not in self.ctes:
+            if not cte_pre_alias_name and cte.unique_id not in self.ctes:
                 if cte.recursive:
                     self.ctes_recursive = True
                 text = self.preparer.format_alias(cte, cte_name)
@@ -2622,7 +2631,9 @@ class SQLCompiler(Compiled):
                     )
 
                 if self.positional:
-                    kwargs["positional_names"] = self.cte_positional[cte] = []
+                    kwargs["positional_names"] = self.cte_positional[
+                        cte.unique_id
+                    ] = []
 
                 assert kwargs.get("subquery", False) is False
 
@@ -2648,8 +2659,8 @@ class SQLCompiler(Compiled):
                         cte, cte._suffixes, **kwargs
                     )
 
-                self.ctes[cte] = text
-                self.level_by_ctes[cte] = cte_level
+                self.ctes[cte.unique_id] = text
+                self.level_by_ctes[cte.unique_id] = cte_level
 
         if asfrom:
             if from_linter:
@@ -3474,40 +3485,41 @@ class SQLCompiler(Compiled):
 
         if nesting_level and nesting_level > 1:
             ctes = util.OrderedDict()
-            for cte in list(self.ctes.keys()):
-                cte_level = self.level_by_ctes[cte]
+            for cte_id in list(self.ctes.keys()):
+                cte = self.ctes_by_name[cte_id]
+                cte_level = self.level_by_ctes[cte_id]
                 is_rendered_level = cte_level == nesting_level or (
                     include_following_stack and cte_level == nesting_level + 1
                 )
                 if not (cte.nesting and is_rendered_level):
                     continue
 
-                ctes[cte] = self.ctes[cte]
-
-                del self.ctes[cte]
-                del self.level_by_ctes[cte]
-
-                cte_name = cte.name
-                if isinstance(cte_name, elements._truncated_label):
-                    cte_name = self._truncated_identifier("alias", cte_name)
+                ctes[cte_id] = self.ctes[cte_id]
 
-                del self.ctes_by_name[(cte_level, cte_name)]
         else:
             ctes = self.ctes
 
         if not ctes:
             return ""
 
-        ctes_recursive = any([cte.recursive for cte in ctes])
+        ctes_recursive = any(
+            [self.ctes_by_name[cte_id].recursive for cte_id in ctes]
+        )
 
         if self.positional:
             self.positiontup = (
-                sum([self.cte_positional[cte] for cte in ctes], [])
+                sum([self.cte_positional[cte_id] for cte_id in ctes], [])
                 + self.positiontup
             )
         cte_text = self.get_cte_preamble(ctes_recursive) + " "
         cte_text += ", \n".join([txt for txt in ctes.values()])
         cte_text += "\n "
+
+        for cte_id in list(ctes.keys()):
+            del self.ctes[cte_id]
+            del self.level_by_ctes[cte_id]
+            del self.ctes_by_name[cte_id]
+
         return cte_text
 
     def get_cte_preamble(self, recursive):
index 970c7a0c567b636fff24345346065dbd5988c4b4..db64d633474b6f278ef41874131e665cc8e54313 100644 (file)
@@ -2074,6 +2074,7 @@ class CTE(
         name=None,
         recursive=False,
         nesting=False,
+        _unique_id=None,
         _cte_alias=None,
         _restates=(),
         _prefixes=None,
@@ -2083,6 +2084,9 @@ class CTE(
         self.nesting = nesting
         self._cte_alias = _cte_alias
         self._restates = _restates
+        import uuid
+
+        self.unique_id = _unique_id if _unique_id else uuid.uuid4()
         if _prefixes:
             self._prefixes = _prefixes
         if _suffixes:
@@ -2115,6 +2119,7 @@ class CTE(
             recursive=self.recursive,
             nesting=self.nesting,
             _cte_alias=self,
+            # _unique_id=self.unique_id,
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
         )
@@ -2125,6 +2130,7 @@ class CTE(
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
+            _unique_id=self.unique_id,
             _restates=self._restates + (self,),
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
@@ -2136,6 +2142,7 @@ class CTE(
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
+            _unique_id=self.unique_id,
             _restates=self._restates + (self,),
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,