From: Eric Masseran Date: Tue, 2 Nov 2021 13:33:38 +0000 (+0100) Subject: Extend to similar methods + add docs/tests X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=744ca96f3ad29dd545fac96543bfcda144b43f78;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Extend to similar methods + add docs/tests --- diff --git a/doc/build/changelog/unreleased_14/7259.rst b/doc/build/changelog/unreleased_14/7259.rst index a1bf65b705..a1f7be12e9 100644 --- a/doc/build/changelog/unreleased_14/7259.rst +++ b/doc/build/changelog/unreleased_14/7259.rst @@ -3,6 +3,9 @@ :tickets: 7259 Render Non-linear CTE with :meth:`_sql.CTE.union` and :meth:`_sql.CTE.union_all` - that accept multiple clauses to render multiple UNIONs. To make it possible, + that accept multiple selectables to render multiple UNIONs. To make it possible, :meth:`_sql.Select.union` and :meth:`_sql.Select.union_all` accept also multiple - clauses. + selectables. It has also been extended to other similar methods like + :meth:`_sql.Select.except_`, :meth:`_sql.Select.except_all`, + :meth:`_sql.Select.intersect` and :meth:`_sql.Select.intersect_all` that accept + multiple selectables. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c7446d75bc..8d77523fac 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -6291,45 +6291,63 @@ class Select( def union(self, *other, **kwargs): """Return a SQL ``UNION`` of this select() construct against - the given selectable. + the given selectables provided as positional arguments. + + The keyword arguments are forwarded to the method + :meth:`_expression.CompoundSelect._create_union`. """ return CompoundSelect._create_union(self, *other, **kwargs) def union_all(self, *other, **kwargs): """Return a SQL ``UNION ALL`` of this select() construct against - the given selectable. + the given selectables provided as positional arguments. + + The keyword arguments are forwarded to the method + :meth:`_expression.CompoundSelect._create_union_all`. """ return CompoundSelect._create_union_all(self, *other, **kwargs) - def except_(self, other, **kwargs): + def except_(self, *other, **kwargs): """Return a SQL ``EXCEPT`` of this select() construct against - the given selectable. + the given selectable provided as positional arguments. + + The keyword arguments are forwarded to the method + :meth:`_expression.CompoundSelect._create_except`. """ - return CompoundSelect._create_except(self, other, **kwargs) + return CompoundSelect._create_except(self, *other, **kwargs) - def except_all(self, other, **kwargs): + def except_all(self, *other, **kwargs): """Return a SQL ``EXCEPT ALL`` of this select() construct against - the given selectable. + the given selectables provided as positional arguments. + + The keyword arguments are forwarded to the method + :meth:`_expression.CompoundSelect._create_except_all`. """ - return CompoundSelect._create_except_all(self, other, **kwargs) + return CompoundSelect._create_except_all(self, *other, **kwargs) - def intersect(self, other, **kwargs): + def intersect(self, *other, **kwargs): """Return a SQL ``INTERSECT`` of this select() construct against - the given selectable. + the given selectables provided as positional arguments. + + The keyword arguments are forwarded to the method + :meth:`_expression.CompoundSelect._create_intersect`. """ - return CompoundSelect._create_intersect(self, other, **kwargs) + return CompoundSelect._create_intersect(self, *other, **kwargs) - def intersect_all(self, other, **kwargs): + def intersect_all(self, *other, **kwargs): """Return a SQL ``INTERSECT ALL`` of this select() construct - against the given selectable. + against the given selectables provided as positional arguments. + + The keyword arguments are forwarded to the method + :meth:`_expression.CompoundSelect._create_intersect_all`. """ - return CompoundSelect._create_intersect_all(self, other, **kwargs) + return CompoundSelect._create_intersect_all(self, *other, **kwargs) @property @util.deprecated_20( diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 2b66825028..3701b4801c 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1720,7 +1720,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "foo", ) - def test_multiple_recursive_unions(self): + def test_recursive_cte_with_multiple_union(self): root_query = select(literal(1).label("val")).cte( "increasing", recursive=True ) @@ -1730,12 +1730,10 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): rec_part_2 = select((root_query.c.val + 5).label("val")).where( root_query.c.val < 15 ) - rec_query = root_query.union(rec_part_1, rec_part_2) - - stmt = select(rec_query) - + union_rec_query = root_query.union(rec_part_1, rec_part_2) + union_stmt = select(union_rec_query) self.assert_compile( - stmt, + union_stmt, "WITH RECURSIVE increasing(val) AS " "(SELECT :param_1 AS val " "UNION SELECT increasing.val + :val_1 AS val FROM increasing " @@ -1745,6 +1743,30 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT increasing.val FROM increasing", ) + def test_recursive_cte_with_multiple_union_all(self): + root_query = select(literal(1).label("val")).cte( + "increasing", recursive=True + ) + rec_part_1 = select((root_query.c.val + 3).label("val")).where( + root_query.c.val < 15 + ) + rec_part_2 = select((root_query.c.val + 5).label("val")).where( + root_query.c.val < 15 + ) + + union_all_rec_query = root_query.union_all(rec_part_1, rec_part_2) + union_all_stmt = select(union_all_rec_query) + self.assert_compile( + union_all_stmt, + "WITH RECURSIVE increasing(val) AS " + "(SELECT :param_1 AS val " + "UNION ALL SELECT increasing.val + :val_1 AS val FROM increasing " + "WHERE increasing.val < :val_2 " + "UNION ALL SELECT increasing.val + :val_3 AS val FROM increasing " + "WHERE increasing.val < :val_4) " + "SELECT increasing.val FROM increasing", + ) + class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/sql/test_select.py b/test/sql/test_select.py index 17b47d96de..94efbc028b 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -11,6 +11,7 @@ from sqlalchemy import Table from sqlalchemy import tuple_ from sqlalchemy import union from sqlalchemy.sql import column +from sqlalchemy.sql import literal from sqlalchemy.sql import table from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -412,3 +413,75 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.name FROM (SELECT mytable.name AS name, " "(mytable.myid, mytable.name) AS anon_2 FROM mytable) AS anon_1", ) + + def test_select_multiple_union_all(self): + stmt_union = select(literal(1)).union_all( + select(literal(2)), select(literal(3)) + ) + + self.assert_compile( + stmt_union, + "SELECT :param_1 AS anon_1" + " UNION ALL SELECT :param_2 AS anon_2" + " UNION ALL SELECT :param_3 AS anon_3", + ) + + def test_select_multiple_union(self): + stmt_union = select(literal(1)).union( + select(literal(2)), select(literal(3)) + ) + + self.assert_compile( + stmt_union, + "SELECT :param_1 AS anon_1" + " UNION SELECT :param_2 AS anon_2" + " UNION SELECT :param_3 AS anon_3", + ) + + def test_select_multiple_except(self): + stmt_union = select(literal(1)).except_( + select(literal(2)), select(literal(3)) + ) + + self.assert_compile( + stmt_union, + "SELECT :param_1 AS anon_1" + " EXCEPT SELECT :param_2 AS anon_2" + " EXCEPT SELECT :param_3 AS anon_3", + ) + + def test_select_multiple_except_all(self): + stmt_union = select(literal(1)).except_all( + select(literal(2)), select(literal(3)) + ) + + self.assert_compile( + stmt_union, + "SELECT :param_1 AS anon_1" + " EXCEPT ALL SELECT :param_2 AS anon_2" + " EXCEPT ALL SELECT :param_3 AS anon_3", + ) + + def test_select_multiple_intersect(self): + stmt_union = select(literal(1)).intersect( + select(literal(2)), select(literal(3)) + ) + + self.assert_compile( + stmt_union, + "SELECT :param_1 AS anon_1" + " INTERSECT SELECT :param_2 AS anon_2" + " INTERSECT SELECT :param_3 AS anon_3", + ) + + def test_select_multiple_intersect_all(self): + stmt_union = select(literal(1)).intersect_all( + select(literal(2)), select(literal(3)) + ) + + self.assert_compile( + stmt_union, + "SELECT :param_1 AS anon_1" + " INTERSECT ALL SELECT :param_2 AS anon_2" + " INTERSECT ALL SELECT :param_3 AS anon_3", + )