From: Mingyu Park Date: Fri, 7 Feb 2025 19:45:26 +0000 (-0500) Subject: Support generic types for union and union_all X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fc44b5078b74081b0df94cca9d21b89ed578caf3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support generic types for union and union_all Support generic types for compound selects (:func:`_sql.union`, :func:`_sql.union_all`, :meth:`_sql.Select.union`, :meth:`_sql.Select.union_all`, etc) returning the type of the first select. Fixes: #11922 Closes: #12320 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12320 Pull-request-sha: f914a19f7201cec292056e900436d8c8431b9f87 Change-Id: I4fffa5d3fe93dd3a293b078360e326fea4207c5d --- diff --git a/doc/build/changelog/unreleased_20/11922.rst b/doc/build/changelog/unreleased_20/11922.rst new file mode 100644 index 0000000000..f0e7e3d978 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11922.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: typing, usecase + :tickets: 11922 + + Support generic types for compound selects (:func:`_sql.union`, + :func:`_sql.union_all`, :meth:`_sql.Select.union`, + :meth:`_sql.Select.union_all`, etc) returning the type of the first select. + Pull request courtesy of Mingyu Park. diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index bb553668c3..08149771b1 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -11,7 +11,6 @@ from typing import Any from typing import Optional from typing import overload from typing import TYPE_CHECKING -from typing import TypeVar from typing import Union from . import coercions @@ -48,6 +47,7 @@ if TYPE_CHECKING: from ._typing import _T7 from ._typing import _T8 from ._typing import _T9 + from ._typing import _Ts from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE @@ -56,9 +56,6 @@ if TYPE_CHECKING: from .selectable import SelectBase -_T = TypeVar("_T", bound=Any) - - def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -107,9 +104,28 @@ def cte( ) +# TODO: mypy requires the _TypedSelectable overloads in all compound select +# constructors since _SelectStatementForCompoundArgument includes +# untyped args that make it return CompoundSelect[Unpack[tuple[Never, ...]]] +# pyright does not have this issue +_TypedSelectable = Union["Select[Unpack[_Ts]]", "CompoundSelect[Unpack[_Ts]]"] + + +@overload def except_( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _TypedSelectable[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +@overload +def except_( + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +def except_( + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -122,9 +138,21 @@ def except_( return CompoundSelect._create_except(*selects) +@overload +def except_all( + *selects: _TypedSelectable[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +@overload +def except_all( + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + def except_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: r"""Return an ``EXCEPT ALL`` of multiple selectables. The returned object is an instance of @@ -183,9 +211,21 @@ def exists( return Exists(__argument) +@overload +def intersect( + *selects: _TypedSelectable[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +@overload +def intersect( + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + def intersect( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -198,9 +238,21 @@ def intersect( return CompoundSelect._create_intersect(*selects) +@overload +def intersect_all( + *selects: _TypedSelectable[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +@overload def intersect_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +def intersect_all( + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: r"""Return an ``INTERSECT ALL`` of multiple selectables. The returned object is an instance of @@ -569,9 +621,21 @@ def tablesample( return TableSample._factory(selectable, sampling, name=name, seed=seed) +@overload +def union( + *selects: _TypedSelectable[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +@overload def union( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +def union( + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -591,9 +655,21 @@ def union( return CompoundSelect._create_union(*selects) +@overload +def union_all( + *selects: _TypedSelectable[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + +@overload +def union_all( + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: ... + + def union_all( - *selects: _SelectStatementForCompoundArgument, -) -> CompoundSelect: + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], +) -> CompoundSelect[Unpack[_Ts]]: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index f46924bf83..6fef1766c6 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -31,6 +31,7 @@ from ..inspection import Inspectable from ..util.typing import Literal from ..util.typing import TupleAny from ..util.typing import TypeAlias +from ..util.typing import TypeVarTuple from ..util.typing import Unpack if TYPE_CHECKING: @@ -57,6 +58,7 @@ if TYPE_CHECKING: from .roles import FromClauseRole from .schema import Column from .selectable import Alias + from .selectable import CompoundSelect from .selectable import CTE from .selectable import FromClause from .selectable import Join @@ -75,6 +77,7 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) +_Ts = TypeVarTuple("_Ts") _CE = TypeVar("_CE", bound="ColumnElement[Any]") @@ -246,7 +249,9 @@ come from the ORM. """ _SelectStatementForCompoundArgument = Union[ - "SelectBase", roles.CompoundElementRole + "Select[Unpack[_Ts]]", + "CompoundSelect[Unpack[_Ts]]", + roles.CompoundElementRole, ] """SELECT statement acceptable by ``union()`` and other SQL set operations""" diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index cfe491e624..c3255a8f18 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -48,6 +48,8 @@ from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument from ._typing import _no_kw +from ._typing import _T +from ._typing import _Ts from ._typing import is_column_element from ._typing import is_select_statement from ._typing import is_subquery @@ -100,15 +102,11 @@ from ..util import HasMemoized_ro_memoized_attribute from ..util.typing import Literal from ..util.typing import Self from ..util.typing import TupleAny -from ..util.typing import TypeVarTuple from ..util.typing import Unpack and_ = BooleanClauseList.and_ -_T = TypeVar("_T", bound=Any) -_Ts = TypeVarTuple("_Ts") - if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -291,7 +289,7 @@ class ExecutableReturnsRows(Executable, ReturnsRows): class TypedReturnsRows(ExecutableReturnsRows, Generic[Unpack[_Ts]]): - """base for executable statements that return rows.""" + """base for a typed executable statements that return rows.""" class Selectable(ReturnsRows): @@ -2229,7 +2227,7 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, *other: _SelectStatementForCompoundArgument) -> CTE: + def union(self, *other: _SelectStatementForCompoundArgument[Any]) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` of the original CTE against the given selectables provided as positional arguments. @@ -2258,7 +2256,9 @@ class CTE( _suffixes=self._suffixes, ) - def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE: + def union_all( + self, *other: _SelectStatementForCompoundArgument[Any] + ) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` of the original CTE against the given selectables provided as positional arguments. @@ -4416,7 +4416,9 @@ class _CompoundSelectKeyword(Enum): INTERSECT_ALL = "INTERSECT ALL" -class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): +class CompoundSelect( + HasCompileState, GenerativeSelect, TypedReturnsRows[Unpack[_Ts]] +): """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations. @@ -4463,7 +4465,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): def __init__( self, keyword: _CompoundSelectKeyword, - *selects: _SelectStatementForCompoundArgument, + *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]], ): self.keyword = keyword self.selects = [ @@ -4477,38 +4479,38 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): @classmethod def _create_union( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: return CompoundSelect(_CompoundSelectKeyword.UNION, *selects) @classmethod def _create_union_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects) @classmethod def _create_except( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects) @classmethod def _create_except_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects) @classmethod def _create_intersect( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects) @classmethod def _create_intersect_all( - cls, *selects: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + cls, *selects: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects) def _scalar_type(self) -> TypeEngine[Any]: @@ -4525,7 +4527,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return True return False - def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect: + def set_label_style(self, style: SelectLabelStyle) -> Self: if self._label_style is not style: self = self._generate() select_0 = self.selects[0].set_label_style(style) @@ -4533,7 +4535,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): return self - def _ensure_disambiguated_names(self) -> CompoundSelect: + def _ensure_disambiguated_names(self) -> Self: new_select = self.selects[0]._ensure_disambiguated_names() if new_select is not self.selects[0]: self = self._generate() @@ -6572,8 +6574,8 @@ class Select( return SelectStatementGrouping(self) def union( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: r"""Return a SQL ``UNION`` of this select() construct against the given selectables provided as positional arguments. @@ -6591,8 +6593,8 @@ class Select( return CompoundSelect._create_union(self, *other) def union_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: r"""Return a SQL ``UNION ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -6610,8 +6612,8 @@ class Select( return CompoundSelect._create_union_all(self, *other) def except_( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: r"""Return a SQL ``EXCEPT`` of this select() construct against the given selectable provided as positional arguments. @@ -6626,8 +6628,8 @@ class Select( return CompoundSelect._create_except(self, *other) def except_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: r"""Return a SQL ``EXCEPT ALL`` of this select() construct against the given selectables provided as positional arguments. @@ -6642,8 +6644,8 @@ class Select( return CompoundSelect._create_except_all(self, *other) def intersect( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: r"""Return a SQL ``INTERSECT`` of this select() construct against the given selectables provided as positional arguments. @@ -6661,8 +6663,8 @@ class Select( return CompoundSelect._create_intersect(self, *other) def intersect_all( - self, *other: _SelectStatementForCompoundArgument - ) -> CompoundSelect: + self, *other: _SelectStatementForCompoundArgument[Unpack[_Ts]] + ) -> CompoundSelect[Unpack[_Ts]]: r"""Return a SQL ``INTERSECT ALL`` of this select() construct against the given selectables provided as positional arguments. diff --git a/test/typing/plain_files/sql/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py index 7c8001a728..3428a640df 100644 --- a/test/typing/plain_files/sql/common_sql_element.py +++ b/test/typing/plain_files/sql/common_sql_element.py @@ -11,14 +11,21 @@ from __future__ import annotations from sqlalchemy import asc from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import ColumnElement from sqlalchemy import desc +from sqlalchemy import except_ +from sqlalchemy import except_all from sqlalchemy import Integer +from sqlalchemy import intersect +from sqlalchemy import intersect_all from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import SQLColumnExpression from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import union +from sqlalchemy import union_all from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -176,3 +183,75 @@ mydict = { literal("5"): "q", column("q"): "q", } + +# compound selects (issue #11922): + +str_col = ColumnElement[str]() +int_col = ColumnElement[int]() + +first_stmt = select(str_col, int_col) +second_stmt = select(str_col, int_col) +third_stmt = select(int_col, str_col) + +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(union(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(union_all(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(except_(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(except_all(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(intersect(first_stmt, second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(intersect_all(first_stmt, second_stmt)) + +# EXPECTED_TYPE: Result[str, int] +reveal_type(Session().execute(union(first_stmt, second_stmt))) +# EXPECTED_TYPE: Result[str, int] +reveal_type(Session().execute(union_all(first_stmt, second_stmt))) + +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.union(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.union_all(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.except_(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.except_all(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.intersect(second_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.intersect_all(second_stmt)) + +# TODO: the following do not error because _SelectStatementForCompoundArgument +# includes untyped elements so the type checker falls back on them when +# the type does not match. Also for the standalone functions mypy +# looses the plot and returns a random type back. See TODO in the +# overloads + +# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] +reveal_type(union(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] +reveal_type(union_all(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] +reveal_type(except_(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] +reveal_type(except_all(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] +reveal_type(intersect(first_stmt, third_stmt)) +# EXPECTED_TYPE: CompoundSelect[Unpack[tuple[Never, ...]]] +reveal_type(intersect_all(first_stmt, third_stmt)) + +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.union(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.union_all(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.except_(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.except_all(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.intersect(third_stmt)) +# EXPECTED_TYPE: CompoundSelect[str, int] +reveal_type(first_stmt.intersect_all(third_stmt))