]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support generic types for union and union_all
authorMingyu Park <mingyuu.dev@gmail.com>
Fri, 7 Feb 2025 19:45:26 +0000 (14:45 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 12 Feb 2025 20:08:49 +0000 (21:08 +0100)
Support generic types for compound selects (:func:`_sql.union`,
:func:`_sql.union_all`, :meth:`_sql.Select.union`,
:meth:`_sql.Select.union_all`, etc) returning the type of the first select.

Fixes: #11922
Closes: #12320
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12320
Pull-request-sha: f914a19f7201cec292056e900436d8c8431b9f87

Change-Id: I4fffa5d3fe93dd3a293b078360e326fea4207c5d
(cherry picked from commit fc44b5078b74081b0df94cca9d21b89ed578caf3)

doc/build/changelog/unreleased_20/11922.rst [new file with mode: 0644]
lib/sqlalchemy/sql/_selectable_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/selectable.py
test/typing/plain_files/sql/common_sql_element.py

diff --git a/doc/build/changelog/unreleased_20/11922.rst b/doc/build/changelog/unreleased_20/11922.rst
new file mode 100644 (file)
index 0000000..f0e7e3d
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: typing, usecase
+    :tickets: 11922
+
+    Support generic types for compound selects (:func:`_sql.union`,
+    :func:`_sql.union_all`, :meth:`_sql.Select.union`,
+    :meth:`_sql.Select.union_all`, etc) returning the type of the first select.
+    Pull request courtesy of Mingyu Park.
index 1660778c56fccfab4e6375d82c45020a8f3f36a7..69427334a32212f309b90ba3f7a536ce8e399248 100644 (file)
@@ -12,7 +12,6 @@ from typing import Optional
 from typing import overload
 from typing import Tuple
 from typing import TYPE_CHECKING
-from typing import TypeVar
 from typing import Union
 
 from . import coercions
@@ -47,6 +46,7 @@ if TYPE_CHECKING:
     from ._typing import _T7
     from ._typing import _T8
     from ._typing import _T9
+    from ._typing import _TP
     from ._typing import _TypedColumnClauseArgument as _TCCA
     from .functions import Function
     from .selectable import CTE
@@ -55,9 +55,6 @@ if TYPE_CHECKING:
     from .selectable import SelectBase
 
 
-_T = TypeVar("_T", bound=Any)
-
-
 def alias(
     selectable: FromClause, name: Optional[str] = None, flat: bool = False
 ) -> NamedFromClause:
@@ -106,9 +103,28 @@ def cte(
     )
 
 
+# TODO: mypy requires the _TypedSelectable overloads in all compound select
+# constructors since _SelectStatementForCompoundArgument includes
+# untyped args that make it return CompoundSelect[Unpack[tuple[Never, ...]]]
+# pyright does not have this issue
+_TypedSelectable = Union["Select[_TP]", "CompoundSelect[_TP]"]
+
+
+@overload
 def except_(
-    *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+    *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def except_(
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+def except_(
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
     r"""Return an ``EXCEPT`` of multiple selectables.
 
     The returned object is an instance of
@@ -121,9 +137,21 @@ def except_(
     return CompoundSelect._create_except(*selects)
 
 
+@overload
+def except_all(
+    *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def except_all(
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
 def except_all(
-    *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
     r"""Return an ``EXCEPT ALL`` of multiple selectables.
 
     The returned object is an instance of
@@ -181,9 +209,21 @@ def exists(
     return Exists(__argument)
 
 
+@overload
+def intersect(
+    *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def intersect(
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
 def intersect(
-    *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
     r"""Return an ``INTERSECT`` of multiple selectables.
 
     The returned object is an instance of
@@ -196,9 +236,21 @@ def intersect(
     return CompoundSelect._create_intersect(*selects)
 
 
+@overload
+def intersect_all(
+    *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
 def intersect_all(
-    *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+def intersect_all(
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
     r"""Return an ``INTERSECT ALL`` of multiple selectables.
 
     The returned object is an instance of
@@ -557,9 +609,21 @@ def tablesample(
     return TableSample._factory(selectable, sampling, name=name, seed=seed)
 
 
+@overload
+def union(
+    *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
 def union(
-    *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+def union(
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
     r"""Return a ``UNION`` of multiple selectables.
 
     The returned object is an instance of
@@ -579,9 +643,21 @@ def union(
     return CompoundSelect._create_union(*selects)
 
 
+@overload
+def union_all(
+    *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def union_all(
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
 def union_all(
-    *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+    *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
     r"""Return a ``UNION ALL`` of multiple selectables.
 
     The returned object is an instance of
index cf9129b479b575803d682ed2810e97f76c936ba5..b1af53f77772cdc445bfb29aeda461e9a01d83e3 100644 (file)
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
     from .roles import FromClauseRole
     from .schema import Column
     from .selectable import Alias
+    from .selectable import CompoundSelect
     from .selectable import CTE
     from .selectable import FromClause
     from .selectable import Join
@@ -247,7 +248,9 @@ come from the ORM.
 """
 
 _SelectStatementForCompoundArgument = Union[
-    "SelectBase", roles.CompoundElementRole
+    "Select[_TP]",
+    "CompoundSelect[_TP]",
+    roles.CompoundElementRole,
 ]
 """SELECT statement acceptable by ``union()`` and other SQL set operations"""
 
index 8aa9f41eb9f63b1b49f491cb43db658d37653769..5db1e729e7a3bb4251bbdfdce7d931e854331ded 100644 (file)
@@ -47,6 +47,7 @@ from . import type_api
 from . import visitors
 from ._typing import _ColumnsClauseArgument
 from ._typing import _no_kw
+from ._typing import _T
 from ._typing import _TP
 from ._typing import is_column_element
 from ._typing import is_select_statement
@@ -101,9 +102,9 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 from ..util.typing import Self
 
+
 and_ = BooleanClauseList.and_
 
-_T = TypeVar("_T", bound=Any)
 
 if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
@@ -286,7 +287,7 @@ class ExecutableReturnsRows(Executable, ReturnsRows):
 
 
 class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]):
-    """base for executable statements that return rows."""
+    """base for a typed executable statements that return rows."""
 
 
 class Selectable(ReturnsRows):
@@ -2224,7 +2225,7 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def union(self, *other: _SelectStatementForCompoundArgument) -> CTE:
+    def union(self, *other: _SelectStatementForCompoundArgument[Any]) -> CTE:
         r"""Return a new :class:`_expression.CTE` with a SQL ``UNION``
         of the original CTE against the given selectables provided
         as positional arguments.
@@ -2253,7 +2254,9 @@ class CTE(
             _suffixes=self._suffixes,
         )
 
-    def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE:
+    def union_all(
+        self, *other: _SelectStatementForCompoundArgument[Any]
+    ) -> CTE:
         r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL``
         of the original CTE against the given selectables provided
         as positional arguments.
@@ -4448,7 +4451,7 @@ class _CompoundSelectKeyword(Enum):
     INTERSECT_ALL = "INTERSECT ALL"
 
 
-class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
+class CompoundSelect(HasCompileState, GenerativeSelect, TypedReturnsRows[_TP]):
     """Forms the basis of ``UNION``, ``UNION ALL``, and other
     SELECT-based set operations.
 
@@ -4495,7 +4498,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
     def __init__(
         self,
         keyword: _CompoundSelectKeyword,
-        *selects: _SelectStatementForCompoundArgument,
+        *selects: _SelectStatementForCompoundArgument[_TP],
     ):
         self.keyword = keyword
         self.selects = [
@@ -4509,38 +4512,38 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
 
     @classmethod
     def _create_union(
-        cls, *selects: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        cls, *selects: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         return CompoundSelect(_CompoundSelectKeyword.UNION, *selects)
 
     @classmethod
     def _create_union_all(
-        cls, *selects: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        cls, *selects: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects)
 
     @classmethod
     def _create_except(
-        cls, *selects: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        cls, *selects: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects)
 
     @classmethod
     def _create_except_all(
-        cls, *selects: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        cls, *selects: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects)
 
     @classmethod
     def _create_intersect(
-        cls, *selects: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        cls, *selects: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects)
 
     @classmethod
     def _create_intersect_all(
-        cls, *selects: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        cls, *selects: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects)
 
     def _scalar_type(self) -> TypeEngine[Any]:
@@ -4557,7 +4560,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
                 return True
         return False
 
-    def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect:
+    def set_label_style(self, style: SelectLabelStyle) -> Self:
         if self._label_style is not style:
             self = self._generate()
             select_0 = self.selects[0].set_label_style(style)
@@ -4565,7 +4568,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
 
         return self
 
-    def _ensure_disambiguated_names(self) -> CompoundSelect:
+    def _ensure_disambiguated_names(self) -> Self:
         new_select = self.selects[0]._ensure_disambiguated_names()
         if new_select is not self.selects[0]:
             self = self._generate()
@@ -6585,8 +6588,8 @@ class Select(
             return SelectStatementGrouping(self)
 
     def union(
-        self, *other: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        self, *other: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         r"""Return a SQL ``UNION`` of this select() construct against
         the given selectables provided as positional arguments.
 
@@ -6604,8 +6607,8 @@ class Select(
         return CompoundSelect._create_union(self, *other)
 
     def union_all(
-        self, *other: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        self, *other: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         r"""Return a SQL ``UNION ALL`` of this select() construct against
         the given selectables provided as positional arguments.
 
@@ -6623,8 +6626,8 @@ class Select(
         return CompoundSelect._create_union_all(self, *other)
 
     def except_(
-        self, *other: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        self, *other: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         r"""Return a SQL ``EXCEPT`` of this select() construct against
         the given selectable provided as positional arguments.
 
@@ -6639,8 +6642,8 @@ class Select(
         return CompoundSelect._create_except(self, *other)
 
     def except_all(
-        self, *other: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        self, *other: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         r"""Return a SQL ``EXCEPT ALL`` of this select() construct against
         the given selectables provided as positional arguments.
 
@@ -6655,8 +6658,8 @@ class Select(
         return CompoundSelect._create_except_all(self, *other)
 
     def intersect(
-        self, *other: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        self, *other: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         r"""Return a SQL ``INTERSECT`` of this select() construct against
         the given selectables provided as positional arguments.
 
@@ -6674,8 +6677,8 @@ class Select(
         return CompoundSelect._create_intersect(self, *other)
 
     def intersect_all(
-        self, *other: _SelectStatementForCompoundArgument
-    ) -> CompoundSelect:
+        self, *other: _SelectStatementForCompoundArgument[_TP]
+    ) -> CompoundSelect[_TP]:
         r"""Return a SQL ``INTERSECT ALL`` of this select() construct
         against the given selectables provided as positional arguments.
 
index fb0add31d819ca996d662dd99a75b1520bb0d94c..d5b8f8834003b5d50e59fd516f1f215b029ab048 100644 (file)
@@ -11,14 +11,21 @@ from __future__ import annotations
 from sqlalchemy import asc
 from sqlalchemy import Column
 from sqlalchemy import column
+from sqlalchemy import ColumnElement
 from sqlalchemy import desc
+from sqlalchemy import except_
+from sqlalchemy import except_all
 from sqlalchemy import Integer
+from sqlalchemy import intersect
+from sqlalchemy import intersect_all
 from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import SQLColumnExpression
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy import union
+from sqlalchemy import union_all
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -176,3 +183,75 @@ mydict = {
     literal("5"): "q",
     column("q"): "q",
 }
+
+# compound selects (issue #11922):
+
+str_col = ColumnElement[str]()
+int_col = ColumnElement[int]()
+
+first_stmt = select(str_col, int_col)
+second_stmt = select(str_col, int_col)
+third_stmt = select(int_col, str_col)
+
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(union(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(union_all(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(except_(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(except_all(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(intersect(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(intersect_all(first_stmt, second_stmt))
+
+# EXPECTED_TYPE: Result[tuple[str, int]]
+reveal_type(Session().execute(union(first_stmt, second_stmt)))
+# EXPECTED_TYPE: Result[tuple[str, int]]
+reveal_type(Session().execute(union_all(first_stmt, second_stmt)))
+
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union_all(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_all(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect_all(second_stmt))
+
+# TODO: the following do not error because _SelectStatementForCompoundArgument
+# includes untyped elements so the type checker falls back on them when
+# the type does not match. Also for the standalone functions mypy
+# looses the plot and returns a random type back. See TODO in the
+# overloads
+
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(union(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(union_all(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(except_(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(except_all(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(intersect(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(intersect_all(first_stmt, third_stmt))
+
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union_all(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_all(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect_all(third_stmt))