]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Documentation improvement
authorEric Masseran <eric.masseran@gmail.com>
Mon, 4 Oct 2021 13:10:26 +0000 (15:10 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Mon, 4 Oct 2021 13:10:26 +0000 (15:10 +0200)
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py

index be251328403574f220299bb6938a30c60f98f2f7..411b5c00e23715091dd165c753c682f4dbf1ec3c 100644 (file)
@@ -839,20 +839,15 @@ class SQLCompiler(Compiled):
         these collections otherwise.
 
         """
-        # Use as a unique id to identify a CTE part of a compilation
-        self.next_id_counter = 0
         # collect CTEs to tack on top of a SELECT
-        # Dict[cte, text_query]
-        # To remember the query to print
+        # To store the query to print - Dict[cte, text_query]
         self.ctes = util.OrderedDict()
-        # Detect same CTE references
-        # Dict[(level, name), cte_instance]
+        # Detect same CTE references - Dict[(level, name), cte]
+        # Level is required for supporting nesting
         self.ctes_by_level_name = {}
-        # Dict[cte_id, cte_name]
-        # To retrieve key in self.ctes_by_level_name
+        # To retrieve key in ctes_by_level_name - Dict[cte_reference, cte_name]
         self.cte_names_by_id = {}
-        # Dict[cte_id, level]
-        # Remember level for nesting usage
+        # Store CTE level for nesting usage - Dict[cte_reference, level]
         self.level_by_ctes = {}
         self.ctes_recursive = False
         if self.positional:
@@ -2537,8 +2532,8 @@ class SQLCompiler(Compiled):
         is_new_cte = True
         embedded_in_current_named_cte = False
 
-        if cte._get_unique_id() in self.level_by_ctes:
-            cte_level = self.level_by_ctes[cte._get_unique_id()]
+        if cte._get_reference_cte() in self.level_by_ctes:
+            cte_level = self.level_by_ctes[cte._get_reference_cte()]
 
         cte_level_name = (cte_level, cte_name)
         if cte_level_name in self.ctes_by_level_name:
@@ -2554,7 +2549,7 @@ class SQLCompiler(Compiled):
                 # enclosed in us - we take precedence, so
                 # discard the text for the "inner".
                 del self.ctes[existing_cte]
-                del self.level_by_ctes[existing_cte._get_unique_id()]
+                del self.level_by_ctes[existing_cte._get_reference_cte()]
             else:
                 raise exc.CompileError(
                     "Multiple, unrelated CTEs found with "
@@ -2575,8 +2570,8 @@ class SQLCompiler(Compiled):
 
         if is_new_cte:
             self.ctes_by_level_name[cte_level_name] = cte
-            self.cte_names_by_id[cte._get_unique_id()] = cte_name
-            self.level_by_ctes[cte._get_unique_id()] = cte_level
+            self.cte_names_by_id[cte._get_reference_cte()] = cte_name
+            self.level_by_ctes[cte._get_reference_cte()] = cte_level
 
             if (
                 "autocommit" in cte.element._execution_options
@@ -2593,10 +2588,7 @@ class SQLCompiler(Compiled):
             if pre_alias_cte not in self.ctes:
                 self.visit_cte(pre_alias_cte, **kwargs)
 
-            if (
-                not cte_pre_alias_name
-                and cte not in self.ctes
-            ):
+            if not cte_pre_alias_name and cte not in self.ctes:
                 if cte.recursive:
                     self.ctes_recursive = True
                 text = self.preparer.format_alias(cte, cte_name)
@@ -2664,7 +2656,7 @@ class SQLCompiler(Compiled):
                     )
 
                 self.ctes[cte] = text
-                self.level_by_ctes[cte._get_unique_id()] = cte_level
+                self.level_by_ctes[cte._get_reference_cte()] = cte_level
 
         if asfrom:
             if from_linter:
@@ -3490,9 +3482,9 @@ class SQLCompiler(Compiled):
         if nesting_level and nesting_level > 1:
             ctes = util.OrderedDict()
             for cte in list(self.ctes.keys()):
-                cte._get_unique_id()
-                cte_level = self.level_by_ctes[cte._get_unique_id()]
-                cte_name = self.cte_names_by_id[cte._get_unique_id()]
+                cte._get_reference_cte()
+                cte_level = self.level_by_ctes[cte._get_reference_cte()]
+                cte_name = self.cte_names_by_id[cte._get_reference_cte()]
                 is_rendered_level = cte_level == nesting_level or (
                     include_following_stack and cte_level == nesting_level + 1
                 )
@@ -3520,11 +3512,11 @@ class SQLCompiler(Compiled):
 
         if nesting_level and nesting_level > 1:
             for cte in list(ctes.keys()):
-                cte_level = self.level_by_ctes[cte._get_unique_id()]
-                cte_name = self.cte_names_by_id[cte._get_unique_id()]
+                cte_level = self.level_by_ctes[cte._get_reference_cte()]
+                cte_name = self.cte_names_by_id[cte._get_reference_cte()]
                 del self.ctes[cte]
-                del self.level_by_ctes[cte._get_unique_id()]
-                del self.cte_names_by_id[cte._get_unique_id()]
+                del self.level_by_ctes[cte._get_reference_cte()]
+                del self.cte_names_by_id[cte._get_reference_cte()]
                 del self.ctes_by_level_name[(cte_level, cte_name)]
 
         return cte_text
index 7de3e95be15f6eb4e716b55aa6c778b3ceb604ae..ef8b7686e19cf0f314bea59dea07edbfb5d7bc6d 100644 (file)
@@ -2082,6 +2082,7 @@ class CTE(
         self.recursive = recursive
         self.nesting = nesting
         self._cte_alias = _cte_alias
+        # Keep recursivity reference with union/union_all
         self._restates = _restates
         if _prefixes:
             self._prefixes = _prefixes
@@ -2141,7 +2142,12 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def _get_unique_id(self):
+    def _get_reference_cte(self):
+        """
+        A recursive CTE is updated to attach the recursive part.
+        Updated CTEs should still refer to the original CTE.
+        This function returns this reference identifier.
+        """
         return self._restates if self._restates is not None else self