From: Mingyu Park Date: Fri, 7 Feb 2025 19:45:26 +0000 (-0500) Subject: Support generic types for union and union_all X-Git-Tag: rel_2_0_39~17^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=253b3694b7abc3b8fee82e9a83a719047885d94a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support generic types for union and union_all Support generic types for compound selects (:func:`_sql.union`, :func:`_sql.union_all`, :meth:`_sql.Select.union`, :meth:`_sql.Select.union_all`, etc) returning the type of the first select. Fixes: #11922 Closes: #12320 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12320 Pull-request-sha: f914a19f7201cec292056e900436d8c8431b9f87 Change-Id: I4fffa5d3fe93dd3a293b078360e326fea4207c5d (cherry picked from commit fc44b5078b74081b0df94cca9d21b89ed578caf3) --- diff --git a/doc/build/changelog/unreleased_20/11922.rst b/doc/build/changelog/unreleased_20/11922.rst new file mode 100644 index 0000000000..f0e7e3d978 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11922.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: typing, usecase + :tickets: 11922 + + Support generic types for compound selects (:func:`_sql.union`, + :func:`_sql.union_all`, :meth:`_sql.Select.union`, + :meth:`_sql.Select.union_all`, etc) returning the type of the first select. + Pull request courtesy of Mingyu Park. diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 1660778c56..69427334a3 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -12,7 +12,6 @@ from typing import Optional from typing import overload from typing import Tuple from typing import TYPE_CHECKING -from typing import TypeVar from typing import Union from . import coercions @@ -47,6 +46,7 @@ if TYPE_CHECKING: from ._typing import _T7 from ._typing import _T8 from ._typing import _T9 + from ._typing import _TP from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE @@ -55,9 +55,6 @@ if TYPE_CHECKING: from .selectable import SelectBase -_T = TypeVar("_T", bound=Any) - - def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -106,9 +103,28 @@ def cte( ) +# TODO: mypy requires the _TypedSelectable overloads in all compound select +# constructors since _SelectStatementForCompoundArgument includes +# untyped args that make it return CompoundSelect[Unpack[tuple[Never, ...]]] +# pyright does not have this issue +_TypedSelectable = Union["Select[_TP]", "CompoundSelect[_TP]"] + + +@overload def except_( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def except_( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + +def except_( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -121,9 +137,21 @@ def except_( return CompoundSelect._create_except(*selects) +@overload +def except_all( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def except_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + def except_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``EXCEPT ALL`` of multiple selectables. The returned object is an instance of @@ -181,9 +209,21 @@ def exists( return Exists(__argument) +@overload +def intersect( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def intersect( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + def intersect( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -196,9 +236,21 @@ def intersect( return CompoundSelect._create_intersect(*selects) +@overload +def intersect_all( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload def intersect_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + +def intersect_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return an ``INTERSECT ALL`` of multiple selectables. The returned object is an instance of @@ -557,9 +609,21 @@ def tablesample( return TableSample._factory(selectable, sampling, name=name, seed=seed) +@overload +def union( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload def union( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + +def union( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -579,9 +643,21 @@ def union( return CompoundSelect._create_union(*selects) +@overload +def union_all( + *selects: _TypedSelectable[_TP], +) -> CompoundSelect[_TP]: ... + + +@overload +def union_all( + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: ... + + def union_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[_TP], +) -> CompoundSelect[_TP]: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index cf9129b479..b1af53f777 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: from .roles import FromClauseRole from .schema import Column from .selectable import Alias + from .selectable import CompoundSelect from .selectable import CTE from .selectable import FromClause from .selectable import Join @@ -247,7 +248,9 @@ come from the ORM. """ _SelectStatementForCompoundArgument = Union[ - "SelectBase", roles.CompoundElementRole + "Select[_TP]", + "CompoundSelect[_TP]", + roles.CompoundElementRole, ] """SELECT statement acceptable by ``union()`` and other SQL set operations""" diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8aa9f41eb9..5db1e729e7 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -47,6 +47,7 @@ from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument from ._typing import _no_kw +from ._typing import _T from ._typing import _TP from ._typing import is_column_element from ._typing import is_select_statement @@ -101,9 +102,9 @@ from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self + and_ = BooleanClauseList.and_ -_T = TypeVar("_T", bound=Any) if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -286,7 +287,7 @@ class ExecutableReturnsRows(Executable, ReturnsRows): class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]): - """base for executable statements that return rows.""" + """base for a typed executable statements that return rows.""" class Selectable(ReturnsRows): @@ -2224,7 +2225,7 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, *other: _SelectStatementForCompoundArgument) -> CTE: + def union(self, *other: _SelectStatementForCompoundArgument[Any]) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` of the original CTE against the given selectables provided as positional arguments. @@ -2253,7 +2254,9 @@ class CTE( _suffixes=self._suffixes, ) - def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE: + def union_all( + self, *other: _SelectStatementForCompoundArgument[Any] + ) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` of the original CTE against the given selectables provided as positional arguments. @@ -4448,7 +4451,7 @@ class _CompoundSelectKeyword(Enum): INTERSECT_ALL = "INTERSECT ALL" -class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): +class CompoundSelect(HasCompileState, GenerativeSelect, TypedReturnsRows[_TP]): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -4495,7 +4498,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): def __init__( self, keyword: _CompoundSelectKeyword, - *selects: _SelectStatementForCompoundArgument, + *selects: _SelectStatementForCompoundArgument[_TP], ): self.keyword = keyword self.selects = [ @@ -4509,38 +4512,38 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): @classmethod def _create_union( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.UNION, *selects) @classmethod def _create_union_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects) @classmethod def _create_except( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects) @classmethod def _create_except_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects) @classmethod def _create_intersect( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects) @classmethod def _create_intersect_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects) def _scalar_type(self) -> TypeEngine[Any]: @@ -4557,7 +4560,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return True return False - def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect: + def set_label_style(self, style: SelectLabelStyle) -> Self: if self._label_style is not style: self = self._generate() select_0 = self.selects[0].set_label_style(style) @@ -4565,7 +4568,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return self - def _ensure_disambiguated_names(self) -> CompoundSelect: + def _ensure_disambiguated_names(self) -> Self: new_select = self.selects[0]._ensure_disambiguated_names() if new_select is not self.selects[0]: self = self._generate() @@ -6585,8 +6588,8 @@ class Select( return SelectStatementGrouping(self) def union( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``UNION`` of this select() construct against the given selectables provided as positional arguments. @@ -6604,8 +6607,8 @@ class Select( return CompoundSelect._create_union(self, *other) def union_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``UNION ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -6623,8 +6626,8 @@ class Select( return CompoundSelect._create_union_all(self, *other) def except_( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``EXCEPT`` of this select() construct against the given selectable provided as positional arguments. @@ -6639,8 +6642,8 @@ class Select( return CompoundSelect._create_except(self, *other) def except_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``EXCEPT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -6655,8 +6658,8 @@ class Select( return CompoundSelect._create_except_all(self, *other) def intersect( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``INTERSECT`` of this select() construct against the given selectables provided as positional arguments. @@ -6674,8 +6677,8 @@ class Select( return CompoundSelect._create_intersect(self, *other) def intersect_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[_TP] + ) -> CompoundSelect[_TP]: r"""Return a SQL ``INTERSECT ALL`` of this select() construct against the given selectables provided as positional arguments. diff --git a/test/typing/plain_files/sql/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py index fb0add31d8..d5b8f88340 100644 --- a/test/typing/plain_files/sql/common_sql_element.py +++ b/test/typing/plain_files/sql/common_sql_element.py @@ -11,14 +11,21 @@ from __future__ import annotations from sqlalchemy import asc from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import ColumnElement from sqlalchemy import desc +from sqlalchemy import except_ +from sqlalchemy import except_all from sqlalchemy import Integer +from sqlalchemy import intersect +from sqlalchemy import intersect_all from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import SQLColumnExpression from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import union +from sqlalchemy import union_all from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -176,3 +183,75 @@ mydict = { literal("5"): "q", column("q"): "q", } + +# compound selects (issue #11922): + +str_col = ColumnElement[str]() +int_col = ColumnElement[int]() + +first_stmt = select(str_col, int_col) +second_stmt = select(str_col, int_col) +third_stmt = select(int_col, str_col) + +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(union(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(union_all(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(except_(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(except_all(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(intersect(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(intersect_all(first_stmt, second_stmt)) + +# EXPECTED_TYPE: Result[tuple[str, int]] +reveal_type(Session().execute(union(first_stmt, second_stmt))) +# EXPECTED_TYPE: Result[tuple[str, int]] +reveal_type(Session().execute(union_all(first_stmt, second_stmt))) + +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.union(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.union_all(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.except_(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.except_all(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.intersect(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.intersect_all(second_stmt)) + +# TODO: the following do not error because _SelectStatementForCompoundArgument +# includes untyped elements so the type checker falls back on them when +# the type does not match. Also for the standalone functions mypy +# looses the plot and returns a random type back. See TODO in the +# overloads + +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(union(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(union_all(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(except_(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(except_all(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(intersect(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Never] +reveal_type(intersect_all(first_stmt, third_stmt)) + +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.union(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.union_all(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.except_(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.except_all(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.intersect(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[tuple[str, int]] +reveal_type(first_stmt.intersect_all(third_stmt))