--- /dev/null
+.. change::
+ :tags: typing, usecase
+ :tickets: 11922
+
+ Support generic types for compound selects (:func:`_sql.union`,
+ :func:`_sql.union_all`, :meth:`_sql.Select.union`,
+ :meth:`_sql.Select.union_all`, etc) returning the type of the first select.
+ Pull request courtesy of Mingyu Park.
from typing import overload
from typing import Tuple
from typing import TYPE_CHECKING
-from typing import TypeVar
from typing import Union
from . import coercions
from ._typing import _T7
from ._typing import _T8
from ._typing import _T9
+ from ._typing import _TP
from ._typing import _TypedColumnClauseArgument as _TCCA
from .functions import Function
from .selectable import CTE
from .selectable import SelectBase
-_T = TypeVar("_T", bound=Any)
-
-
def alias(
selectable: FromClause, name: Optional[str] = None, flat: bool = False
) -> NamedFromClause:
)
+# TODO: mypy requires the _TypedSelectable overloads in all compound select
+# constructors since _SelectStatementForCompoundArgument includes
+# untyped args that make it return CompoundSelect[Unpack[tuple[Never, ...]]]
+# pyright does not have this issue
+_TypedSelectable = Union["Select[_TP]", "CompoundSelect[_TP]"]
+
+
+@overload
def except_(
- *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+ *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def except_(
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+def except_(
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
r"""Return an ``EXCEPT`` of multiple selectables.
The returned object is an instance of
return CompoundSelect._create_except(*selects)
+@overload
+def except_all(
+ *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def except_all(
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
def except_all(
- *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
r"""Return an ``EXCEPT ALL`` of multiple selectables.
The returned object is an instance of
return Exists(__argument)
+@overload
+def intersect(
+ *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def intersect(
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
def intersect(
- *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
r"""Return an ``INTERSECT`` of multiple selectables.
The returned object is an instance of
return CompoundSelect._create_intersect(*selects)
+@overload
+def intersect_all(
+ *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
def intersect_all(
- *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+def intersect_all(
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
r"""Return an ``INTERSECT ALL`` of multiple selectables.
The returned object is an instance of
return TableSample._factory(selectable, sampling, name=name, seed=seed)
+@overload
+def union(
+ *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
def union(
- *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+def union(
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
r"""Return a ``UNION`` of multiple selectables.
The returned object is an instance of
return CompoundSelect._create_union(*selects)
+@overload
+def union_all(
+ *selects: _TypedSelectable[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
+@overload
+def union_all(
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]: ...
+
+
def union_all(
- *selects: _SelectStatementForCompoundArgument,
-) -> CompoundSelect:
+ *selects: _SelectStatementForCompoundArgument[_TP],
+) -> CompoundSelect[_TP]:
r"""Return a ``UNION ALL`` of multiple selectables.
The returned object is an instance of
from .roles import FromClauseRole
from .schema import Column
from .selectable import Alias
+ from .selectable import CompoundSelect
from .selectable import CTE
from .selectable import FromClause
from .selectable import Join
"""
_SelectStatementForCompoundArgument = Union[
- "SelectBase", roles.CompoundElementRole
+ "Select[_TP]",
+ "CompoundSelect[_TP]",
+ roles.CompoundElementRole,
]
"""SELECT statement acceptable by ``union()`` and other SQL set operations"""
from . import visitors
from ._typing import _ColumnsClauseArgument
from ._typing import _no_kw
+from ._typing import _T
from ._typing import _TP
from ._typing import is_column_element
from ._typing import is_select_statement
from ..util.typing import Protocol
from ..util.typing import Self
+
and_ = BooleanClauseList.and_
-_T = TypeVar("_T", bound=Any)
if TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]):
- """base for executable statements that return rows."""
+ """base for a typed executable statements that return rows."""
class Selectable(ReturnsRows):
_suffixes=self._suffixes,
)
- def union(self, *other: _SelectStatementForCompoundArgument) -> CTE:
+ def union(self, *other: _SelectStatementForCompoundArgument[Any]) -> CTE:
r"""Return a new :class:`_expression.CTE` with a SQL ``UNION``
of the original CTE against the given selectables provided
as positional arguments.
_suffixes=self._suffixes,
)
- def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE:
+ def union_all(
+ self, *other: _SelectStatementForCompoundArgument[Any]
+ ) -> CTE:
r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL``
of the original CTE against the given selectables provided
as positional arguments.
INTERSECT_ALL = "INTERSECT ALL"
-class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
+class CompoundSelect(HasCompileState, GenerativeSelect, TypedReturnsRows[_TP]):
"""Forms the basis of ``UNION``, ``UNION ALL``, and other
SELECT-based set operations.
def __init__(
self,
keyword: _CompoundSelectKeyword,
- *selects: _SelectStatementForCompoundArgument,
+ *selects: _SelectStatementForCompoundArgument[_TP],
):
self.keyword = keyword
self.selects = [
@classmethod
def _create_union(
- cls, *selects: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ cls, *selects: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
return CompoundSelect(_CompoundSelectKeyword.UNION, *selects)
@classmethod
def _create_union_all(
- cls, *selects: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ cls, *selects: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
return CompoundSelect(_CompoundSelectKeyword.UNION_ALL, *selects)
@classmethod
def _create_except(
- cls, *selects: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ cls, *selects: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
return CompoundSelect(_CompoundSelectKeyword.EXCEPT, *selects)
@classmethod
def _create_except_all(
- cls, *selects: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ cls, *selects: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
return CompoundSelect(_CompoundSelectKeyword.EXCEPT_ALL, *selects)
@classmethod
def _create_intersect(
- cls, *selects: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ cls, *selects: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
return CompoundSelect(_CompoundSelectKeyword.INTERSECT, *selects)
@classmethod
def _create_intersect_all(
- cls, *selects: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ cls, *selects: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
return CompoundSelect(_CompoundSelectKeyword.INTERSECT_ALL, *selects)
def _scalar_type(self) -> TypeEngine[Any]:
return True
return False
- def set_label_style(self, style: SelectLabelStyle) -> CompoundSelect:
+ def set_label_style(self, style: SelectLabelStyle) -> Self:
if self._label_style is not style:
self = self._generate()
select_0 = self.selects[0].set_label_style(style)
return self
- def _ensure_disambiguated_names(self) -> CompoundSelect:
+ def _ensure_disambiguated_names(self) -> Self:
new_select = self.selects[0]._ensure_disambiguated_names()
if new_select is not self.selects[0]:
self = self._generate()
return SelectStatementGrouping(self)
def union(
- self, *other: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ self, *other: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
r"""Return a SQL ``UNION`` of this select() construct against
the given selectables provided as positional arguments.
return CompoundSelect._create_union(self, *other)
def union_all(
- self, *other: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ self, *other: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
r"""Return a SQL ``UNION ALL`` of this select() construct against
the given selectables provided as positional arguments.
return CompoundSelect._create_union_all(self, *other)
def except_(
- self, *other: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ self, *other: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
r"""Return a SQL ``EXCEPT`` of this select() construct against
the given selectable provided as positional arguments.
return CompoundSelect._create_except(self, *other)
def except_all(
- self, *other: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ self, *other: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
r"""Return a SQL ``EXCEPT ALL`` of this select() construct against
the given selectables provided as positional arguments.
return CompoundSelect._create_except_all(self, *other)
def intersect(
- self, *other: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ self, *other: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
r"""Return a SQL ``INTERSECT`` of this select() construct against
the given selectables provided as positional arguments.
return CompoundSelect._create_intersect(self, *other)
def intersect_all(
- self, *other: _SelectStatementForCompoundArgument
- ) -> CompoundSelect:
+ self, *other: _SelectStatementForCompoundArgument[_TP]
+ ) -> CompoundSelect[_TP]:
r"""Return a SQL ``INTERSECT ALL`` of this select() construct
against the given selectables provided as positional arguments.
from sqlalchemy import asc
from sqlalchemy import Column
from sqlalchemy import column
+from sqlalchemy import ColumnElement
from sqlalchemy import desc
+from sqlalchemy import except_
+from sqlalchemy import except_all
from sqlalchemy import Integer
+from sqlalchemy import intersect
+from sqlalchemy import intersect_all
from sqlalchemy import literal
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import SQLColumnExpression
from sqlalchemy import String
from sqlalchemy import Table
+from sqlalchemy import union
+from sqlalchemy import union_all
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
literal("5"): "q",
column("q"): "q",
}
+
+# compound selects (issue #11922):
+
+str_col = ColumnElement[str]()
+int_col = ColumnElement[int]()
+
+first_stmt = select(str_col, int_col)
+second_stmt = select(str_col, int_col)
+third_stmt = select(int_col, str_col)
+
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(union(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(union_all(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(except_(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(except_all(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(intersect(first_stmt, second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(intersect_all(first_stmt, second_stmt))
+
+# EXPECTED_TYPE: Result[tuple[str, int]]
+reveal_type(Session().execute(union(first_stmt, second_stmt)))
+# EXPECTED_TYPE: Result[tuple[str, int]]
+reveal_type(Session().execute(union_all(first_stmt, second_stmt)))
+
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union_all(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_all(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect(second_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect_all(second_stmt))
+
+# TODO: the following do not error because _SelectStatementForCompoundArgument
+# includes untyped elements so the type checker falls back on them when
+# the type does not match. Also for the standalone functions mypy
+# looses the plot and returns a random type back. See TODO in the
+# overloads
+
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(union(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(union_all(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(except_(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(except_all(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(intersect(first_stmt, third_stmt))
+# EXPECTED_TYPE: CompoundSelect[Never]
+reveal_type(intersect_all(first_stmt, third_stmt))
+
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.union_all(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.except_all(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect(third_stmt))
+# EXPECTED_TYPE: CompoundSelect[tuple[str, int]]
+reveal_type(first_stmt.intersect_all(third_stmt))