]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add more nesting features to add_cte()
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 Feb 2022 17:24:31 +0000 (12:24 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 24 Feb 2022 23:43:50 +0000 (18:43 -0500)
Added new parameter :paramref:`.HasCTE.add_cte.nest_here` to
:meth:`.HasCTE.add_cte` which will "nest" a given :class:`.CTE` at the
level of the parent statement. This parameter is equivalent to using the
:paramref:`.HasCTE.cte.nesting` parameter, but may be more intuitive in
some scenarios as it allows the nesting attribute to be set simultaneously
along with the explicit level of the CTE.

The :meth:`.HasCTE.add_cte` method also accepts multiple CTE objects.

Fixes: #7759
Change-Id: I263c015f5a3f452cb54819aee12bc9bf2953a7bb

doc/build/changelog/unreleased_20/7759.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_20/7759.rst b/doc/build/changelog/unreleased_20/7759.rst
new file mode 100644 (file)
index 0000000..b7f3bff
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, sql
+    :tickets: 7759
+
+    Added new parameter :paramref:`.HasCTE.add_cte.nest_here` to
+    :meth:`.HasCTE.add_cte` which will "nest" a given :class:`.CTE` at the
+    level of the parent statement. This parameter is equivalent to using the
+    :paramref:`.HasCTE.cte.nesting` parameter, but may be more intuitive in
+    some scenarios as it allows the nesting attribute to be set simultaneously
+    along with the explicit level of the CTE.
+
+    The :meth:`.HasCTE.add_cte` method also accepts multiple CTE objects.
\ No newline at end of file
index b140f9297576cf390495d164d65590fd18d76ce6..77bc1ea38d10a7f030ccf1125353bf690b202ed8 100644 (file)
@@ -31,6 +31,13 @@ import itertools
 import operator
 import re
 from time import perf_counter
+import typing
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Tuple
 
 from . import base
 from . import coercions
@@ -47,6 +54,12 @@ from .elements import quoted_name
 from .. import exc
 from .. import util
 
+if typing.TYPE_CHECKING:
+    from .selectable import CTE
+    from .selectable import FromClause
+
+_FromHintsType = Dict["FromClause", str]
+
 RESERVED_WORDS = set(
     [
         "all",
@@ -842,7 +855,7 @@ class SQLCompiler(Compiled):
         return {}
 
     @util.memoized_instancemethod
-    def _init_cte_state(self):
+    def _init_cte_state(self) -> None:
         """Initialize collections related to CTEs only if
         a CTE is located, to save on the overhead of
         these collections otherwise.
@@ -850,19 +863,21 @@ class SQLCompiler(Compiled):
         """
         # collect CTEs to tack on top of a SELECT
         # To store the query to print - Dict[cte, text_query]
-        self.ctes = util.OrderedDict()
+        self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
 
         # Detect same CTE references - Dict[(level, name), cte]
         # Level is required for supporting nesting
-        self.ctes_by_level_name = {}
+        self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
 
         # To retrieve key/level in ctes_by_level_name -
-        # Dict[cte_reference, (level, cte_name)]
-        self.level_name_by_cte = {}
+        # Dict[cte_reference, (level, cte_name, cte_opts)]
+        self.level_name_by_cte: Dict[
+            CTE, Tuple[int, str, selectable._CTEOpts]
+        ] = {}
 
-        self.ctes_recursive = False
+        self.ctes_recursive: bool = False
         if self.positional:
-            self.cte_positional = {}
+            self.cte_positional: Dict[CTE, List[str]] = {}
 
     @contextlib.contextmanager
     def _nested_result(self):
@@ -1604,8 +1619,7 @@ class SQLCompiler(Compiled):
         self.stack.append(new_entry)
 
         if taf._independent_ctes:
-            for cte in taf._independent_ctes:
-                cte._compiler_dispatch(self, **kw)
+            self._dispatch_independent_ctes(taf, kw)
 
         populate_result_map = (
             toplevel
@@ -1879,8 +1893,7 @@ class SQLCompiler(Compiled):
         )
 
         if compound_stmt._independent_ctes:
-            for cte in compound_stmt._independent_ctes:
-                cte._compiler_dispatch(self, **kwargs)
+            self._dispatch_independent_ctes(compound_stmt, kwargs)
 
         keyword = self.compound_keywords.get(cs.keyword)
 
@@ -2671,16 +2684,25 @@ class SQLCompiler(Compiled):
 
         return ret
 
+    def _dispatch_independent_ctes(self, stmt, kw):
+        local_kw = kw.copy()
+        local_kw.pop("cte_opts", None)
+        for cte, opt in zip(
+            stmt._independent_ctes, stmt._independent_ctes_opts
+        ):
+            cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
+
     def visit_cte(
         self,
-        cte,
-        asfrom=False,
-        ashint=False,
-        fromhints=None,
-        visiting_cte=None,
-        from_linter=None,
-        **kwargs,
-    ):
+        cte: CTE,
+        asfrom: bool = False,
+        ashint: bool = False,
+        fromhints: Optional[_FromHintsType] = None,
+        visiting_cte: Optional[CTE] = None,
+        from_linter: Optional[FromLinter] = None,
+        cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
+        **kwargs: Any,
+    ) -> Optional[str]:
         self._init_cte_state()
 
         kwargs["visiting_cte"] = cte
@@ -2695,15 +2717,48 @@ class SQLCompiler(Compiled):
 
         _reference_cte = cte._get_reference_cte()
 
+        nesting = cte.nesting or cte_opts.nesting
+
+        # check for CTE already encountered
         if _reference_cte in self.level_name_by_cte:
-            cte_level, _ = self.level_name_by_cte[_reference_cte]
+            cte_level, _, existing_cte_opts = self.level_name_by_cte[
+                _reference_cte
+            ]
             assert _ == cte_name
-        else:
-            cte_level = len(self.stack) if cte.nesting else 1
 
-        cte_level_name = (cte_level, cte_name)
-        if cte_level_name in self.ctes_by_level_name:
+            cte_level_name = (cte_level, cte_name)
             existing_cte = self.ctes_by_level_name[cte_level_name]
+
+            # check if we are receiving it here with a specific
+            # "nest_here" location; if so, move it to this location
+
+            if cte_opts.nesting:
+                if existing_cte_opts.nesting:
+                    raise exc.CompileError(
+                        "CTE is stated as 'nest_here' in "
+                        "more than one location"
+                    )
+
+                old_level_name = (cte_level, cte_name)
+                cte_level = len(self.stack) if nesting else 1
+                cte_level_name = new_level_name = (cte_level, cte_name)
+
+                del self.ctes_by_level_name[old_level_name]
+                self.ctes_by_level_name[new_level_name] = existing_cte
+                self.level_name_by_cte[_reference_cte] = new_level_name + (
+                    cte_opts,
+                )
+
+        else:
+            cte_level = len(self.stack) if nesting else 1
+            cte_level_name = (cte_level, cte_name)
+
+            if cte_level_name in self.ctes_by_level_name:
+                existing_cte = self.ctes_by_level_name[cte_level_name]
+            else:
+                existing_cte = None
+
+        if existing_cte is not None:
             embedded_in_current_named_cte = visiting_cte is existing_cte
 
             # we've generated a same-named CTE that we are enclosed in,
@@ -2718,10 +2773,8 @@ class SQLCompiler(Compiled):
 
                 existing_cte_reference_cte = existing_cte._get_reference_cte()
 
-                # TODO: determine if these assertions are correct.  they
-                # pass for current test cases
-                # assert existing_cte_reference_cte is _reference_cte
-                # assert existing_cte_reference_cte is existing_cte
+                assert existing_cte_reference_cte is _reference_cte
+                assert existing_cte_reference_cte is existing_cte
 
                 del self.level_name_by_cte[existing_cte_reference_cte]
             else:
@@ -2746,19 +2799,9 @@ class SQLCompiler(Compiled):
 
         if is_new_cte:
             self.ctes_by_level_name[cte_level_name] = cte
-            self.level_name_by_cte[_reference_cte] = cte_level_name
-
-            if (
-                "autocommit" in cte.element._execution_options
-                and "autocommit" not in self.execution_options
-            ):
-                self.execution_options = self.execution_options.union(
-                    {
-                        "autocommit": cte.element._execution_options[
-                            "autocommit"
-                        ]
-                    }
-                )
+            self.level_name_by_cte[_reference_cte] = cte_level_name + (
+                cte_opts,
+            )
 
             if pre_alias_cte not in self.ctes:
                 self.visit_cte(pre_alias_cte, **kwargs)
@@ -3378,8 +3421,7 @@ class SQLCompiler(Compiled):
             byfrom = None
 
         if select_stmt._independent_ctes:
-            for cte in select_stmt._independent_ctes:
-                cte._compiler_dispatch(self, **kwargs)
+            self._dispatch_independent_ctes(select_stmt, kwargs)
 
         if select_stmt._prefixes:
             text += self._generate_prefixes(
@@ -3485,7 +3527,9 @@ class SQLCompiler(Compiled):
 
         return text
 
-    def _setup_select_hints(self, select):
+    def _setup_select_hints(
+        self, select: Select
+    ) -> Tuple[str, _FromHintsType]:
         byfrom = dict(
             [
                 (
@@ -3663,13 +3707,14 @@ class SQLCompiler(Compiled):
         if nesting_level and nesting_level > 1:
             ctes = util.OrderedDict()
             for cte in list(self.ctes.keys()):
-                cte_level, cte_name = self.level_name_by_cte[
+                cte_level, cte_name, cte_opts = self.level_name_by_cte[
                     cte._get_reference_cte()
                 ]
+                nesting = cte.nesting or cte_opts.nesting
                 is_rendered_level = cte_level == nesting_level or (
                     include_following_stack and cte_level == nesting_level + 1
                 )
-                if not (cte.nesting and is_rendered_level):
+                if not (nesting and is_rendered_level):
                     continue
 
                 ctes[cte] = self.ctes[cte]
@@ -3693,7 +3738,7 @@ class SQLCompiler(Compiled):
 
         if nesting_level and nesting_level > 1:
             for cte in list(ctes.keys()):
-                cte_level, cte_name = self.level_name_by_cte[
+                cte_level, cte_name, cte_opts = self.level_name_by_cte[
                     cte._get_reference_cte()
                 ]
                 del self.ctes[cte]
@@ -3939,8 +3984,7 @@ class SQLCompiler(Compiled):
             _, table_text = self._setup_crud_hints(insert_stmt, table_text)
 
         if insert_stmt._independent_ctes:
-            for cte in insert_stmt._independent_ctes:
-                cte._compiler_dispatch(self, **kw)
+            self._dispatch_independent_ctes(insert_stmt, kw)
 
         text += table_text
 
@@ -4108,8 +4152,7 @@ class SQLCompiler(Compiled):
             dialect_hints = None
 
         if update_stmt._independent_ctes:
-            for cte in update_stmt._independent_ctes:
-                cte._compiler_dispatch(self, **kw)
+            self._dispatch_independent_ctes(update_stmt, kw)
 
         text += table_text
 
@@ -4221,8 +4264,7 @@ class SQLCompiler(Compiled):
             dialect_hints = None
 
         if delete_stmt._independent_ctes:
-            for cte in delete_stmt._independent_ctes:
-                cte._compiler_dispatch(self, **kw)
+            self._dispatch_independent_ctes(delete_stmt, kw)
 
         text += table_text
 
index 7f6360edb07b76460f5b14cc244a6aac8b18a2bf..836c30af7420558ff3b15d7aefb1371bd24fecd9 100644 (file)
@@ -19,6 +19,7 @@ import itertools
 from operator import attrgetter
 import typing
 from typing import Any as TODO_Any
+from typing import NamedTuple
 from typing import Optional
 from typing import Tuple
 
@@ -1809,6 +1810,10 @@ class CTE(
 SelfHasCTE = typing.TypeVar("SelfHasCTE", bound="HasCTE")
 
 
+class _CTEOpts(NamedTuple):
+    nesting: bool
+
+
 class HasCTE(roles.HasCTERole):
     """Mixin that declares a class to include CTE support.
 
@@ -1818,20 +1823,36 @@ class HasCTE(roles.HasCTERole):
 
     _has_ctes_traverse_internals = [
         ("_independent_ctes", InternalTraversal.dp_clauseelement_list),
+        ("_independent_ctes_opts", InternalTraversal.dp_plain_obj),
     ]
 
     _independent_ctes = ()
+    _independent_ctes_opts = ()
 
     @_generative
-    def add_cte(self: SelfHasCTE, cte) -> SelfHasCTE:
-        """Add a :class:`_sql.CTE` to this statement object that will be
-        independently rendered even if not referenced in the statement
-        otherwise.
+    def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE:
+        r"""Add one or more :class:`_sql.CTE` constructs to this statement.
+
+        This method will associate the given :class:`_sql.CTE` constructs with
+        the parent statement such that they will each be unconditionally
+        rendered in the WITH clause of the final statement, even if not
+        referenced elsewhere within the statement or any sub-selects.
+
+        The optional :paramref:`.HasCTE.add_cte.nest_here` parameter when set
+        to True will have the effect that each given :class:`_sql.CTE` will
+        render in a WITH clause rendered directly along with this statement,
+        rather than being moved to the top of the ultimate rendered statement,
+        even if this statement is rendered as a subquery within a larger
+        statement.
 
-        This feature is useful for the use case of embedding a DML statement
-        such as an INSERT or UPDATE as a CTE inline with a primary statement
-        that may draw from its results indirectly; while PostgreSQL is known
-        to support this usage, it may not be supported by other backends.
+        This method has two general uses. One is to embed CTE statements that
+        serve some purpose without being referenced explicitly, such as the use
+        case of embedding a DML statement such as an INSERT or UPDATE as a CTE
+        inline with a primary statement that may draw from its results
+        indirectly.  The other is to provide control over the exact placement
+        of a particular series of CTE constructs that should remain rendered
+        directly in terms of a particular statement that may be nested in a
+        larger statement.
 
         E.g.::
 
@@ -1885,9 +1906,32 @@ class HasCTE(roles.HasCTERole):
 
         .. versionadded:: 1.4.21
 
+        :param \*ctes: zero or more :class:`.CTE` constructs.
+
+         .. versionchanged:: 2.0  Multiple CTE instances are accepted
+
+        :param nest_here: if True, the given CTE or CTEs will be rendered
+         as though they specified the :paramref:`.HasCTE.cte.nesting` flag
+         to ``True`` when they were added to this :class:`.HasCTE`.
+         Assuming the given CTEs are not referenced in an outer-enclosing
+         statement as well, the CTEs given should render at the level of
+         this statement when this flag is given.
+
+         .. versionadded:: 2.0
+
+         .. seealso::
+
+            :paramref:`.HasCTE.cte.nesting`
+
+
         """
-        cte = coercions.expect(roles.IsCTERole, cte)
-        self._independent_ctes += (cte,)
+        opt = _CTEOpts(
+            nest_here,
+        )
+        for cte in ctes:
+            cte = coercions.expect(roles.IsCTERole, cte)
+            self._independent_ctes += (cte,)
+            self._independent_ctes_opts += (opt,)
         return self
 
     def cte(self, name=None, recursive=False, nesting=False):
@@ -1931,10 +1975,18 @@ class HasCTE(roles.HasCTERole):
          conjunction with UNION ALL in order to derive rows
          from those already selected.
         :param nesting: if ``True``, will render the CTE locally to the
-         actual statement.
+         statement in which it is referenced.   For more complex scenarios,
+         the :meth:`.HasCTE.add_cte` method using the
+         :paramref:`.HasCTE.add_cte.nest_here`
+         parameter may also be used to more carefully
+         control the exact placement of a particular CTE.
 
          .. versionadded:: 1.4.24
 
+         .. seealso::
+
+            :meth:`.HasCTE.add_cte`
+
         The following examples include two from PostgreSQL's documentation at
         https://www.postgresql.org/docs/current/static/queries-with.html,
         as well as additional examples.
@@ -2084,6 +2136,28 @@ class HasCTE(roles.HasCTERole):
             SELECT value_a.n AS a, value_b.n AS b
             FROM value_a, value_b
 
+        The same CTE can be set up using the :meth:`.HasCTE.add_cte` method
+        as follows (SQLAlchemy 2.0 and above)::
+
+            value_a = select(
+                literal("root").label("n")
+            ).cte("value_a")
+
+            # A nested CTE with the same name as the root one
+            value_a_nested = select(
+                literal("nesting").label("n")
+            ).cte("value_a")
+
+            # Nesting CTEs takes ascendency locally
+            # over the CTEs at a higher level
+            value_b = (
+                select(value_a_nested.c.n).
+                add_cte(value_a_nested, nest_here=True).
+                cte("value_b")
+            )
+
+            value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b"))
+
         Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above)::
 
             edge = Table(
index b0569250485f91665d8b64941857d8e28901ebe7..2ee6fa9f312c6aabc49fcbb15a6dd7f75567e64a 100644 (file)
@@ -1,11 +1,13 @@
 from sqlalchemy import Column
 from sqlalchemy import delete
+from sqlalchemy import exc
 from sqlalchemy import Integer
 from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy import MetaData
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import update
 from sqlalchemy.dialects import mssql
 from sqlalchemy.engine import default
@@ -25,6 +27,7 @@ from sqlalchemy.sql.visitors import cloned_traverse
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 
 
@@ -1869,6 +1872,21 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT cte.outer_cte FROM cte",
         )
 
+    def test_select_with_nesting_cte_in_cte_w_add_cte(self):
+        nesting_cte = select(literal(1).label("inner_cte")).cte("nesting")
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte"))
+            .add_cte(nesting_cte, nest_here=True)
+            .cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting) "
+            "SELECT cte.outer_cte FROM cte",
+        )
+
     def test_select_with_aliased_nesting_cte_in_cte(self):
         nesting_cte = (
             select(literal(1).label("inner_cte"))
@@ -1887,6 +1905,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT cte.outer_cte FROM cte",
         )
 
+    def test_select_with_aliased_nesting_cte_in_cte_w_add_cte(self):
+        inner_nesting_cte = select(literal(1).label("inner_cte")).cte(
+            "nesting"
+        )
+        outer_cte = select().add_cte(inner_nesting_cte, nest_here=True)
+        nesting_cte = inner_nesting_cte.alias("aliased_nested")
+        outer_cte = outer_cte.add_columns(
+            nesting_cte.c.inner_cte.label("outer_cte")
+        ).cte("cte")
+        stmt = select(outer_cte)
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
+            "SELECT aliased_nested.inner_cte AS outer_cte "
+            "FROM nesting AS aliased_nested) "
+            "SELECT cte.outer_cte FROM cte",
+        )
+
     def test_nesting_cte_in_cte_with_same_name(self):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "some_cte", nesting=True
@@ -1904,6 +1941,23 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT some_cte.outer_cte FROM some_cte",
         )
 
+    def test_nesting_cte_in_cte_with_same_name_w_add_cte(self):
+        nesting_cte = select(literal(1).label("inner_cte")).cte("some_cte")
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte"))
+            .add_cte(nesting_cte, nest_here=True)
+            .cte("some_cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH some_cte AS (WITH some_cte AS "
+            "(SELECT :param_1 AS inner_cte) "
+            "SELECT some_cte.inner_cte AS outer_cte "
+            "FROM some_cte) "
+            "SELECT some_cte.outer_cte FROM some_cte",
+        )
+
     def test_nesting_cte_at_top_level(self):
         nesting_cte = select(literal(1).label("val")).cte(
             "nesting_cte", nesting=True
@@ -1918,6 +1972,20 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             " SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
         )
 
+    def test_nesting_cte_at_top_level_w_add_cte(self):
+        nesting_cte = select(literal(1).label("val")).cte("nesting_cte")
+        cte = select(literal(2).label("val")).cte("cte")
+        stmt = select(nesting_cte.c.val, cte.c.val).add_cte(
+            nesting_cte, nest_here=True
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH nesting_cte AS (SELECT :param_1 AS val)"
+            ", cte AS (SELECT :param_2 AS val)"
+            " SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
+        )
+
     def test_double_nesting_cte_in_cte(self):
         """
         Validate that the SELECT in the 2nd nesting CTE does not render
@@ -1950,6 +2018,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             ") SELECT cte.outer_1, cte.outer_2 FROM cte",
         )
 
+    def test_double_nesting_cte_in_cte_w_add_cte(self):
+        """
+        Validate that the SELECT in the 2nd nesting CTE does not render
+        the 1st CTE.
+
+        It implies that nesting CTE level is taken in account.
+        """
+        select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
+        select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
+
+        stmt = select(
+            select(
+                select_1_cte.c.inner_cte.label("outer_1"),
+                select_2_cte.c.inner_cte.label("outer_2"),
+            )
+            .add_cte(select_1_cte, select_2_cte, nest_here=True)
+            .cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS ("
+            "WITH nesting_1 AS (SELECT :param_1 AS inner_cte)"
+            ", nesting_2 AS (SELECT :param_2 AS inner_cte)"
+            " SELECT nesting_1.inner_cte AS outer_1"
+            ", nesting_2.inner_cte AS outer_2"
+            " FROM nesting_1, nesting_2"
+            ") SELECT cte.outer_1, cte.outer_2 FROM cte",
+        )
+
     def test_double_nesting_cte_with_cross_reference_in_cte(self):
         select_1_cte = select(literal(1).label("inner_cte_1")).cte(
             "nesting_1", nesting=True
@@ -1993,6 +2091,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             ") SELECT cte.inner_cte_2, cte.inner_cte_1 FROM cte",
         )
 
+    def test_double_nesting_cte_with_cross_reference_in_cte_w_add_cte(self):
+        select_1_cte = select(literal(1).label("inner_cte_1")).cte("nesting_1")
+        select_2_cte = select(
+            (select_1_cte.c.inner_cte_1 + 1).label("inner_cte_2")
+        ).cte("nesting_2")
+
+        # 1 next 2
+
+        nesting_cte_1_2 = (
+            select(select_1_cte, select_2_cte)
+            .add_cte(select_1_cte, select_2_cte, nest_here=True)
+            .cte("cte")
+        )
+        stmt_1_2 = select(nesting_cte_1_2)
+        self.assert_compile(
+            stmt_1_2,
+            "WITH cte AS ("
+            "WITH nesting_1 AS (SELECT :param_1 AS inner_cte_1)"
+            ", nesting_2 AS (SELECT nesting_1.inner_cte_1 + :inner_cte_1_1"
+            " AS inner_cte_2 FROM nesting_1)"
+            " SELECT nesting_1.inner_cte_1 AS inner_cte_1"
+            ", nesting_2.inner_cte_2 AS inner_cte_2"
+            " FROM nesting_1, nesting_2"
+            ") SELECT cte.inner_cte_1, cte.inner_cte_2 FROM cte",
+        )
+
     def test_nesting_cte_in_nesting_cte_in_cte(self):
         select_1_cte = select(literal(1).label("inner_cte")).cte(
             "nesting_1", nesting=True
@@ -2069,6 +2193,31 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT rec_cte.outer_cte FROM rec_cte",
         )
 
+    def test_nesting_cte_in_recursive_cte_w_add_cte(self):
+        nesting_cte = select(literal(1).label("inner_cte")).cte(
+            "nesting", nesting=True
+        )
+
+        rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+            "rec_cte", recursive=True
+        )
+        rec_part = select(rec_cte.c.outer_cte).where(
+            rec_cte.c.outer_cte == literal(1)
+        )
+        rec_cte = rec_cte.union(rec_part)
+
+        stmt = select(rec_cte)
+
+        self.assert_compile(
+            stmt,
+            "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+            "(SELECT :param_1 AS inner_cte) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = :param_2) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
+        )
+
     def test_recursive_nesting_cte_in_cte(self):
         rec_root = select(literal(1).label("inner_cte")).cte(
             "nesting", recursive=True, nesting=True
@@ -2209,6 +2358,80 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM nesting_cte",
         )
 
+    def test_add_cte_dont_nest_in_two_places(self):
+        nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
+            "nesting_cte"
+        )
+        select_add_cte = select(
+            (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
+        ).cte("nesting_2")
+
+        union_cte = (
+            select(
+                (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
+            )
+            .add_cte(nesting_cte_used_twice, nest_here=True)
+            .union(
+                select(select_add_cte).add_cte(select_add_cte, nest_here=True)
+            )
+            .cte("wrapper")
+        )
+
+        stmt = (
+            select(union_cte)
+            .add_cte(nesting_cte_used_twice, nest_here=True)
+            .union(select(nesting_cte_used_twice))
+        )
+        with expect_raises_message(
+            exc.CompileError,
+            "CTE is stated as 'nest_here' in more than one location",
+        ):
+            stmt.compile()
+
+    def test_same_nested_cte_is_not_generated_twice_w_add_cte(self):
+        # Same = name and query
+        nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
+            "nesting_cte"
+        )
+        select_add_cte = select(
+            (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
+        ).cte("nesting_2")
+
+        union_cte = (
+            select(
+                (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
+            )
+            .add_cte(nesting_cte_used_twice)
+            .union(
+                select(select_add_cte).add_cte(select_add_cte, nest_here=True)
+            )
+            .cte("wrapper")
+        )
+
+        stmt = (
+            select(union_cte)
+            .add_cte(nesting_cte_used_twice, nest_here=True)
+            .union(select(nesting_cte_used_twice))
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH nesting_cte AS "
+            "(SELECT :param_1 AS inner_cte_1)"
+            ", wrapper AS "
+            "(WITH nesting_2 AS "
+            "(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 "
+            "AS next_value "
+            "FROM nesting_cte)"
+            " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 "
+            "AS next_value "
+            "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value "
+            "FROM nesting_2)"
+            " SELECT wrapper.next_value "
+            "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
+            "FROM nesting_cte",
+        )
+
     def test_recursive_nesting_cte_in_recursive_cte(self):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "nesting", nesting=True, recursive=True
@@ -2363,6 +2586,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             ") SELECT cte.outer_cte FROM cte",
         )
 
+    def test_compound_select_with_nesting_cte_in_custom_order_w_add_cte(self):
+        select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
+        select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
+
+        nesting_cte = (
+            select(select_1_cte)
+            .add_cte(select_1_cte, nest_here=True)
+            .union(select(select_2_cte))
+            # Generate "select_2_cte" first
+            .add_cte(select_2_cte, nest_here=True)
+            .subquery()
+        )
+
+        stmt = select(
+            select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte AS ("
+            "SELECT anon_1.inner_cte AS outer_cte FROM ("
+            "WITH nesting_2 AS (SELECT :param_1 AS inner_cte)"
+            ", nesting_1 AS (SELECT :param_2 AS inner_cte)"
+            " SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1"
+            " UNION"
+            " SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2"
+            ") AS anon_1"
+            ") SELECT cte.outer_cte FROM cte",
+        )
+
     def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
         rec_root = select(literal(1).label("the_value")).cte(
             "recursive_cte", recursive=True
@@ -2411,3 +2664,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             " WHERE should_continue.val != true))"
             " SELECT recursive_cte.the_value FROM recursive_cte",
         )
+
+    @testing.combinations(True, False)
+    def test_correlated_cte_in_lateral_w_add_cte(self, reverse_direction):
+        """this is the original use case that led to #7759"""
+        contracts = table("contracts", column("id"))
+
+        invoices = table("invoices", column("id"), column("contract_id"))
+
+        contracts_alias = contracts.alias()
+        cte1 = (
+            select(contracts_alias)
+            .where(contracts_alias.c.id == contracts.c.id)
+            .correlate(contracts)
+            .cte(name="cte1")
+        )
+        cte2 = (
+            select(invoices)
+            .join(cte1, invoices.c.contract_id == cte1.c.id)
+            .cte(name="cte2")
+        )
+
+        if reverse_direction:
+            subq = select(cte1, cte2).add_cte(cte2, cte1, nest_here=True)
+        else:
+            subq = select(cte1, cte2).add_cte(cte1, cte2, nest_here=True)
+        stmt = select(contracts).outerjoin(subq.lateral(), true())
+
+        self.assert_compile(
+            stmt,
+            "SELECT contracts.id FROM contracts LEFT OUTER JOIN LATERAL "
+            "(WITH cte1 AS (SELECT contracts_1.id AS id "
+            "FROM contracts AS contracts_1 "
+            "WHERE contracts_1.id = contracts.id), "
+            "cte2 AS (SELECT invoices.id AS id, "
+            "invoices.contract_id AS contract_id FROM invoices "
+            "JOIN cte1 ON invoices.contract_id = cte1.id) "
+            "SELECT cte1.id AS id, cte2.id AS id_1, "
+            "cte2.contract_id AS contract_id "
+            "FROM cte1, cte2) AS anon_1 ON true",
+        )