]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add implementation with test and doc
authorEric Masseran <eric.masseran@gmail.com>
Fri, 29 Oct 2021 09:21:42 +0000 (11:21 +0200)
committerEric Masseran <eric.masseran@gmail.com>
Fri, 29 Oct 2021 09:21:42 +0000 (11:21 +0200)
lib/sqlalchemy/sql/selectable.py
test/sql/test_cte.py

index 8e71dfb97faf1659f5e6fc9635a9b6e3c1ecd5de..c7446d75bc87dda48595a4f3c626d4d5226dec39 100644 (file)
@@ -2119,9 +2119,9 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def union(self, other):
+    def union(self, *other):
         return CTE._construct(
-            self.element.union(other),
+            self.element.union(*other),
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
@@ -2130,9 +2130,9 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def union_all(self, other):
+    def union_all(self, *other):
         return CTE._construct(
-            self.element.union_all(other),
+            self.element.union_all(*other),
             name=self.name,
             recursive=self.recursive,
             nesting=self.nesting,
@@ -2416,6 +2416,44 @@ class HasCTE(roles.HasCTERole):
             SELECT value_a.n AS a, value_b.n AS b
             FROM value_a, value_b
 
+        Example 5, Non-Linear CTE::
+
+            edge = Table(
+                "edge",
+                metadata,
+                Column("id", Integer, primary_key=True),
+                Column("left", Integer),
+                Column("right", Integer),
+            )
+
+            root_node = select(literal(1).label("node")).cte(
+                "nodes", recursive=True
+            )
+
+            left_edge = select(edge.c.left).join(
+                root_node, edge.c.right == root_node.c.node
+            )
+            right_edge = select(edge.c.right).join(
+                root_node, edge.c.left == root_node.c.node
+            )
+
+            subgraph_cte = root_node.union(left_edge, right_edge)
+
+            subgraph = select(subgraph_cte)
+
+        The above query will render 2 UNIONs inside the recursive CTE::
+
+            WITH RECURSIVE nodes(node) AS (
+                    SELECT 1 AS node
+                UNION
+                    SELECT edge."left" AS "left"
+                    FROM edge JOIN nodes ON edge."right" = nodes.node
+                UNION
+                    SELECT edge."right" AS "right"
+                    FROM edge JOIN nodes ON edge."left" = nodes.node
+            )
+            SELECT nodes.node FROM nodes
+
         .. seealso::
 
             :meth:`_orm.Query.cte` - ORM version of
@@ -6251,19 +6289,19 @@ class Select(
         else:
             return SelectStatementGrouping(self)
 
-    def union(self, other, **kwargs):
+    def union(self, *other, **kwargs):
         """Return a SQL ``UNION`` of this select() construct against
         the given selectable.
 
         """
-        return CompoundSelect._create_union(self, other, **kwargs)
+        return CompoundSelect._create_union(self, *other, **kwargs)
 
-    def union_all(self, other, **kwargs):
+    def union_all(self, *other, **kwargs):
         """Return a SQL ``UNION ALL`` of this select() construct against
         the given selectable.
 
         """
-        return CompoundSelect._create_union_all(self, other, **kwargs)
+        return CompoundSelect._create_union_all(self, *other, **kwargs)
 
     def except_(self, other, **kwargs):
         """Return a SQL ``EXCEPT`` of this select() construct against
index 5d24adff9302548fbd37c59441765ec2db68ae10..2b66825028fcddfb06925e4934c9489294a794ef 100644 (file)
@@ -1720,6 +1720,31 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "foo",
         )
 
+    def test_multiple_recursive_unions(self):
+        root_query = select(literal(1).label("val")).cte(
+            "increasing", recursive=True
+        )
+        rec_part_1 = select((root_query.c.val + 3).label("val")).where(
+            root_query.c.val < 15
+        )
+        rec_part_2 = select((root_query.c.val + 5).label("val")).where(
+            root_query.c.val < 15
+        )
+        rec_query = root_query.union(rec_part_1, rec_part_2)
+
+        stmt = select(rec_query)
+
+        self.assert_compile(
+            stmt,
+            "WITH RECURSIVE increasing(val) AS "
+            "(SELECT :param_1 AS val "
+            "UNION SELECT increasing.val + :val_1 AS val FROM increasing "
+            "WHERE increasing.val < :val_2 "
+            "UNION SELECT increasing.val + :val_3 AS val FROM increasing "
+            "WHERE increasing.val < :val_4) "
+            "SELECT increasing.val FROM increasing",
+        )
+
 
 class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):