]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Extend to similar methods + add docs/tests
authorEric Masseran <eric.masseran@gmail.com>
Tue, 2 Nov 2021 13:33:38 +0000 (14:33 +0100)
committerEric Masseran <eric.masseran@gmail.com>
Tue, 2 Nov 2021 13:44:21 +0000 (14:44 +0100)
doc/build/changelog/unreleased_14/7259.rst
lib/sqlalchemy/sql/selectable.py
test/sql/test_cte.py
test/sql/test_select.py

index a1bf65b705520db95cd365257d4becf7a56ba0f1..a1f7be12e98699e99aeeb0b37690523d74bb56c5 100644 (file)
@@ -3,6 +3,9 @@
     :tickets: 7259
 
     Render Non-linear CTE with :meth:`_sql.CTE.union` and :meth:`_sql.CTE.union_all`
-    that accept multiple clauses to render multiple UNIONs. To make it possible,
+    that accept multiple selectables to render multiple UNIONs. To make it possible,
     :meth:`_sql.Select.union` and :meth:`_sql.Select.union_all` accept also multiple
-    clauses.
+    selectables. It has also been extended to other similar methods like
+    :meth:`_sql.Select.except_`, :meth:`_sql.Select.except_all`,
+    :meth:`_sql.Select.intersect` and :meth:`_sql.Select.intersect_all` that accept
+    multiple selectables.
index c7446d75bc87dda48595a4f3c626d4d5226dec39..8d77523facf0230de6d9d61645551521c4d1fa22 100644 (file)
@@ -6291,45 +6291,63 @@ class Select(
 
     def union(self, *other, **kwargs):
         """Return a SQL ``UNION`` of this select() construct against
-        the given selectable.
+        the given selectables provided as positional arguments.
+
+        The keyword arguments are forwarded to the method
+        :meth:`_expression.CompoundSelect._create_union`.
 
         """
         return CompoundSelect._create_union(self, *other, **kwargs)
 
     def union_all(self, *other, **kwargs):
         """Return a SQL ``UNION ALL`` of this select() construct against
-        the given selectable.
+        the given selectables provided as positional arguments.
+
+        The keyword arguments are forwarded to the method
+        :meth:`_expression.CompoundSelect._create_union_all`.
 
         """
         return CompoundSelect._create_union_all(self, *other, **kwargs)
 
-    def except_(self, other, **kwargs):
+    def except_(self, *other, **kwargs):
         """Return a SQL ``EXCEPT`` of this select() construct against
-        the given selectable.
+        the given selectable provided as positional arguments.
+
+        The keyword arguments are forwarded to the method
+        :meth:`_expression.CompoundSelect._create_except`.
 
         """
-        return CompoundSelect._create_except(self, other, **kwargs)
+        return CompoundSelect._create_except(self, *other, **kwargs)
 
-    def except_all(self, other, **kwargs):
+    def except_all(self, *other, **kwargs):
         """Return a SQL ``EXCEPT ALL`` of this select() construct against
-        the given selectable.
+        the given selectables provided as positional arguments.
+
+        The keyword arguments are forwarded to the method
+        :meth:`_expression.CompoundSelect._create_except_all`.
 
         """
-        return CompoundSelect._create_except_all(self, other, **kwargs)
+        return CompoundSelect._create_except_all(self, *other, **kwargs)
 
-    def intersect(self, other, **kwargs):
+    def intersect(self, *other, **kwargs):
         """Return a SQL ``INTERSECT`` of this select() construct against
-        the given selectable.
+        the given selectables provided as positional arguments.
+
+        The keyword arguments are forwarded to the method
+        :meth:`_expression.CompoundSelect._create_intersect`.
 
         """
-        return CompoundSelect._create_intersect(self, other, **kwargs)
+        return CompoundSelect._create_intersect(self, *other, **kwargs)
 
-    def intersect_all(self, other, **kwargs):
+    def intersect_all(self, *other, **kwargs):
         """Return a SQL ``INTERSECT ALL`` of this select() construct
-        against the given selectable.
+        against the given selectables provided as positional arguments.
+
+        The keyword arguments are forwarded to the method
+        :meth:`_expression.CompoundSelect._create_intersect_all`.
 
         """
-        return CompoundSelect._create_intersect_all(self, other, **kwargs)
+        return CompoundSelect._create_intersect_all(self, *other, **kwargs)
 
     @property
     @util.deprecated_20(
index 2b66825028fcddfb06925e4934c9489294a794ef..3701b4801c9a2785fd9c7ba5fe07ee00b4dc2e79 100644 (file)
@@ -1720,7 +1720,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "foo",
         )
 
-    def test_multiple_recursive_unions(self):
+    def test_recursive_cte_with_multiple_union(self):
         root_query = select(literal(1).label("val")).cte(
             "increasing", recursive=True
         )
@@ -1730,12 +1730,10 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
         rec_part_2 = select((root_query.c.val + 5).label("val")).where(
             root_query.c.val < 15
         )
-        rec_query = root_query.union(rec_part_1, rec_part_2)
-
-        stmt = select(rec_query)
-
+        union_rec_query = root_query.union(rec_part_1, rec_part_2)
+        union_stmt = select(union_rec_query)
         self.assert_compile(
-            stmt,
+            union_stmt,
             "WITH RECURSIVE increasing(val) AS "
             "(SELECT :param_1 AS val "
             "UNION SELECT increasing.val + :val_1 AS val FROM increasing "
@@ -1745,6 +1743,30 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT increasing.val FROM increasing",
         )
 
+    def test_recursive_cte_with_multiple_union_all(self):
+        root_query = select(literal(1).label("val")).cte(
+            "increasing", recursive=True
+        )
+        rec_part_1 = select((root_query.c.val + 3).label("val")).where(
+            root_query.c.val < 15
+        )
+        rec_part_2 = select((root_query.c.val + 5).label("val")).where(
+            root_query.c.val < 15
+        )
+
+        union_all_rec_query = root_query.union_all(rec_part_1, rec_part_2)
+        union_all_stmt = select(union_all_rec_query)
+        self.assert_compile(
+            union_all_stmt,
+            "WITH RECURSIVE increasing(val) AS "
+            "(SELECT :param_1 AS val "
+            "UNION ALL SELECT increasing.val + :val_1 AS val FROM increasing "
+            "WHERE increasing.val < :val_2 "
+            "UNION ALL SELECT increasing.val + :val_3 AS val FROM increasing "
+            "WHERE increasing.val < :val_4) "
+            "SELECT increasing.val FROM increasing",
+        )
+
 
 class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
 
index 17b47d96de77549404dad65bb8d0e5eb2fab493c..94efbc028b09ae833a4c7695706ba2c74e6182b9 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy import Table
 from sqlalchemy import tuple_
 from sqlalchemy import union
 from sqlalchemy.sql import column
+from sqlalchemy.sql import literal
 from sqlalchemy.sql import table
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -412,3 +413,75 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT anon_1.name FROM (SELECT mytable.name AS name, "
             "(mytable.myid, mytable.name) AS anon_2 FROM mytable) AS anon_1",
         )
+
+    def test_select_multiple_union_all(self):
+        stmt_union = select(literal(1)).union_all(
+            select(literal(2)), select(literal(3))
+        )
+
+        self.assert_compile(
+            stmt_union,
+            "SELECT :param_1 AS anon_1"
+            " UNION ALL SELECT :param_2 AS anon_2"
+            " UNION ALL SELECT :param_3 AS anon_3",
+        )
+
+    def test_select_multiple_union(self):
+        stmt_union = select(literal(1)).union(
+            select(literal(2)), select(literal(3))
+        )
+
+        self.assert_compile(
+            stmt_union,
+            "SELECT :param_1 AS anon_1"
+            " UNION SELECT :param_2 AS anon_2"
+            " UNION SELECT :param_3 AS anon_3",
+        )
+
+    def test_select_multiple_except(self):
+        stmt_union = select(literal(1)).except_(
+            select(literal(2)), select(literal(3))
+        )
+
+        self.assert_compile(
+            stmt_union,
+            "SELECT :param_1 AS anon_1"
+            " EXCEPT SELECT :param_2 AS anon_2"
+            " EXCEPT SELECT :param_3 AS anon_3",
+        )
+
+    def test_select_multiple_except_all(self):
+        stmt_union = select(literal(1)).except_all(
+            select(literal(2)), select(literal(3))
+        )
+
+        self.assert_compile(
+            stmt_union,
+            "SELECT :param_1 AS anon_1"
+            " EXCEPT ALL SELECT :param_2 AS anon_2"
+            " EXCEPT ALL SELECT :param_3 AS anon_3",
+        )
+
+    def test_select_multiple_intersect(self):
+        stmt_union = select(literal(1)).intersect(
+            select(literal(2)), select(literal(3))
+        )
+
+        self.assert_compile(
+            stmt_union,
+            "SELECT :param_1 AS anon_1"
+            " INTERSECT SELECT :param_2 AS anon_2"
+            " INTERSECT SELECT :param_3 AS anon_3",
+        )
+
+    def test_select_multiple_intersect_all(self):
+        stmt_union = select(literal(1)).intersect_all(
+            select(literal(2)), select(literal(3))
+        )
+
+        self.assert_compile(
+            stmt_union,
+            "SELECT :param_1 AS anon_1"
+            " INTERSECT ALL SELECT :param_2 AS anon_2"
+            " INTERSECT ALL SELECT :param_3 AS anon_3",
+        )