From 0fe8f4a3e79c8fc805e7a84849920c7258177f41 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 23 Feb 2022 12:24:31 -0500 Subject: [PATCH] Add more nesting features to add_cte() Added new parameter :paramref:`.HasCTE.add_cte.nest_here` to :meth:`.HasCTE.add_cte` which will "nest" a given :class:`.CTE` at the level of the parent statement. This parameter is equivalent to using the :paramref:`.HasCTE.cte.nesting` parameter, but may be more intuitive in some scenarios as it allows the nesting attribute to be set simultaneously along with the explicit level of the CTE. The :meth:`.HasCTE.add_cte` method also accepts multiple CTE objects. Fixes: #7759 Change-Id: I263c015f5a3f452cb54819aee12bc9bf2953a7bb --- doc/build/changelog/unreleased_20/7759.rst | 12 + lib/sqlalchemy/sql/compiler.py | 148 +++++++---- lib/sqlalchemy/sql/selectable.py | 96 ++++++- test/sql/test_cte.py | 293 +++++++++++++++++++++ 4 files changed, 485 insertions(+), 64 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/7759.rst diff --git a/doc/build/changelog/unreleased_20/7759.rst b/doc/build/changelog/unreleased_20/7759.rst new file mode 100644 index 0000000000..b7f3bff8d0 --- /dev/null +++ b/doc/build/changelog/unreleased_20/7759.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, sql + :tickets: 7759 + + Added new parameter :paramref:`.HasCTE.add_cte.nest_here` to + :meth:`.HasCTE.add_cte` which will "nest" a given :class:`.CTE` at the + level of the parent statement. This parameter is equivalent to using the + :paramref:`.HasCTE.cte.nesting` parameter, but may be more intuitive in + some scenarios as it allows the nesting attribute to be set simultaneously + along with the explicit level of the CTE. + + The :meth:`.HasCTE.add_cte` method also accepts multiple CTE objects. \ No newline at end of file diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b140f92975..77bc1ea38d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -31,6 +31,13 @@ import itertools import operator import re from time import perf_counter +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import MutableMapping +from typing import Optional +from typing import Tuple from . import base from . import coercions @@ -47,6 +54,12 @@ from .elements import quoted_name from .. import exc from .. import util +if typing.TYPE_CHECKING: + from .selectable import CTE + from .selectable import FromClause + +_FromHintsType = Dict["FromClause", str] + RESERVED_WORDS = set( [ "all", @@ -842,7 +855,7 @@ class SQLCompiler(Compiled): return {} @util.memoized_instancemethod - def _init_cte_state(self): + def _init_cte_state(self) -> None: """Initialize collections related to CTEs only if a CTE is located, to save on the overhead of these collections otherwise. @@ -850,19 +863,21 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT # To store the query to print - Dict[cte, text_query] - self.ctes = util.OrderedDict() + self.ctes: MutableMapping[CTE, str] = util.OrderedDict() # Detect same CTE references - Dict[(level, name), cte] # Level is required for supporting nesting - self.ctes_by_level_name = {} + self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {} # To retrieve key/level in ctes_by_level_name - - # Dict[cte_reference, (level, cte_name)] - self.level_name_by_cte = {} + # Dict[cte_reference, (level, cte_name, cte_opts)] + self.level_name_by_cte: Dict[ + CTE, Tuple[int, str, selectable._CTEOpts] + ] = {} - self.ctes_recursive = False + self.ctes_recursive: bool = False if self.positional: - self.cte_positional = {} + self.cte_positional: Dict[CTE, List[str]] = {} @contextlib.contextmanager def _nested_result(self): @@ -1604,8 +1619,7 @@ class SQLCompiler(Compiled): self.stack.append(new_entry) if taf._independent_ctes: - for cte in taf._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(taf, kw) populate_result_map = ( toplevel @@ -1879,8 +1893,7 @@ class SQLCompiler(Compiled): ) if compound_stmt._independent_ctes: - for cte in compound_stmt._independent_ctes: - cte._compiler_dispatch(self, **kwargs) + self._dispatch_independent_ctes(compound_stmt, kwargs) keyword = self.compound_keywords.get(cs.keyword) @@ -2671,16 +2684,25 @@ class SQLCompiler(Compiled): return ret + def _dispatch_independent_ctes(self, stmt, kw): + local_kw = kw.copy() + local_kw.pop("cte_opts", None) + for cte, opt in zip( + stmt._independent_ctes, stmt._independent_ctes_opts + ): + cte._compiler_dispatch(self, cte_opts=opt, **local_kw) + def visit_cte( self, - cte, - asfrom=False, - ashint=False, - fromhints=None, - visiting_cte=None, - from_linter=None, - **kwargs, - ): + cte: CTE, + asfrom: bool = False, + ashint: bool = False, + fromhints: Optional[_FromHintsType] = None, + visiting_cte: Optional[CTE] = None, + from_linter: Optional[FromLinter] = None, + cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), + **kwargs: Any, + ) -> Optional[str]: self._init_cte_state() kwargs["visiting_cte"] = cte @@ -2695,15 +2717,48 @@ class SQLCompiler(Compiled): _reference_cte = cte._get_reference_cte() + nesting = cte.nesting or cte_opts.nesting + + # check for CTE already encountered if _reference_cte in self.level_name_by_cte: - cte_level, _ = self.level_name_by_cte[_reference_cte] + cte_level, _, existing_cte_opts = self.level_name_by_cte[ + _reference_cte + ] assert _ == cte_name - else: - cte_level = len(self.stack) if cte.nesting else 1 - cte_level_name = (cte_level, cte_name) - if cte_level_name in self.ctes_by_level_name: + cte_level_name = (cte_level, cte_name) existing_cte = self.ctes_by_level_name[cte_level_name] + + # check if we are receiving it here with a specific + # "nest_here" location; if so, move it to this location + + if cte_opts.nesting: + if existing_cte_opts.nesting: + raise exc.CompileError( + "CTE is stated as 'nest_here' in " + "more than one location" + ) + + old_level_name = (cte_level, cte_name) + cte_level = len(self.stack) if nesting else 1 + cte_level_name = new_level_name = (cte_level, cte_name) + + del self.ctes_by_level_name[old_level_name] + self.ctes_by_level_name[new_level_name] = existing_cte + self.level_name_by_cte[_reference_cte] = new_level_name + ( + cte_opts, + ) + + else: + cte_level = len(self.stack) if nesting else 1 + cte_level_name = (cte_level, cte_name) + + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] + else: + existing_cte = None + + if existing_cte is not None: embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, @@ -2718,10 +2773,8 @@ class SQLCompiler(Compiled): existing_cte_reference_cte = existing_cte._get_reference_cte() - # TODO: determine if these assertions are correct. they - # pass for current test cases - # assert existing_cte_reference_cte is _reference_cte - # assert existing_cte_reference_cte is existing_cte + assert existing_cte_reference_cte is _reference_cte + assert existing_cte_reference_cte is existing_cte del self.level_name_by_cte[existing_cte_reference_cte] else: @@ -2746,19 +2799,9 @@ class SQLCompiler(Compiled): if is_new_cte: self.ctes_by_level_name[cte_level_name] = cte - self.level_name_by_cte[_reference_cte] = cte_level_name - - if ( - "autocommit" in cte.element._execution_options - and "autocommit" not in self.execution_options - ): - self.execution_options = self.execution_options.union( - { - "autocommit": cte.element._execution_options[ - "autocommit" - ] - } - ) + self.level_name_by_cte[_reference_cte] = cte_level_name + ( + cte_opts, + ) if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -3378,8 +3421,7 @@ class SQLCompiler(Compiled): byfrom = None if select_stmt._independent_ctes: - for cte in select_stmt._independent_ctes: - cte._compiler_dispatch(self, **kwargs) + self._dispatch_independent_ctes(select_stmt, kwargs) if select_stmt._prefixes: text += self._generate_prefixes( @@ -3485,7 +3527,9 @@ class SQLCompiler(Compiled): return text - def _setup_select_hints(self, select): + def _setup_select_hints( + self, select: Select + ) -> Tuple[str, _FromHintsType]: byfrom = dict( [ ( @@ -3663,13 +3707,14 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: ctes = util.OrderedDict() for cte in list(self.ctes.keys()): - cte_level, cte_name = self.level_name_by_cte[ + cte_level, cte_name, cte_opts = self.level_name_by_cte[ cte._get_reference_cte() ] + nesting = cte.nesting or cte_opts.nesting is_rendered_level = cte_level == nesting_level or ( include_following_stack and cte_level == nesting_level + 1 ) - if not (cte.nesting and is_rendered_level): + if not (nesting and is_rendered_level): continue ctes[cte] = self.ctes[cte] @@ -3693,7 +3738,7 @@ class SQLCompiler(Compiled): if nesting_level and nesting_level > 1: for cte in list(ctes.keys()): - cte_level, cte_name = self.level_name_by_cte[ + cte_level, cte_name, cte_opts = self.level_name_by_cte[ cte._get_reference_cte() ] del self.ctes[cte] @@ -3939,8 +3984,7 @@ class SQLCompiler(Compiled): _, table_text = self._setup_crud_hints(insert_stmt, table_text) if insert_stmt._independent_ctes: - for cte in insert_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(insert_stmt, kw) text += table_text @@ -4108,8 +4152,7 @@ class SQLCompiler(Compiled): dialect_hints = None if update_stmt._independent_ctes: - for cte in update_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(update_stmt, kw) text += table_text @@ -4221,8 +4264,7 @@ class SQLCompiler(Compiled): dialect_hints = None if delete_stmt._independent_ctes: - for cte in delete_stmt._independent_ctes: - cte._compiler_dispatch(self, **kw) + self._dispatch_independent_ctes(delete_stmt, kw) text += table_text diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 7f6360edb0..836c30af74 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -19,6 +19,7 @@ import itertools from operator import attrgetter import typing from typing import Any as TODO_Any +from typing import NamedTuple from typing import Optional from typing import Tuple @@ -1809,6 +1810,10 @@ class CTE( SelfHasCTE = typing.TypeVar("SelfHasCTE", bound="HasCTE") +class _CTEOpts(NamedTuple): + nesting: bool + + class HasCTE(roles.HasCTERole): """Mixin that declares a class to include CTE support. @@ -1818,20 +1823,36 @@ class HasCTE(roles.HasCTERole): _has_ctes_traverse_internals = [ ("_independent_ctes", InternalTraversal.dp_clauseelement_list), + ("_independent_ctes_opts", InternalTraversal.dp_plain_obj), ] _independent_ctes = () + _independent_ctes_opts = () @_generative - def add_cte(self: SelfHasCTE, cte) -> SelfHasCTE: - """Add a :class:`_sql.CTE` to this statement object that will be - independently rendered even if not referenced in the statement - otherwise. + def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE: + r"""Add one or more :class:`_sql.CTE` constructs to this statement. + + This method will associate the given :class:`_sql.CTE` constructs with + the parent statement such that they will each be unconditionally + rendered in the WITH clause of the final statement, even if not + referenced elsewhere within the statement or any sub-selects. + + The optional :paramref:`.HasCTE.add_cte.nest_here` parameter when set + to True will have the effect that each given :class:`_sql.CTE` will + render in a WITH clause rendered directly along with this statement, + rather than being moved to the top of the ultimate rendered statement, + even if this statement is rendered as a subquery within a larger + statement. - This feature is useful for the use case of embedding a DML statement - such as an INSERT or UPDATE as a CTE inline with a primary statement - that may draw from its results indirectly; while PostgreSQL is known - to support this usage, it may not be supported by other backends. + This method has two general uses. One is to embed CTE statements that + serve some purpose without being referenced explicitly, such as the use + case of embedding a DML statement such as an INSERT or UPDATE as a CTE + inline with a primary statement that may draw from its results + indirectly. The other is to provide control over the exact placement + of a particular series of CTE constructs that should remain rendered + directly in terms of a particular statement that may be nested in a + larger statement. E.g.:: @@ -1885,9 +1906,32 @@ class HasCTE(roles.HasCTERole): .. versionadded:: 1.4.21 + :param \*ctes: zero or more :class:`.CTE` constructs. + + .. versionchanged:: 2.0 Multiple CTE instances are accepted + + :param nest_here: if True, the given CTE or CTEs will be rendered + as though they specified the :paramref:`.HasCTE.cte.nesting` flag + to ``True`` when they were added to this :class:`.HasCTE`. + Assuming the given CTEs are not referenced in an outer-enclosing + statement as well, the CTEs given should render at the level of + this statement when this flag is given. + + .. versionadded:: 2.0 + + .. seealso:: + + :paramref:`.HasCTE.cte.nesting` + + """ - cte = coercions.expect(roles.IsCTERole, cte) - self._independent_ctes += (cte,) + opt = _CTEOpts( + nest_here, + ) + for cte in ctes: + cte = coercions.expect(roles.IsCTERole, cte) + self._independent_ctes += (cte,) + self._independent_ctes_opts += (opt,) return self def cte(self, name=None, recursive=False, nesting=False): @@ -1931,10 +1975,18 @@ class HasCTE(roles.HasCTERole): conjunction with UNION ALL in order to derive rows from those already selected. :param nesting: if ``True``, will render the CTE locally to the - actual statement. + statement in which it is referenced. For more complex scenarios, + the :meth:`.HasCTE.add_cte` method using the + :paramref:`.HasCTE.add_cte.nest_here` + parameter may also be used to more carefully + control the exact placement of a particular CTE. .. versionadded:: 1.4.24 + .. seealso:: + + :meth:`.HasCTE.add_cte` + The following examples include two from PostgreSQL's documentation at https://www.postgresql.org/docs/current/static/queries-with.html, as well as additional examples. @@ -2084,6 +2136,28 @@ class HasCTE(roles.HasCTERole): SELECT value_a.n AS a, value_b.n AS b FROM value_a, value_b + The same CTE can be set up using the :meth:`.HasCTE.add_cte` method + as follows (SQLAlchemy 2.0 and above):: + + value_a = select( + literal("root").label("n") + ).cte("value_a") + + # A nested CTE with the same name as the root one + value_a_nested = select( + literal("nesting").label("n") + ).cte("value_a") + + # Nesting CTEs takes ascendency locally + # over the CTEs at a higher level + value_b = ( + select(value_a_nested.c.n). + add_cte(value_a_nested, nest_here=True). + cte("value_b") + ) + + value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b")) + Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above):: edge = Table( diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index b056925048..2ee6fa9f31 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,11 +1,13 @@ from sqlalchemy import Column from sqlalchemy import delete +from sqlalchemy import exc from sqlalchemy import Integer from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL from sqlalchemy import MetaData from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import update from sqlalchemy.dialects import mssql from sqlalchemy.engine import default @@ -25,6 +27,7 @@ from sqlalchemy.sql.visitors import cloned_traverse from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures @@ -1869,6 +1872,21 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) + def test_select_with_nesting_cte_in_cte_w_add_cte(self): + nesting_cte = select(literal(1).label("inner_cte")).cte("nesting") + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")) + .add_cte(nesting_cte, nest_here=True) + .cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting) " + "SELECT cte.outer_cte FROM cte", + ) + def test_select_with_aliased_nesting_cte_in_cte(self): nesting_cte = ( select(literal(1).label("inner_cte")) @@ -1887,6 +1905,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) + def test_select_with_aliased_nesting_cte_in_cte_w_add_cte(self): + inner_nesting_cte = select(literal(1).label("inner_cte")).cte( + "nesting" + ) + outer_cte = select().add_cte(inner_nesting_cte, nest_here=True) + nesting_cte = inner_nesting_cte.alias("aliased_nested") + outer_cte = outer_cte.add_columns( + nesting_cte.c.inner_cte.label("outer_cte") + ).cte("cte") + stmt = select(outer_cte) + + self.assert_compile( + stmt, + "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) " + "SELECT aliased_nested.inner_cte AS outer_cte " + "FROM nesting AS aliased_nested) " + "SELECT cte.outer_cte FROM cte", + ) + def test_nesting_cte_in_cte_with_same_name(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "some_cte", nesting=True @@ -1904,6 +1941,23 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT some_cte.outer_cte FROM some_cte", ) + def test_nesting_cte_in_cte_with_same_name_w_add_cte(self): + nesting_cte = select(literal(1).label("inner_cte")).cte("some_cte") + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")) + .add_cte(nesting_cte, nest_here=True) + .cte("some_cte") + ) + + self.assert_compile( + stmt, + "WITH some_cte AS (WITH some_cte AS " + "(SELECT :param_1 AS inner_cte) " + "SELECT some_cte.inner_cte AS outer_cte " + "FROM some_cte) " + "SELECT some_cte.outer_cte FROM some_cte", + ) + def test_nesting_cte_at_top_level(self): nesting_cte = select(literal(1).label("val")).cte( "nesting_cte", nesting=True @@ -1918,6 +1972,20 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): " SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte", ) + def test_nesting_cte_at_top_level_w_add_cte(self): + nesting_cte = select(literal(1).label("val")).cte("nesting_cte") + cte = select(literal(2).label("val")).cte("cte") + stmt = select(nesting_cte.c.val, cte.c.val).add_cte( + nesting_cte, nest_here=True + ) + + self.assert_compile( + stmt, + "WITH nesting_cte AS (SELECT :param_1 AS val)" + ", cte AS (SELECT :param_2 AS val)" + " SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte", + ) + def test_double_nesting_cte_in_cte(self): """ Validate that the SELECT in the 2nd nesting CTE does not render @@ -1950,6 +2018,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_1, cte.outer_2 FROM cte", ) + def test_double_nesting_cte_in_cte_w_add_cte(self): + """ + Validate that the SELECT in the 2nd nesting CTE does not render + the 1st CTE. + + It implies that nesting CTE level is taken in account. + """ + select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1") + select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2") + + stmt = select( + select( + select_1_cte.c.inner_cte.label("outer_1"), + select_2_cte.c.inner_cte.label("outer_2"), + ) + .add_cte(select_1_cte, select_2_cte, nest_here=True) + .cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + "WITH nesting_1 AS (SELECT :param_1 AS inner_cte)" + ", nesting_2 AS (SELECT :param_2 AS inner_cte)" + " SELECT nesting_1.inner_cte AS outer_1" + ", nesting_2.inner_cte AS outer_2" + " FROM nesting_1, nesting_2" + ") SELECT cte.outer_1, cte.outer_2 FROM cte", + ) + def test_double_nesting_cte_with_cross_reference_in_cte(self): select_1_cte = select(literal(1).label("inner_cte_1")).cte( "nesting_1", nesting=True @@ -1993,6 +2091,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.inner_cte_2, cte.inner_cte_1 FROM cte", ) + def test_double_nesting_cte_with_cross_reference_in_cte_w_add_cte(self): + select_1_cte = select(literal(1).label("inner_cte_1")).cte("nesting_1") + select_2_cte = select( + (select_1_cte.c.inner_cte_1 + 1).label("inner_cte_2") + ).cte("nesting_2") + + # 1 next 2 + + nesting_cte_1_2 = ( + select(select_1_cte, select_2_cte) + .add_cte(select_1_cte, select_2_cte, nest_here=True) + .cte("cte") + ) + stmt_1_2 = select(nesting_cte_1_2) + self.assert_compile( + stmt_1_2, + "WITH cte AS (" + "WITH nesting_1 AS (SELECT :param_1 AS inner_cte_1)" + ", nesting_2 AS (SELECT nesting_1.inner_cte_1 + :inner_cte_1_1" + " AS inner_cte_2 FROM nesting_1)" + " SELECT nesting_1.inner_cte_1 AS inner_cte_1" + ", nesting_2.inner_cte_2 AS inner_cte_2" + " FROM nesting_1, nesting_2" + ") SELECT cte.inner_cte_1, cte.inner_cte_2 FROM cte", + ) + def test_nesting_cte_in_nesting_cte_in_cte(self): select_1_cte = select(literal(1).label("inner_cte")).cte( "nesting_1", nesting=True @@ -2069,6 +2193,31 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT rec_cte.outer_cte FROM rec_cte", ) + def test_nesting_cte_in_recursive_cte_w_add_cte(self): + nesting_cte = select(literal(1).label("inner_cte")).cte( + "nesting", nesting=True + ) + + rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte( + "rec_cte", recursive=True + ) + rec_part = select(rec_cte.c.outer_cte).where( + rec_cte.c.outer_cte == literal(1) + ) + rec_cte = rec_cte.union(rec_part) + + stmt = select(rec_cte) + + self.assert_compile( + stmt, + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " + "(SELECT :param_1 AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = :param_2) " + "SELECT rec_cte.outer_cte FROM rec_cte", + ) + def test_recursive_nesting_cte_in_cte(self): rec_root = select(literal(1).label("inner_cte")).cte( "nesting", recursive=True, nesting=True @@ -2209,6 +2358,80 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "FROM nesting_cte", ) + def test_add_cte_dont_nest_in_two_places(self): + nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( + "nesting_cte" + ) + select_add_cte = select( + (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value") + ).cte("nesting_2") + + union_cte = ( + select( + (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value") + ) + .add_cte(nesting_cte_used_twice, nest_here=True) + .union( + select(select_add_cte).add_cte(select_add_cte, nest_here=True) + ) + .cte("wrapper") + ) + + stmt = ( + select(union_cte) + .add_cte(nesting_cte_used_twice, nest_here=True) + .union(select(nesting_cte_used_twice)) + ) + with expect_raises_message( + exc.CompileError, + "CTE is stated as 'nest_here' in more than one location", + ): + stmt.compile() + + def test_same_nested_cte_is_not_generated_twice_w_add_cte(self): + # Same = name and query + nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( + "nesting_cte" + ) + select_add_cte = select( + (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value") + ).cte("nesting_2") + + union_cte = ( + select( + (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value") + ) + .add_cte(nesting_cte_used_twice) + .union( + select(select_add_cte).add_cte(select_add_cte, nest_here=True) + ) + .cte("wrapper") + ) + + stmt = ( + select(union_cte) + .add_cte(nesting_cte_used_twice, nest_here=True) + .union(select(nesting_cte_used_twice)) + ) + + self.assert_compile( + stmt, + "WITH nesting_cte AS " + "(SELECT :param_1 AS inner_cte_1)" + ", wrapper AS " + "(WITH nesting_2 AS " + "(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 " + "AS next_value " + "FROM nesting_cte)" + " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 " + "AS next_value " + "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value " + "FROM nesting_2)" + " SELECT wrapper.next_value " + "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " + "FROM nesting_cte", + ) + def test_recursive_nesting_cte_in_recursive_cte(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True, recursive=True @@ -2363,6 +2586,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_cte FROM cte", ) + def test_compound_select_with_nesting_cte_in_custom_order_w_add_cte(self): + select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1") + select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2") + + nesting_cte = ( + select(select_1_cte) + .add_cte(select_1_cte, nest_here=True) + .union(select(select_2_cte)) + # Generate "select_2_cte" first + .add_cte(select_2_cte, nest_here=True) + .subquery() + ) + + stmt = select( + select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte") + ) + + self.assert_compile( + stmt, + "WITH cte AS (" + "SELECT anon_1.inner_cte AS outer_cte FROM (" + "WITH nesting_2 AS (SELECT :param_1 AS inner_cte)" + ", nesting_1 AS (SELECT :param_2 AS inner_cte)" + " SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1" + " UNION" + " SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2" + ") AS anon_1" + ") SELECT cte.outer_cte FROM cte", + ) + def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self): rec_root = select(literal(1).label("the_value")).cte( "recursive_cte", recursive=True @@ -2411,3 +2664,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): " WHERE should_continue.val != true))" " SELECT recursive_cte.the_value FROM recursive_cte", ) + + @testing.combinations(True, False) + def test_correlated_cte_in_lateral_w_add_cte(self, reverse_direction): + """this is the original use case that led to #7759""" + contracts = table("contracts", column("id")) + + invoices = table("invoices", column("id"), column("contract_id")) + + contracts_alias = contracts.alias() + cte1 = ( + select(contracts_alias) + .where(contracts_alias.c.id == contracts.c.id) + .correlate(contracts) + .cte(name="cte1") + ) + cte2 = ( + select(invoices) + .join(cte1, invoices.c.contract_id == cte1.c.id) + .cte(name="cte2") + ) + + if reverse_direction: + subq = select(cte1, cte2).add_cte(cte2, cte1, nest_here=True) + else: + subq = select(cte1, cte2).add_cte(cte1, cte2, nest_here=True) + stmt = select(contracts).outerjoin(subq.lateral(), true()) + + self.assert_compile( + stmt, + "SELECT contracts.id FROM contracts LEFT OUTER JOIN LATERAL " + "(WITH cte1 AS (SELECT contracts_1.id AS id " + "FROM contracts AS contracts_1 " + "WHERE contracts_1.id = contracts.id), " + "cte2 AS (SELECT invoices.id AS id, " + "invoices.contract_id AS contract_id FROM invoices " + "JOIN cte1 ON invoices.contract_id = cte1.id) " + "SELECT cte1.id AS id, cte2.id AS id_1, " + "cte2.contract_id AS contract_id " + "FROM cte1, cte2) AS anon_1 ON true", + ) -- 2.47.2