From 24a53fd8fce2cdfb0154caa687ef893bcff120a7 Mon Sep 17 00:00:00 2001 From: Eric Masseran Date: Tue, 2 Nov 2021 16:40:04 -0400 Subject: [PATCH] Add Non linear CTE support "Compound select" methods like :meth:`_sql.Select.union`, :meth:`_sql.Select.intersect_all` etc. now accept ``*other`` as an argument rather than ``other`` to allow for multiple additional SELECTs to be compounded with the parent statement at once. In particular, the change as applied to :meth:`_sql.CTE.union` and :meth:`_sql.CTE.union_all` now allow for a so-called "non-linear CTE" to be created with the :class:`_sql.CTE` construct, whereas previously there was no way to have more than two CTE sub-elements in a UNION together while still correctly calling upon the CTE in recursive fashion. Pull request courtesy Eric Masseran. Allow: ```sql WITH RECURSIVE nodes(x) AS ( SELECT 59 UNION SELECT aa FROM edge JOIN nodes ON bb=x UNION SELECT bb FROM edge JOIN nodes ON aa=x ) SELECT x FROM nodes; ``` Based on @zzzeek suggestion: https://github.com/sqlalchemy/sqlalchemy/pull/7133#issuecomment-933882348 Fixes: #7259 Closes: #7260 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7260 Pull-request-sha: 2565a5fd4b1940e92125e53aeaa731cc682f49bb Change-Id: I685c8379762b5fb6ab4107ff8f4d8a4de70c0ca6 (cherry picked from commit 958f902b1fc528fed0be550bc573545de47ed854) --- doc/build/changelog/unreleased_14/7259.rst | 13 ++ lib/sqlalchemy/sql/selectable.py | 184 +++++++++++++++++---- test/sql/test_cte.py | 47 ++++++ test/sql/test_select.py | 23 ++- 4 files changed, 237 insertions(+), 30 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/7259.rst diff --git a/doc/build/changelog/unreleased_14/7259.rst b/doc/build/changelog/unreleased_14/7259.rst new file mode 100644 index 0000000000..477714edd9 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7259.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: sql, usecase + :tickets: 7259 + + "Compound select" methods like :meth:`_sql.Select.union`, + :meth:`_sql.Select.intersect_all` etc. now accept ``*other`` as an argument + rather than ``other`` to allow for multiple additional SELECTs to be + compounded with the parent statement at once. In particular, the change as + applied to :meth:`_sql.CTE.union` and :meth:`_sql.CTE.union_all` now allow + for a so-called "non-linear CTE" to be created with the :class:`_sql.CTE` + construct, whereas previously there was no way to have more than two CTE + sub-elements in a UNION together while still correctly calling upon the CTE + in recursive fashion. Pull request courtesy Eric Masseran. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 9143602970..95fca267c6 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2121,9 +2121,23 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, other): + def union(self, *other): + r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` + of the original CTE against the given selectables provided + as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. + + .. versionchanged:: 1.4.28 multiple elements are now accepted. + + .. seealso:: + + :meth:`_sql.HasCTE.cte` - examples of calling styles + + """ return CTE._construct( - self.element.union(other), + self.element.union(*other), name=self.name, recursive=self.recursive, nesting=self.nesting, @@ -2132,9 +2146,23 @@ class CTE( _suffixes=self._suffixes, ) - def union_all(self, other): + def union_all(self, *other): + r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` + of the original CTE against the given selectables provided + as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. + + .. versionchanged:: 1.4.28 multiple elements are now accepted. + + .. seealso:: + + :meth:`_sql.HasCTE.cte` - examples of calling styles + + """ return CTE._construct( - self.element.union_all(other), + self.element.union_all(*other), name=self.name, recursive=self.recursive, nesting=self.nesting, @@ -2396,7 +2424,7 @@ class HasCTE(roles.HasCTERole): connection.execute(upsert) - Example 4, Nesting CTE:: + Example 4, Nesting CTE (SQLAlchemy 1.4.24 and above):: value_a = select( literal("root").label("n") @@ -2426,6 +2454,44 @@ class HasCTE(roles.HasCTERole): SELECT value_a.n AS a, value_b.n AS b FROM value_a, value_b + Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above):: + + edge = Table( + "edge", + metadata, + Column("id", Integer, primary_key=True), + Column("left", Integer), + Column("right", Integer), + ) + + root_node = select(literal(1).label("node")).cte( + "nodes", recursive=True + ) + + left_edge = select(edge.c.left).join( + root_node, edge.c.right == root_node.c.node + ) + right_edge = select(edge.c.right).join( + root_node, edge.c.left == root_node.c.node + ) + + subgraph_cte = root_node.union(left_edge, right_edge) + + subgraph = select(subgraph_cte) + + The above query will render 2 UNIONs inside the recursive CTE:: + + WITH RECURSIVE nodes(node) AS ( + SELECT 1 AS node + UNION + SELECT edge."left" AS "left" + FROM edge JOIN nodes ON edge."right" = nodes.node + UNION + SELECT edge."right" AS "right" + FROM edge JOIN nodes ON edge."left" = nodes.node + ) + SELECT nodes.node FROM nodes + .. seealso:: :meth:`_orm.Query.cte` - ORM version of @@ -6270,47 +6336,107 @@ class Select( else: return SelectStatementGrouping(self) - def union(self, other, **kwargs): - """Return a SQL ``UNION`` of this select() construct against - the given selectable. + def union(self, *other, **kwargs): + r"""Return a SQL ``UNION`` of this select() construct against + the given selectables provided as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. + + .. versionchanged:: 1.4.28 + + multiple elements are now accepted. + + :param \**kwargs: keyword arguments are forwarded to the constructor + for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union(self, other, **kwargs) + return CompoundSelect._create_union(self, *other, **kwargs) + + def union_all(self, *other, **kwargs): + r"""Return a SQL ``UNION ALL`` of this select() construct against + the given selectables provided as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. - def union_all(self, other, **kwargs): - """Return a SQL ``UNION ALL`` of this select() construct against - the given selectable. + .. versionchanged:: 1.4.28 + + multiple elements are now accepted. + + :param \**kwargs: keyword arguments are forwarded to the constructor + for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_union_all(self, other, **kwargs) + return CompoundSelect._create_union_all(self, *other, **kwargs) + + def except_(self, *other, **kwargs): + r"""Return a SQL ``EXCEPT`` of this select() construct against + the given selectable provided as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. + + .. versionchanged:: 1.4.28 + + multiple elements are now accepted. - def except_(self, other, **kwargs): - """Return a SQL ``EXCEPT`` of this select() construct against - the given selectable. + :param \**kwargs: keyword arguments are forwarded to the constructor + for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_except(self, other, **kwargs) + return CompoundSelect._create_except(self, *other, **kwargs) - def except_all(self, other, **kwargs): - """Return a SQL ``EXCEPT ALL`` of this select() construct against - the given selectable. + def except_all(self, *other, **kwargs): + r"""Return a SQL ``EXCEPT ALL`` of this select() construct against + the given selectables provided as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. + + .. versionchanged:: 1.4.28 + + multiple elements are now accepted. + + :param \**kwargs: keyword arguments are forwarded to the constructor + for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_except_all(self, other, **kwargs) + return CompoundSelect._create_except_all(self, *other, **kwargs) + + def intersect(self, *other, **kwargs): + r"""Return a SQL ``INTERSECT`` of this select() construct against + the given selectables provided as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. - def intersect(self, other, **kwargs): - """Return a SQL ``INTERSECT`` of this select() construct against - the given selectable. + .. versionchanged:: 1.4.28 + + multiple elements are now accepted. + + :param \**kwargs: keyword arguments are forwarded to the constructor + for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect(self, other, **kwargs) + return CompoundSelect._create_intersect(self, *other, **kwargs) + + def intersect_all(self, *other, **kwargs): + r"""Return a SQL ``INTERSECT ALL`` of this select() construct + against the given selectables provided as positional arguments. + + :param \*other: one or more elements with which to create a + UNION. + + .. versionchanged:: 1.4.28 + + multiple elements are now accepted. - def intersect_all(self, other, **kwargs): - """Return a SQL ``INTERSECT ALL`` of this select() construct - against the given selectable. + :param \**kwargs: keyword arguments are forwarded to the constructor + for the newly created :class:`_sql.CompoundSelect` object. """ - return CompoundSelect._create_intersect_all(self, other, **kwargs) + return CompoundSelect._create_intersect_all(self, *other, **kwargs) @property @util.deprecated_20( diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 10fe81b553..df9f065acc 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1769,6 +1769,53 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "foo", ) + def test_recursive_cte_with_multiple_union(self): + root_query = select(literal(1).label("val")).cte( + "increasing", recursive=True + ) + rec_part_1 = select((root_query.c.val + 3).label("val")).where( + root_query.c.val < 15 + ) + rec_part_2 = select((root_query.c.val + 5).label("val")).where( + root_query.c.val < 15 + ) + union_rec_query = root_query.union(rec_part_1, rec_part_2) + union_stmt = select(union_rec_query) + self.assert_compile( + union_stmt, + "WITH RECURSIVE increasing(val) AS " + "(SELECT :param_1 AS val " + "UNION SELECT increasing.val + :val_1 AS val FROM increasing " + "WHERE increasing.val < :val_2 " + "UNION SELECT increasing.val + :val_3 AS val FROM increasing " + "WHERE increasing.val < :val_4) " + "SELECT increasing.val FROM increasing", + ) + + def test_recursive_cte_with_multiple_union_all(self): + root_query = select(literal(1).label("val")).cte( + "increasing", recursive=True + ) + rec_part_1 = select((root_query.c.val + 3).label("val")).where( + root_query.c.val < 15 + ) + rec_part_2 = select((root_query.c.val + 5).label("val")).where( + root_query.c.val < 15 + ) + + union_all_rec_query = root_query.union_all(rec_part_1, rec_part_2) + union_all_stmt = select(union_all_rec_query) + self.assert_compile( + union_all_stmt, + "WITH RECURSIVE increasing(val) AS " + "(SELECT :param_1 AS val " + "UNION ALL SELECT increasing.val + :val_1 AS val FROM increasing " + "WHERE increasing.val < :val_2 " + "UNION ALL SELECT increasing.val + :val_3 AS val FROM increasing " + "WHERE increasing.val < :val_4) " + "SELECT increasing.val FROM increasing", + ) + class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/sql/test_select.py b/test/sql/test_select.py index 17b47d96de..c9abb7fb8b 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -8,15 +8,16 @@ from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import testing from sqlalchemy import tuple_ from sqlalchemy import union from sqlalchemy.sql import column +from sqlalchemy.sql import literal from sqlalchemy.sql import table from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import fixtures - table1 = table( "mytable", column("myid", Integer), @@ -412,3 +413,23 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.name FROM (SELECT mytable.name AS name, " "(mytable.myid, mytable.name) AS anon_2 FROM mytable) AS anon_1", ) + + @testing.combinations( + ("union_all", "UNION ALL"), + ("union", "UNION"), + ("intersect_all", "INTERSECT ALL"), + ("intersect", "INTERSECT"), + ("except_all", "EXCEPT ALL"), + ("except_", "EXCEPT"), + ) + def test_select_multiple_compound_elements(self, methname, joiner): + stmt = select(literal(1)) + meth = getattr(stmt, methname) + stmt = meth(select(literal(2)), select(literal(3))) + + self.assert_compile( + stmt, + "SELECT :param_1 AS anon_1" + " %(joiner)s SELECT :param_2 AS anon_2" + " %(joiner)s SELECT :param_3 AS anon_3" % {"joiner": joiner}, + ) -- 2.47.3