]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add Non linear CTE support
authorEric Masseran <eric.masseran@gmail.com>
Tue, 2 Nov 2021 20:40:04 +0000 (16:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Nov 2021 19:45:35 +0000 (14:45 -0500)
"Compound select" methods like :meth:`_sql.Select.union`,
:meth:`_sql.Select.intersect_all` etc. now accept ``*other`` as an argument
rather than ``other`` to allow for multiple additional SELECTs to be
compounded with the parent statement at once. In particular, the change as
applied to :meth:`_sql.CTE.union` and :meth:`_sql.CTE.union_all` now allow
for a so-called "non-linear CTE" to be created with the :class:`_sql.CTE`
construct, whereas previously there was no way to have more than two CTE
sub-elements in a UNION together while still correctly calling upon the CTE
in recursive fashion. Pull request courtesy Eric Masseran.

Allow:

```sql
WITH RECURSIVE nodes(x) AS (
   SELECT 59
   UNION
   SELECT aa FROM edge JOIN nodes ON bb=x
   UNION
   SELECT bb FROM edge JOIN nodes ON aa=x
)
SELECT x FROM nodes;
```

Based on @zzzeek suggestion: https://github.com/sqlalchemy/sqlalchemy/pull/7133#issuecomment-933882348

Fixes: #7259
Closes: #7260
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7260
Pull-request-sha: 2565a5fd4b1940e92125e53aeaa731cc682f49bb

Change-Id: I685c8379762b5fb6ab4107ff8f4d8a4de70c0ca6
(cherry picked from commit 958f902b1fc528fed0be550bc573545de47ed854)

doc/build/changelog/unreleased_14/7259.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/sql/test_cte.py
test/sql/test_select.py

diff --git a/doc/build/changelog/unreleased_14/7259.rst b/doc/build/changelog/unreleased_14/7259.rst
new file mode 100644 (file)
index 0000000..477714e
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: sql, usecase
+    :tickets: 7259
+
+    "Compound select" methods like :meth:`_sql.Select.union`,
+    :meth:`_sql.Select.intersect_all` etc. now accept ``*other`` as an argument
+    rather than ``other`` to allow for multiple additional SELECTs to be
+    compounded with the parent statement at once. In particular, the change as
+    applied to :meth:`_sql.CTE.union` and :meth:`_sql.CTE.union_all` now allow
+    for a so-called "non-linear CTE" to be created with the :class:`_sql.CTE`
+    construct, whereas previously there was no way to have more than two CTE
+    sub-elements in a UNION together while still correctly calling upon the CTE
+    in recursive fashion. Pull request courtesy Eric Masseran.
index 9143602970624355486a7ba6aad1c08c3260bbe7..95fca267c65d7d34bb5c0dcbb8bc63ab802f20b0 100644 (file)
@@ -2121,9 +2121,23 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def union(self, other):
+    def union(self, *other):
+        r"""Return a new :class:`_expression.CTE` with a SQL ``UNION``
+        of the original CTE against the given selectables provided
+        as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
+
+         .. versionchanged:: 1.4.28 multiple elements are now accepted.
+
+        .. seealso::
+
+            :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+        """
         return CTE._construct(
-            self.element.union(other),
+            self.element.union(*other),
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
@@ -2132,9 +2146,23 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def union_all(self, other):
+    def union_all(self, *other):
+        r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL``
+        of the original CTE against the given selectables provided
+        as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
+
+         .. versionchanged:: 1.4.28 multiple elements are now accepted.
+
+        .. seealso::
+
+            :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+        """
         return CTE._construct(
-            self.element.union_all(other),
+            self.element.union_all(*other),
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
@@ -2396,7 +2424,7 @@ class HasCTE(roles.HasCTERole):
 
             connection.execute(upsert)
 
-        Example 4, Nesting CTE::
+        Example 4, Nesting CTE (SQLAlchemy 1.4.24 and above)::
 
             value_a = select(
                 literal("root").label("n")
@@ -2426,6 +2454,44 @@ class HasCTE(roles.HasCTERole):
             SELECT value_a.n AS a, value_b.n AS b
             FROM value_a, value_b
 
+        Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above)::
+
+            edge = Table(
+                "edge",
+                metadata,
+                Column("id", Integer, primary_key=True),
+                Column("left", Integer),
+                Column("right", Integer),
+            )
+
+            root_node = select(literal(1).label("node")).cte(
+                "nodes", recursive=True
+            )
+
+            left_edge = select(edge.c.left).join(
+                root_node, edge.c.right == root_node.c.node
+            )
+            right_edge = select(edge.c.right).join(
+                root_node, edge.c.left == root_node.c.node
+            )
+
+            subgraph_cte = root_node.union(left_edge, right_edge)
+
+            subgraph = select(subgraph_cte)
+
+        The above query will render 2 UNIONs inside the recursive CTE::
+
+            WITH RECURSIVE nodes(node) AS (
+                    SELECT 1 AS node
+                UNION
+                    SELECT edge."left" AS "left"
+                    FROM edge JOIN nodes ON edge."right" = nodes.node
+                UNION
+                    SELECT edge."right" AS "right"
+                    FROM edge JOIN nodes ON edge."left" = nodes.node
+            )
+            SELECT nodes.node FROM nodes
+
         .. seealso::
 
             :meth:`_orm.Query.cte` - ORM version of
@@ -6270,47 +6336,107 @@ class Select(
         else:
             return SelectStatementGrouping(self)
 
-    def union(self, other, **kwargs):
-        """Return a SQL ``UNION`` of this select() construct against
-        the given selectable.
+    def union(self, *other, **kwargs):
+        r"""Return a SQL ``UNION`` of this select() construct against
+        the given selectables provided as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
+
+         .. versionchanged:: 1.4.28
+
+            multiple elements are now accepted.
+
+        :param \**kwargs: keyword arguments are forwarded to the constructor
+         for the newly created :class:`_sql.CompoundSelect` object.
 
         """
-        return CompoundSelect._create_union(self, other, **kwargs)
+        return CompoundSelect._create_union(self, *other, **kwargs)
+
+    def union_all(self, *other, **kwargs):
+        r"""Return a SQL ``UNION ALL`` of this select() construct against
+        the given selectables provided as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
 
-    def union_all(self, other, **kwargs):
-        """Return a SQL ``UNION ALL`` of this select() construct against
-        the given selectable.
+         .. versionchanged:: 1.4.28
+
+            multiple elements are now accepted.
+
+        :param \**kwargs: keyword arguments are forwarded to the constructor
+         for the newly created :class:`_sql.CompoundSelect` object.
 
         """
-        return CompoundSelect._create_union_all(self, other, **kwargs)
+        return CompoundSelect._create_union_all(self, *other, **kwargs)
+
+    def except_(self, *other, **kwargs):
+        r"""Return a SQL ``EXCEPT`` of this select() construct against
+        the given selectable provided as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
+
+         .. versionchanged:: 1.4.28
+
+            multiple elements are now accepted.
 
-    def except_(self, other, **kwargs):
-        """Return a SQL ``EXCEPT`` of this select() construct against
-        the given selectable.
+        :param \**kwargs: keyword arguments are forwarded to the constructor
+         for the newly created :class:`_sql.CompoundSelect` object.
 
         """
-        return CompoundSelect._create_except(self, other, **kwargs)
+        return CompoundSelect._create_except(self, *other, **kwargs)
 
-    def except_all(self, other, **kwargs):
-        """Return a SQL ``EXCEPT ALL`` of this select() construct against
-        the given selectable.
+    def except_all(self, *other, **kwargs):
+        r"""Return a SQL ``EXCEPT ALL`` of this select() construct against
+        the given selectables provided as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
+
+         .. versionchanged:: 1.4.28
+
+            multiple elements are now accepted.
+
+        :param \**kwargs: keyword arguments are forwarded to the constructor
+         for the newly created :class:`_sql.CompoundSelect` object.
 
         """
-        return CompoundSelect._create_except_all(self, other, **kwargs)
+        return CompoundSelect._create_except_all(self, *other, **kwargs)
+
+    def intersect(self, *other, **kwargs):
+        r"""Return a SQL ``INTERSECT`` of this select() construct against
+        the given selectables provided as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
 
-    def intersect(self, other, **kwargs):
-        """Return a SQL ``INTERSECT`` of this select() construct against
-        the given selectable.
+         .. versionchanged:: 1.4.28
+
+            multiple elements are now accepted.
+
+        :param \**kwargs: keyword arguments are forwarded to the constructor
+         for the newly created :class:`_sql.CompoundSelect` object.
 
         """
-        return CompoundSelect._create_intersect(self, other, **kwargs)
+        return CompoundSelect._create_intersect(self, *other, **kwargs)
+
+    def intersect_all(self, *other, **kwargs):
+        r"""Return a SQL ``INTERSECT ALL`` of this select() construct
+        against the given selectables provided as positional arguments.
+
+        :param \*other: one or more elements with which to create a
+         UNION.
+
+         .. versionchanged:: 1.4.28
+
+            multiple elements are now accepted.
 
-    def intersect_all(self, other, **kwargs):
-        """Return a SQL ``INTERSECT ALL`` of this select() construct
-        against the given selectable.
+        :param \**kwargs: keyword arguments are forwarded to the constructor
+         for the newly created :class:`_sql.CompoundSelect` object.
 
         """
-        return CompoundSelect._create_intersect_all(self, other, **kwargs)
+        return CompoundSelect._create_intersect_all(self, *other, **kwargs)
 
     @property
     @util.deprecated_20(
index 10fe81b5530ff0510814271906b74535b8231ac1..df9f065acc849b78331d2b4430109e808f5632b3 100644 (file)
@@ -1769,6 +1769,53 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "foo",
         )
 
+    def test_recursive_cte_with_multiple_union(self):
+        root_query = select(literal(1).label("val")).cte(
+            "increasing", recursive=True
+        )
+        rec_part_1 = select((root_query.c.val + 3).label("val")).where(
+            root_query.c.val < 15
+        )
+        rec_part_2 = select((root_query.c.val + 5).label("val")).where(
+            root_query.c.val < 15
+        )
+        union_rec_query = root_query.union(rec_part_1, rec_part_2)
+        union_stmt = select(union_rec_query)
+        self.assert_compile(
+            union_stmt,
+            "WITH RECURSIVE increasing(val) AS "
+            "(SELECT :param_1 AS val "
+            "UNION SELECT increasing.val + :val_1 AS val FROM increasing "
+            "WHERE increasing.val < :val_2 "
+            "UNION SELECT increasing.val + :val_3 AS val FROM increasing "
+            "WHERE increasing.val < :val_4) "
+            "SELECT increasing.val FROM increasing",
+        )
+
+    def test_recursive_cte_with_multiple_union_all(self):
+        root_query = select(literal(1).label("val")).cte(
+            "increasing", recursive=True
+        )
+        rec_part_1 = select((root_query.c.val + 3).label("val")).where(
+            root_query.c.val < 15
+        )
+        rec_part_2 = select((root_query.c.val + 5).label("val")).where(
+            root_query.c.val < 15
+        )
+
+        union_all_rec_query = root_query.union_all(rec_part_1, rec_part_2)
+        union_all_stmt = select(union_all_rec_query)
+        self.assert_compile(
+            union_all_stmt,
+            "WITH RECURSIVE increasing(val) AS "
+            "(SELECT :param_1 AS val "
+            "UNION ALL SELECT increasing.val + :val_1 AS val FROM increasing "
+            "WHERE increasing.val < :val_2 "
+            "UNION ALL SELECT increasing.val + :val_3 AS val FROM increasing "
+            "WHERE increasing.val < :val_4) "
+            "SELECT increasing.val FROM increasing",
+        )
+
 
 class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
 
index 17b47d96de77549404dad65bb8d0e5eb2fab493c..c9abb7fb8b451ab102df80fc73a130b933fe8d30 100644 (file)
@@ -8,15 +8,16 @@ from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy import testing
 from sqlalchemy import tuple_
 from sqlalchemy import union
 from sqlalchemy.sql import column
+from sqlalchemy.sql import literal
 from sqlalchemy.sql import table
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import fixtures
 
-
 table1 = table(
     "mytable",
     column("myid", Integer),
@@ -412,3 +413,23 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT anon_1.name FROM (SELECT mytable.name AS name, "
             "(mytable.myid, mytable.name) AS anon_2 FROM mytable) AS anon_1",
         )
+
+    @testing.combinations(
+        ("union_all", "UNION ALL"),
+        ("union", "UNION"),
+        ("intersect_all", "INTERSECT ALL"),
+        ("intersect", "INTERSECT"),
+        ("except_all", "EXCEPT ALL"),
+        ("except_", "EXCEPT"),
+    )
+    def test_select_multiple_compound_elements(self, methname, joiner):
+        stmt = select(literal(1))
+        meth = getattr(stmt, methname)
+        stmt = meth(select(literal(2)), select(literal(3)))
+
+        self.assert_compile(
+            stmt,
+            "SELECT :param_1 AS anon_1"
+            " %(joiner)s SELECT :param_2 AS anon_2"
+            " %(joiner)s SELECT :param_3 AS anon_3" % {"joiner": joiner},
+        )