]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix recursive CTE to support nesting
authorEric Masseran <eric.masseran@gmail.com>
Fri, 8 Oct 2021 14:02:58 +0000 (10:02 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Tue, 12 Oct 2021 22:46:57 +0000 (22:46 +0000)
Repaired issue in new :paramref:`_sql.HasCTE.cte.nesting` parameter
introduced with :ticket:`4123` where a recursive :class:`_sql.CTE` using
:paramref:`_sql.HasCTE.cte.recursive` in typical conjunction with UNION
would not compile correctly.  Additionally makes some adjustments so that
the :class:`_sql.CTE` construct creates a correct cache key.
Pull request courtesy Eric Masseran.

Fixes: #4123
> This has not been caught by the tests because the nesting recursive
queries there did not union against itself, eg there was only the i
root clause...

- Now tests are real recursive queries
- Add tests on aliased nested CTEs (recursive or not)
- Adapt the `_restates` attribute to use it as a reference
- Add some docs around to explain some variables usage

Closes: #7133
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7133
Pull-request-sha: 2633f34f7f5336a4a85bd3f71d07bca33ce27a2c

Change-Id: I15512c94e1bc1f52afc619d82057ca647d274e92

doc/build/changelog/unreleased_14/4123.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/sql/test_compare.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_14/4123.rst b/doc/build/changelog/unreleased_14/4123.rst
new file mode 100644 (file)
index 0000000..df1f9c1
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4123
+
+    Repaired issue in new :paramref:`_sql.HasCTE.cte.nesting` parameter
+    introduced with :ticket:`4123` where a recursive :class:`_sql.CTE` using
+    :paramref:`_sql.HasCTE.cte.recursive` in typical conjunction with UNION
+    would not compile correctly.  Additionally makes some adjustments so that
+    the :class:`_sql.CTE` construct creates a correct cache key.
+    Pull request courtesy Eric Masseran.
index 333ed36f41f56355f5e50814316ac18c27be9151..efcfe0e51c0791ad79aed260ceb1870ada813a08 100644 (file)
@@ -840,10 +840,17 @@ class SQLCompiler(Compiled):
 
         """
         # collect CTEs to tack on top of a SELECT
+        # To store the query to print - Dict[cte, text_query]
         self.ctes = util.OrderedDict()
-        # Detect same CTE references
-        self.ctes_by_name = {}
-        self.level_by_ctes = {}
+
+        # Detect same CTE references - Dict[(level, name), cte]
+        # Level is required for supporting nesting
+        self.ctes_by_level_name = {}
+
+        # To retrieve key/level in ctes_by_level_name -
+        # Dict[cte_reference, (level, cte_name)]
+        self.level_name_by_cte = {}
+
         self.ctes_recursive = False
         if self.positional:
             self.cte_positional = {}
@@ -2515,8 +2522,6 @@ class SQLCompiler(Compiled):
     ):
         self._init_cte_state()
 
-        cte_level = len(self.stack) if cte.nesting else 1
-
         kwargs["visiting_cte"] = cte
 
         cte_name = cte.name
@@ -2527,44 +2532,60 @@ class SQLCompiler(Compiled):
         is_new_cte = True
         embedded_in_current_named_cte = False
 
-        if cte in self.level_by_ctes:
-            cte_level = self.level_by_ctes[cte]
+        _reference_cte = cte._get_reference_cte()
+
+        if _reference_cte in self.level_name_by_cte:
+            cte_level, _ = self.level_name_by_cte[_reference_cte]
+            assert _ == cte_name
+        else:
+            cte_level = len(self.stack) if cte.nesting else 1
 
         cte_level_name = (cte_level, cte_name)
-        if cte_level_name in self.ctes_by_name:
-            existing_cte = self.ctes_by_name[cte_level_name]
+        if cte_level_name in self.ctes_by_level_name:
+            existing_cte = self.ctes_by_level_name[cte_level_name]
             embedded_in_current_named_cte = visiting_cte is existing_cte
 
             # we've generated a same-named CTE that we are enclosed in,
             # or this is the same CTE.  just return the name.
-            if cte in existing_cte._restates or cte is existing_cte:
+            if cte is existing_cte._restates or cte is existing_cte:
                 is_new_cte = False
-            elif existing_cte in cte._restates:
+            elif existing_cte is cte._restates:
                 # we've generated a same-named CTE that is
                 # enclosed in us - we take precedence, so
                 # discard the text for the "inner".
                 del self.ctes[existing_cte]
-                del self.level_by_ctes[existing_cte]
+
+                existing_cte_reference_cte = existing_cte._get_reference_cte()
+
+                # TODO: determine if these assertions are correct.  they
+                # pass for current test cases
+                # assert existing_cte_reference_cte is _reference_cte
+                # assert existing_cte_reference_cte is existing_cte
+
+                del self.level_name_by_cte[existing_cte_reference_cte]
             else:
                 raise exc.CompileError(
                     "Multiple, unrelated CTEs found with "
                     "the same name: %r" % cte_name
                 )
 
-        if asfrom or is_new_cte:
-            if cte._cte_alias is not None:
-                pre_alias_cte = cte._cte_alias
-                cte_pre_alias_name = cte._cte_alias.name
-                if isinstance(cte_pre_alias_name, elements._truncated_label):
-                    cte_pre_alias_name = self._truncated_identifier(
-                        "alias", cte_pre_alias_name
-                    )
-            else:
-                pre_alias_cte = cte
-                cte_pre_alias_name = None
+        if not asfrom and not is_new_cte:
+            return None
+
+        if cte._cte_alias is not None:
+            pre_alias_cte = cte._cte_alias
+            cte_pre_alias_name = cte._cte_alias.name
+            if isinstance(cte_pre_alias_name, elements._truncated_label):
+                cte_pre_alias_name = self._truncated_identifier(
+                    "alias", cte_pre_alias_name
+                )
+        else:
+            pre_alias_cte = cte
+            cte_pre_alias_name = None
 
         if is_new_cte:
-            self.ctes_by_name[cte_level_name] = cte
+            self.ctes_by_level_name[cte_level_name] = cte
+            self.level_name_by_cte[_reference_cte] = cte_level_name
 
             if (
                 "autocommit" in cte.element._execution_options
@@ -2649,7 +2670,6 @@ class SQLCompiler(Compiled):
                     )
 
                 self.ctes[cte] = text
-                self.level_by_ctes[cte] = cte_level
 
         if asfrom:
             if from_linter:
@@ -3475,7 +3495,9 @@ class SQLCompiler(Compiled):
         if nesting_level and nesting_level > 1:
             ctes = util.OrderedDict()
             for cte in list(self.ctes.keys()):
-                cte_level = self.level_by_ctes[cte]
+                cte_level, cte_name = self.level_name_by_cte[
+                    cte._get_reference_cte()
+                ]
                 is_rendered_level = cte_level == nesting_level or (
                     include_following_stack and cte_level == nesting_level + 1
                 )
@@ -3484,14 +3506,6 @@ class SQLCompiler(Compiled):
 
                 ctes[cte] = self.ctes[cte]
 
-                del self.ctes[cte]
-                del self.level_by_ctes[cte]
-
-                cte_name = cte.name
-                if isinstance(cte_name, elements._truncated_label):
-                    cte_name = self._truncated_identifier("alias", cte_name)
-
-                del self.ctes_by_name[(cte_level, cte_name)]
         else:
             ctes = self.ctes
 
@@ -3508,6 +3522,16 @@ class SQLCompiler(Compiled):
         cte_text = self.get_cte_preamble(ctes_recursive) + " "
         cte_text += ", \n".join([txt for txt in ctes.values()])
         cte_text += "\n "
+
+        if nesting_level and nesting_level > 1:
+            for cte in list(ctes.keys()):
+                cte_level, cte_name = self.level_name_by_cte[
+                    cte._get_reference_cte()
+                ]
+                del self.ctes[cte]
+                del self.ctes_by_level_name[(cte_level, cte_name)]
+                del self.level_name_by_cte[cte._get_reference_cte()]
+
         return cte_text
 
     def get_cte_preamble(self, recursive):
index 8e71dfb97faf1659f5e6fc9635a9b6e3c1ecd5de..616df0d05b565ffa9e6b79ac8ea226507eed959d 100644 (file)
@@ -2049,8 +2049,9 @@ class CTE(
         AliasedReturnsRows._traverse_internals
         + [
             ("_cte_alias", InternalTraversal.dp_clauseelement),
-            ("_restates", InternalTraversal.dp_clauseelement_list),
+            ("_restates", InternalTraversal.dp_clauseelement),
             ("recursive", InternalTraversal.dp_boolean),
+            ("nesting", InternalTraversal.dp_boolean),
         ]
         + HasPrefixes._has_prefixes_traverse_internals
         + HasSuffixes._has_suffixes_traverse_internals
@@ -2075,13 +2076,14 @@ class CTE(
         recursive=False,
         nesting=False,
         _cte_alias=None,
-        _restates=(),
+        _restates=None,
         _prefixes=None,
         _suffixes=None,
     ):
         self.recursive = recursive
         self.nesting = nesting
         self._cte_alias = _cte_alias
+        # Keep recursivity reference with union/union_all
         self._restates = _restates
         if _prefixes:
             self._prefixes = _prefixes
@@ -2125,7 +2127,7 @@ class CTE(
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
-            _restates=self._restates + (self,),
+            _restates=self,
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
         )
@@ -2136,11 +2138,19 @@ class CTE(
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
-            _restates=self._restates + (self,),
+            _restates=self,
             _prefixes=self._prefixes,
             _suffixes=self._suffixes,
         )
 
+    def _get_reference_cte(self):
+        """
+        A recursive CTE is updated to attach the recursive part.
+        Updated CTEs should still refer to the original CTE.
+        This function returns this reference identifier.
+        """
+        return self._restates if self._restates is not None else self
+
 
 class HasCTE(roles.HasCTERole):
     """Mixin that declares a class to include CTE support.
index eaae0c448dc53e9358b58f56428d43fef20d67cf..2db7a5744648b377abce5de833f7429ee308b5aa 100644 (file)
@@ -502,6 +502,7 @@ class CoreFixtures(object):
         ),
         lambda: (
             select(table_a.c.a).cte(),
+            select(table_a.c.a).cte(nesting=True),
             select(table_a.c.a).cte(recursive=True),
             select(table_a.c.a).cte(name="some_cte", recursive=True),
             select(table_a.c.a).cte(name="some_cte"),
@@ -830,7 +831,17 @@ class CoreFixtures(object):
             )
             return stmt
 
-        return [one(), one_diff(), two(), three()]
+        def four():
+            stmt = select(table_a.c.a).cte(recursive=True)
+            stmt = stmt.union(select(stmt.c.a + 1).where(stmt.c.a < 10))
+            return stmt
+
+        def five():
+            stmt = select(table_a.c.a).cte(recursive=True, nesting=True)
+            stmt = stmt.union(select(stmt.c.a + 1).where(stmt.c.a < 10))
+            return stmt
+
+        return [one(), one_diff(), two(), three(), four(), five()]
 
     fixtures.append(_complex_fixtures)
 
index 5d24adff9302548fbd37c59441765ec2db68ae10..22107eeee51d9bbdb9aae6ea61204443f509ea93 100644 (file)
@@ -444,6 +444,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         cte = s1.cte(name="cte", recursive=True)
 
         bar = select(cte).cte("bar").alias("cs1")
+
         cte = cte.union_all(select(cte.c.x + 1).where(cte.c.x < 10)).alias(
             "cs2"
         )
@@ -1740,6 +1741,24 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT cte.outer_cte FROM cte",
         )
 
+    def test_select_with_aliased_nesting_cte_in_cte(self):
+        nesting_cte = (
+            select(literal(1).label("inner_cte"))
+            .cte("nesting", nesting=True)
+            .alias("aliased_nested")
+        )
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
+            "SELECT aliased_nested.inner_cte AS outer_cte "
+            "FROM nesting AS aliased_nested) "
+            "SELECT cte.outer_cte FROM cte",
+        )
+
     def test_nesting_cte_in_cte_with_same_name(self):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "some_cte", nesting=True
@@ -1901,24 +1920,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "nesting", nesting=True
         )
-        stmt = select(
-            select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
-                "cte", recursive=True
-            )
+
+        rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+            "rec_cte", recursive=True
         )
+        rec_part = select(rec_cte.c.outer_cte).where(
+            rec_cte.c.outer_cte == literal(1)
+        )
+        rec_cte = rec_cte.union(rec_part)
+
+        stmt = select(rec_cte)
 
         self.assert_compile(
             stmt,
-            "WITH RECURSIVE cte(outer_cte) AS (WITH nesting AS "
+            "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
             "(SELECT :param_1 AS inner_cte) "
-            "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
-            "SELECT cte.outer_cte FROM cte",
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = :param_2) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
         )
 
     def test_recursive_nesting_cte_in_cte(self):
-        nesting_cte = select(literal(1).label("inner_cte")).cte(
-            "nesting", nesting=True, recursive=True
+        rec_root = select(literal(1).label("inner_cte")).cte(
+            "nesting", recursive=True, nesting=True
+        )
+        rec_part = select(rec_root.c.inner_cte).where(
+            rec_root.c.inner_cte == literal(1)
         )
+        nesting_cte = rec_root.union(rec_part)
+
         stmt = select(
             select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
         )
@@ -1926,11 +1957,89 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             stmt,
             "WITH cte AS (WITH RECURSIVE nesting(inner_cte) AS "
-            "(SELECT :param_1 AS inner_cte) "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+            "WHERE nesting.inner_cte = :param_2) "
             "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
             "SELECT cte.outer_cte FROM cte",
         )
 
+    def test_anon_recursive_nesting_cte_in_cte(self):
+        rec_root = (
+            select(literal(1).label("inner_cte"))
+            .cte("nesting", recursive=True, nesting=True)
+            .alias()
+        )
+        rec_part = select(rec_root.c.inner_cte).where(
+            rec_root.c.inner_cte == literal(1)
+        )
+        nesting_cte = rec_root.union(rec_part)
+
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS (WITH RECURSIVE anon_1(inner_cte) AS "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT anon_1.inner_cte AS inner_cte FROM anon_1 "
+            "WHERE anon_1.inner_cte = :param_2) "
+            "SELECT anon_1.inner_cte AS outer_cte FROM anon_1) "
+            "SELECT cte.outer_cte FROM cte",
+        )
+
+    def test_fully_aliased_recursive_nesting_cte_in_cte(self):
+        rec_root = (
+            select(literal(1).label("inner_cte"))
+            .cte("nesting", recursive=True, nesting=True)
+            .alias("aliased_nesting")
+        )
+        rec_part = select(rec_root.c.inner_cte).where(
+            rec_root.c.inner_cte == literal(1)
+        )
+        nesting_cte = rec_root.union(rec_part)
+
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS (WITH RECURSIVE aliased_nesting(inner_cte) AS "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT aliased_nesting.inner_cte AS inner_cte "
+            "FROM aliased_nesting "
+            "WHERE aliased_nesting.inner_cte = :param_2) "
+            "SELECT aliased_nesting.inner_cte AS outer_cte "
+            "FROM aliased_nesting) "
+            "SELECT cte.outer_cte FROM cte",
+        )
+
+    def test_aliased_recursive_nesting_cte_in_cte(self):
+        rec_root = select(literal(1).label("inner_cte")).cte(
+            "nesting", recursive=True, nesting=True
+        )
+        rec_part = select(rec_root.c.inner_cte).where(
+            rec_root.c.inner_cte == literal(1)
+        )
+        nesting_cte = rec_root.union(rec_part).alias("aliased_nesting")
+
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS (WITH RECURSIVE nesting(inner_cte) AS "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+            "WHERE nesting.inner_cte = :param_2) "
+            "SELECT aliased_nesting.inner_cte AS outer_cte "
+            "FROM nesting AS aliased_nesting) "
+            "SELECT cte.outer_cte FROM cte",
+        )
+
     def test_same_nested_cte_is_not_generated_twice(self):
         # Same = name and query
         nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
@@ -1976,19 +2085,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "nesting", nesting=True, recursive=True
         )
-        stmt = select(
-            select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
-                "cte", recursive=True
-            )
+        nesting_rec_part = select(nesting_cte.c.inner_cte).where(
+            nesting_cte.c.inner_cte == literal(1)
+        )
+        nesting_cte = nesting_cte.union(nesting_rec_part)
+
+        rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+            "rec_cte", recursive=True
         )
+        rec_part = select(rec_cte.c.outer_cte).where(
+            rec_cte.c.outer_cte == literal(1)
+        )
+        rec_cte = rec_cte.union(rec_part)
+
+        stmt = select(rec_cte)
 
         self.assert_compile(
             stmt,
-            "WITH RECURSIVE cte(outer_cte) AS "
-            "(WITH RECURSIVE nesting(inner_cte) "
-            "AS (SELECT :param_1 AS inner_cte) "
-            "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
-            "SELECT cte.outer_cte FROM cte",
+            "WITH RECURSIVE rec_cte(outer_cte) AS ("
+            "WITH RECURSIVE nesting(inner_cte) AS "
+            "(SELECT :param_1 AS inner_cte UNION "
+            "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+            "WHERE nesting.inner_cte = :param_2) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = :param_3) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
         )
 
     def test_select_from_insert_cte_with_nesting(self):