From: Eric Masseran Date: Fri, 8 Oct 2021 14:02:58 +0000 (-0400) Subject: Fix recursive CTE to support nesting X-Git-Tag: rel_1_4_26~21 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=ee9b8836a160484733baa556c5d3ade4810aa999;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix recursive CTE to support nesting Repaired issue in new :paramref:`_sql.HasCTE.cte.nesting` parameter introduced with :ticket:`4123` where a recursive :class:`_sql.CTE` using :paramref:`_sql.HasCTE.cte.recursive` in typical conjunction with UNION would not compile correctly. Additionally makes some adjustments so that the :class:`_sql.CTE` construct creates a correct cache key. Pull request courtesy Eric Masseran. Fixes: #4123 > This has not been caught by the tests because the nesting recursive queries there did not union against itself, eg there was only the i root clause... - Now tests are real recursive queries - Add tests on aliased nested CTEs (recursive or not) - Adapt the `_restates` attribute to use it as a reference - Add some docs around to explain some variables usage Closes: #7133 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7133 Pull-request-sha: 2633f34f7f5336a4a85bd3f71d07bca33ce27a2c Change-Id: I15512c94e1bc1f52afc619d82057ca647d274e92 --- diff --git a/doc/build/changelog/unreleased_14/4123.rst b/doc/build/changelog/unreleased_14/4123.rst new file mode 100644 index 0000000000..df1f9c1d36 --- /dev/null +++ b/doc/build/changelog/unreleased_14/4123.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 4123 + + Repaired issue in new :paramref:`_sql.HasCTE.cte.nesting` parameter + introduced with :ticket:`4123` where a recursive :class:`_sql.CTE` using + :paramref:`_sql.HasCTE.cte.recursive` in typical conjunction with UNION + would not compile correctly. Additionally makes some adjustments so that + the :class:`_sql.CTE` construct creates a correct cache key. + Pull request courtesy Eric Masseran. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 333ed36f41..efcfe0e51c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -840,10 +840,17 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT + # To store the query to print - Dict[cte, text_query] self.ctes = util.OrderedDict() - # Detect same CTE references - self.ctes_by_name = {} - self.level_by_ctes = {} + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + self.ctes_by_level_name = {} + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name)] + self.level_name_by_cte = {} + self.ctes_recursive = False if self.positional: self.cte_positional = {} @@ -2515,8 +2522,6 @@ class SQLCompiler(Compiled): ): self._init_cte_state() - cte_level = len(self.stack) if cte.nesting else 1 - kwargs["visiting_cte"] = cte cte_name = cte.name @@ -2527,44 +2532,60 @@ class SQLCompiler(Compiled): is_new_cte = True embedded_in_current_named_cte = False - if cte in self.level_by_ctes: - cte_level = self.level_by_ctes[cte] + _reference_cte = cte._get_reference_cte() + + if _reference_cte in self.level_name_by_cte: + cte_level, _ = self.level_name_by_cte[_reference_cte] + assert _ == cte_name + else: + cte_level = len(self.stack) if cte.nesting else 1 cte_level_name = (cte_level, cte_name) - if cte_level_name in self.ctes_by_name: - existing_cte = self.ctes_by_name[cte_level_name] + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, # or this is the same CTE. just return the name. - if cte in existing_cte._restates or cte is existing_cte: + if cte is existing_cte._restates or cte is existing_cte: is_new_cte = False - elif existing_cte in cte._restates: + elif existing_cte is cte._restates: # we've generated a same-named CTE that is # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] - del self.level_by_ctes[existing_cte] + + existing_cte_reference_cte = existing_cte._get_reference_cte() + + # TODO: determine if these assertions are correct. they + # pass for current test cases + # assert existing_cte_reference_cte is _reference_cte + # assert existing_cte_reference_cte is existing_cte + + del self.level_name_by_cte[existing_cte_reference_cte] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " "the same name: %r" % cte_name ) - if asfrom or is_new_cte: - if cte._cte_alias is not None: - pre_alias_cte = cte._cte_alias - cte_pre_alias_name = cte._cte_alias.name - if isinstance(cte_pre_alias_name, elements._truncated_label): - cte_pre_alias_name = self._truncated_identifier( - "alias", cte_pre_alias_name - ) - else: - pre_alias_cte = cte - cte_pre_alias_name = None + if not asfrom and not is_new_cte: + return None + + if cte._cte_alias is not None: + pre_alias_cte = cte._cte_alias + cte_pre_alias_name = cte._cte_alias.name + if isinstance(cte_pre_alias_name, elements._truncated_label): + cte_pre_alias_name = self._truncated_identifier( + "alias", cte_pre_alias_name + ) + else: + pre_alias_cte = cte + cte_pre_alias_name = None if is_new_cte: - self.ctes_by_name[cte_level_name] = cte + self.ctes_by_level_name[cte_level_name] = cte + self.level_name_by_cte[_reference_cte] = cte_level_name if ( "autocommit" in cte.element._execution_options @@ -2649,7 +2670,6 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text - self.level_by_ctes[cte] = cte_level if asfrom: if from_linter: @@ -3475,7 +3495,9 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte_level = self.level_by_ctes[cte] + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) @@ -3484,14 +3506,6 @@ class SQLCompiler(Compiled): ctes[cte] = self.ctes[cte] - del self.ctes[cte] - del self.level_by_ctes[cte] - - cte_name = cte.name - if isinstance(cte_name, elements._truncated_label): - cte_name = self._truncated_identifier("alias", cte_name) - - del self.ctes_by_name[(cte_level, cte_name)] else: ctes = self.ctes @@ -3508,6 +3522,16 @@ class SQLCompiler(Compiled): cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " + + if nesting_level and nesting_level > 1: + for cte in list(ctes.keys()): + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] + del self.ctes[cte] + del self.ctes_by_level_name[(cte_level, cte_name)] + del self.level_name_by_cte[cte._get_reference_cte()] + return cte_text def get_cte_preamble(self, recursive): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8e71dfb97f..616df0d05b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2049,8 +2049,9 @@ class CTE( AliasedReturnsRows._traverse_internals + [ ("_cte_alias", InternalTraversal.dp_clauseelement), - ("_restates", InternalTraversal.dp_clauseelement_list), + ("_restates", InternalTraversal.dp_clauseelement), ("recursive", InternalTraversal.dp_boolean), + ("nesting", InternalTraversal.dp_boolean), ] + HasPrefixes._has_prefixes_traverse_internals + HasSuffixes._has_suffixes_traverse_internals @@ -2075,13 +2076,14 @@ class CTE( recursive=False, nesting=False, _cte_alias=None, - _restates=(), + _restates=None, _prefixes=None, _suffixes=None, ): self.recursive = recursive self.nesting = nesting self._cte_alias = _cte_alias + # Keep recursivity reference with union/union_all self._restates = _restates if _prefixes: self._prefixes = _prefixes @@ -2125,7 +2127,7 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, - _restates=self._restates + (self,), + _restates=self, _prefixes=self._prefixes, _suffixes=self._suffixes, ) @@ -2136,11 +2138,19 @@ class CTE( name=self.name, recursive=self.recursive, nesting=self.nesting, - _restates=self._restates + (self,), + _restates=self, _prefixes=self._prefixes, _suffixes=self._suffixes, ) + def _get_reference_cte(self): + """ + A recursive CTE is updated to attach the recursive part. + Updated CTEs should still refer to the original CTE. + This function returns this reference identifier. + """ + return self._restates if self._restates is not None else self + class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index eaae0c448d..2db7a57446 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -502,6 +502,7 @@ class CoreFixtures(object): ), lambda: ( select(table_a.c.a).cte(), + select(table_a.c.a).cte(nesting=True), select(table_a.c.a).cte(recursive=True), select(table_a.c.a).cte(name="some_cte", recursive=True), select(table_a.c.a).cte(name="some_cte"), @@ -830,7 +831,17 @@ class CoreFixtures(object): ) return stmt - return [one(), one_diff(), two(), three()] + def four(): + stmt = select(table_a.c.a).cte(recursive=True) + stmt = stmt.union(select(stmt.c.a + 1).where(stmt.c.a < 10)) + return stmt + + def five(): + stmt = select(table_a.c.a).cte(recursive=True, nesting=True) + stmt = stmt.union(select(stmt.c.a + 1).where(stmt.c.a < 10)) + return stmt + + return [one(), one_diff(), two(), three(), four(), five()] fixtures.append(_complex_fixtures) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 5d24adff93..22107eeee5 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -444,6 +444,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): cte = s1.cte(name="cte", recursive=True) bar = select(cte).cte("bar").alias("cs1") + cte = cte.union_all(select(cte.c.x + 1).where(cte.c.x < 10)).alias( "cs2" ) @@ -1740,6 +1741,24 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) + def test_select_with_aliased_nesting_cte_in_cte(self): + nesting_cte = ( + select(literal(1).label("inner_cte")) + .cte("nesting", nesting=True) + .alias("aliased_nested") + ) + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) " + "SELECT aliased_nested.inner_cte AS outer_cte " + "FROM nesting AS aliased_nested) " + "SELECT cte.outer_cte FROM cte", + ) + def test_nesting_cte_in_cte_with_same_name(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "some_cte", nesting=True @@ -1901,24 +1920,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True ) - stmt = select( - select(nesting_cte.c.inner_cte.label("outer_cte")).cte( - "cte", recursive=True - ) + + rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte( + "rec_cte", recursive=True ) + rec_part = select(rec_cte.c.outer_cte).where( + rec_cte.c.outer_cte == literal(1) + ) + rec_cte = rec_cte.union(rec_part) + + stmt = select(rec_cte) self.assert_compile( stmt, - "WITH RECURSIVE cte(outer_cte) AS (WITH nesting AS " + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " "(SELECT :param_1 AS inner_cte) " - "SELECT nesting.inner_cte AS outer_cte FROM nesting) " - "SELECT cte.outer_cte FROM cte", + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = :param_2) " + "SELECT rec_cte.outer_cte FROM rec_cte", ) def test_recursive_nesting_cte_in_cte(self): - nesting_cte = select(literal(1).label("inner_cte")).cte( - "nesting", nesting=True, recursive=True + rec_root = select(literal(1).label("inner_cte")).cte( + "nesting", recursive=True, nesting=True + ) + rec_part = select(rec_root.c.inner_cte).where( + rec_root.c.inner_cte == literal(1) ) + nesting_cte = rec_root.union(rec_part) + stmt = select( select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") ) @@ -1926,11 +1957,89 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, "WITH cte AS (WITH RECURSIVE nesting(inner_cte) AS " - "(SELECT :param_1 AS inner_cte) " + "(SELECT :param_1 AS inner_cte UNION " + "SELECT nesting.inner_cte AS inner_cte FROM nesting " + "WHERE nesting.inner_cte = :param_2) " "SELECT nesting.inner_cte AS outer_cte FROM nesting) " "SELECT cte.outer_cte FROM cte", ) + def test_anon_recursive_nesting_cte_in_cte(self): + rec_root = ( + select(literal(1).label("inner_cte")) + .cte("nesting", recursive=True, nesting=True) + .alias() + ) + rec_part = select(rec_root.c.inner_cte).where( + rec_root.c.inner_cte == literal(1) + ) + nesting_cte = rec_root.union(rec_part) + + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (WITH RECURSIVE anon_1(inner_cte) AS " + "(SELECT :param_1 AS inner_cte UNION " + "SELECT anon_1.inner_cte AS inner_cte FROM anon_1 " + "WHERE anon_1.inner_cte = :param_2) " + "SELECT anon_1.inner_cte AS outer_cte FROM anon_1) " + "SELECT cte.outer_cte FROM cte", + ) + + def test_fully_aliased_recursive_nesting_cte_in_cte(self): + rec_root = ( + select(literal(1).label("inner_cte")) + .cte("nesting", recursive=True, nesting=True) + .alias("aliased_nesting") + ) + rec_part = select(rec_root.c.inner_cte).where( + rec_root.c.inner_cte == literal(1) + ) + nesting_cte = rec_root.union(rec_part) + + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (WITH RECURSIVE aliased_nesting(inner_cte) AS " + "(SELECT :param_1 AS inner_cte UNION " + "SELECT aliased_nesting.inner_cte AS inner_cte " + "FROM aliased_nesting " + "WHERE aliased_nesting.inner_cte = :param_2) " + "SELECT aliased_nesting.inner_cte AS outer_cte " + "FROM aliased_nesting) " + "SELECT cte.outer_cte FROM cte", + ) + + def test_aliased_recursive_nesting_cte_in_cte(self): + rec_root = select(literal(1).label("inner_cte")).cte( + "nesting", recursive=True, nesting=True + ) + rec_part = select(rec_root.c.inner_cte).where( + rec_root.c.inner_cte == literal(1) + ) + nesting_cte = rec_root.union(rec_part).alias("aliased_nesting") + + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (WITH RECURSIVE nesting(inner_cte) AS " + "(SELECT :param_1 AS inner_cte UNION " + "SELECT nesting.inner_cte AS inner_cte FROM nesting " + "WHERE nesting.inner_cte = :param_2) " + "SELECT aliased_nesting.inner_cte AS outer_cte " + "FROM nesting AS aliased_nesting) " + "SELECT cte.outer_cte FROM cte", + ) + def test_same_nested_cte_is_not_generated_twice(self): # Same = name and query nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( @@ -1976,19 +2085,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True, recursive=True ) - stmt = select( - select(nesting_cte.c.inner_cte.label("outer_cte")).cte( - "cte", recursive=True - ) + nesting_rec_part = select(nesting_cte.c.inner_cte).where( + nesting_cte.c.inner_cte == literal(1) + ) + nesting_cte = nesting_cte.union(nesting_rec_part) + + rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte( + "rec_cte", recursive=True ) + rec_part = select(rec_cte.c.outer_cte).where( + rec_cte.c.outer_cte == literal(1) + ) + rec_cte = rec_cte.union(rec_part) + + stmt = select(rec_cte) self.assert_compile( stmt, - "WITH RECURSIVE cte(outer_cte) AS " - "(WITH RECURSIVE nesting(inner_cte) " - "AS (SELECT :param_1 AS inner_cte) " - "SELECT nesting.inner_cte AS outer_cte FROM nesting) " - "SELECT cte.outer_cte FROM cte", + "WITH RECURSIVE rec_cte(outer_cte) AS (" + "WITH RECURSIVE nesting(inner_cte) AS " + "(SELECT :param_1 AS inner_cte UNION " + "SELECT nesting.inner_cte AS inner_cte FROM nesting " + "WHERE nesting.inner_cte = :param_2) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = :param_3) " + "SELECT rec_cte.outer_cte FROM rec_cte", ) def test_select_from_insert_cte_with_nesting(self):