From: Eric Masseran Date: Fri, 29 Oct 2021 09:21:42 +0000 (+0200) Subject: Add implementation with test and doc X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=66eb5f4243a52629aa40f86c9d4bd0e798480a6d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add implementation with test and doc --- diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8e71dfb97f..c7446d75bc 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2119,9 +2119,9 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, other): + def union(self, *other): return CTE._construct( - self.element.union(other), + self.element.union(*other), name=self.name, recursive=self.recursive, nesting=self.nesting, @@ -2130,9 +2130,9 @@ class CTE( _suffixes=self._suffixes, ) - def union_all(self, other): + def union_all(self, *other): return CTE._construct( - self.element.union_all(other), + self.element.union_all(*other), name=self.name, recursive=self.recursive, nesting=self.nesting, @@ -2416,6 +2416,44 @@ class HasCTE(roles.HasCTERole): SELECT value_a.n AS a, value_b.n AS b FROM value_a, value_b + Example 5, Non-Linear CTE:: + + edge = Table( + "edge", + metadata, + Column("id", Integer, primary_key=True), + Column("left", Integer), + Column("right", Integer), + ) + + root_node = select(literal(1).label("node")).cte( + "nodes", recursive=True + ) + + left_edge = select(edge.c.left).join( + root_node, edge.c.right == root_node.c.node + ) + right_edge = select(edge.c.right).join( + root_node, edge.c.left == root_node.c.node + ) + + subgraph_cte = root_node.union(left_edge, right_edge) + + subgraph = select(subgraph_cte) + + The above query will render 2 UNIONs inside the recursive CTE:: + + WITH RECURSIVE nodes(node) AS ( + SELECT 1 AS node + UNION + SELECT edge."left" AS "left" + FROM edge JOIN nodes ON edge."right" = nodes.node + UNION + SELECT edge."right" AS "right" + FROM edge JOIN nodes ON edge."left" = nodes.node + ) + SELECT nodes.node FROM nodes + .. seealso:: :meth:`_orm.Query.cte` - ORM version of @@ -6251,19 +6289,19 @@ class Select( else: return SelectStatementGrouping(self) - def union(self, other, **kwargs): + def union(self, *other, **kwargs): """Return a SQL ``UNION`` of this select() construct against the given selectable. """ - return CompoundSelect._create_union(self, other, **kwargs) + return CompoundSelect._create_union(self, *other, **kwargs) - def union_all(self, other, **kwargs): + def union_all(self, *other, **kwargs): """Return a SQL ``UNION ALL`` of this select() construct against the given selectable. """ - return CompoundSelect._create_union_all(self, other, **kwargs) + return CompoundSelect._create_union_all(self, *other, **kwargs) def except_(self, other, **kwargs): """Return a SQL ``EXCEPT`` of this select() construct against diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 5d24adff93..2b66825028 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1720,6 +1720,31 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "foo", ) + def test_multiple_recursive_unions(self): + root_query = select(literal(1).label("val")).cte( + "increasing", recursive=True + ) + rec_part_1 = select((root_query.c.val + 3).label("val")).where( + root_query.c.val < 15 + ) + rec_part_2 = select((root_query.c.val + 5).label("val")).where( + root_query.c.val < 15 + ) + rec_query = root_query.union(rec_part_1, rec_part_2) + + stmt = select(rec_query) + + self.assert_compile( + stmt, + "WITH RECURSIVE increasing(val) AS " + "(SELECT :param_1 AS val " + "UNION SELECT increasing.val + :val_1 AS val FROM increasing " + "WHERE increasing.val < :val_2 " + "UNION SELECT increasing.val + :val_3 AS val FROM increasing " + "WHERE increasing.val < :val_4) " + "SELECT increasing.val FROM increasing", + ) + class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):