From: Yurii Karabas <1998uriyyo@gmail.com> Date: Wed, 15 Nov 2023 13:05:56 +0000 (+0200) Subject: More typing updates X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=546e5348fd83504324a6e92f6546f3b58a1f2aa1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git More typing updates --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 47597054c9..6fdcb04442 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1261,7 +1261,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @overload def scalar( self, - statement: TypedReturnsRows[Tuple[_T]], + statement: TypedReturnsRows[_T], parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, @@ -1310,7 +1310,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @overload def scalars( self, - statement: TypedReturnsRows[Tuple[_T]], + statement: TypedReturnsRows[_T], parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, @@ -1355,7 +1355,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @overload def execute( self, - statement: TypedReturnsRows[Tuple[Unpack[_Ts]]], + statement: TypedReturnsRows[Unpack[_Ts]], parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index b854e4b94f..214cc1050c 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -454,12 +454,14 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[Unpack[_Ts]]]): @overload def scalars( - self: AsyncResult[Tuple[_T]], index: Literal[0] + self: AsyncResult[_T, Unpack[Tuple[Any, ...]]], index: Literal[0] ) -> AsyncScalarResult[_T]: ... @overload - def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: + def scalars( + self: AsyncResult[_T, Unpack[Tuple[Any, ...]]], + ) -> AsyncScalarResult[_T]: ... @overload diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 0f7f93ecec..7d5e1c612c 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -31,6 +31,7 @@ from ...util import create_proxy_methods from ...util import ScopedRegistry from ...util import warn from ...util import warn_deprecated +from ...util.typing import TypeVarTuple from ...util.typing import Unpack if TYPE_CHECKING: @@ -62,6 +63,7 @@ if TYPE_CHECKING: from ...sql.selectable import TypedReturnsRows _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") @create_proxy_methods( @@ -530,14 +532,14 @@ class async_scoped_session(Generic[_AS]): @overload async def execute( self, - statement: TypedReturnsRows[_T], + statement: TypedReturnsRows[Tuple[Unpack[_Ts]]], params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: + ) -> Result[Unpack[_Ts]]: ... @overload diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 6a78a0978e..2c93525167 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -38,8 +38,10 @@ from ...orm import Session from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import TypeVarTuple from ...util.typing import Unpack + if TYPE_CHECKING: from .engine import AsyncConnection from .engine import AsyncEngine @@ -73,7 +75,7 @@ if TYPE_CHECKING: _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] _T = TypeVar("_T", bound=Any) - +_Ts = TypeVarTuple("_Ts") _EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) @@ -392,14 +394,14 @@ class AsyncSession(ReversibleProxy[Session]): @overload async def execute( self, - statement: TypedReturnsRows[_T], + statement: TypedReturnsRows[Tuple[Unpack[_Ts]]], params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: + ) -> Result[Unpack[_Ts]]: ... @overload diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 2f5e4ce8b7..f9b57926a3 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -46,7 +46,6 @@ from ..sql import expression from ..sql import roles from ..sql import util as sql_util from ..sql import visitors -from ..sql._typing import _TP from ..sql._typing import is_dml from ..sql._typing import is_insert_update from ..sql._typing import is_select_base @@ -68,6 +67,9 @@ from ..sql.selectable import SelectLabelStyle from ..sql.selectable import SelectState from ..sql.selectable import TypedReturnsRows from ..sql.visitors import InternalTraversal +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack + if TYPE_CHECKING: from ._typing import _InternalEntityType @@ -91,6 +93,7 @@ if TYPE_CHECKING: from ..sql.type_api import TypeEngine _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") _path_registry = PathRegistry.root _EMPTY_DICT = util.immutabledict() @@ -147,7 +150,10 @@ class QueryContext: def __init__( self, compile_state: CompileState, - statement: Union[Select[Any], FromStatement[Any]], + statement: Union[ + Select[Unpack[Tuple[Any, ...]]], + FromStatement[Unpack[Tuple[Any, ...]]], + ], params: _CoreSingleExecuteParams, session: Session, load_options: Union[ @@ -401,8 +407,12 @@ class ORMCompileState(AbstractORMCompileState): attributes: Dict[Any, Any] global_attributes: Dict[Any, Any] - statement: Union[Select[Any], FromStatement[Any]] - select_statement: Union[Select[Any], FromStatement[Any]] + statement: Union[ + Select[Unpack[Tuple[Any, ...]]], FromStatement[Unpack[Tuple[Any, ...]]] + ] + select_statement: Union[ + Select[Unpack[Tuple[Any, ...]]], FromStatement[Unpack[Tuple[Any, ...]]] + ] _entities: List[_QueryEntity] _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] compile_options: Union[ @@ -857,7 +867,7 @@ class ORMFromStatementCompileState(ORMCompileState): entity.setup_dml_returning_compile_state(self, adapter) -class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): +class FromStatement(GroupedElement, Generative, TypedReturnsRows[Unpack[_Ts]]): """Core construct that represents a load of ORM objects from various :class:`.ReturnsRows` and other classes including: @@ -2434,7 +2444,9 @@ def _column_descriptions( def _legacy_filter_by_entity_zero( - query_or_augmented_select: Union[Query[Any], Select[Any]] + query_or_augmented_select: Union[ + Query[Any], Select[Unpack[Tuple[Any, ...]]] + ] ) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if self._setup_joins: @@ -2449,7 +2461,9 @@ def _legacy_filter_by_entity_zero( def _entity_from_pre_ent_zero( - query_or_augmented_select: Union[Query[Any], Select[Any]] + query_or_augmented_select: Union[ + Query[Any], Select[Unpack[Tuple[Any, ...]]] + ] ) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if not self._raw_columns: diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 5880aad06a..7925f25b1c 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -715,7 +715,7 @@ class CompositeProperty( def create_row_processor( self, - query: Select[Any], + query: Select[Unpack[Tuple[Any, ...]]], procs: Sequence[Callable[[Row[Unpack[Tuple[Any, ...]]]], Any]], labels: Sequence[str], ) -> Callable[[Row[Unpack[Tuple[Any, ...]]]], Any]: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5da7ee9b22..ae96f37b93 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -74,7 +74,6 @@ from ..sql import Select from ..sql import util as sql_util from ..sql import visitors from ..sql._typing import _FromClauseArgument -from ..sql._typing import _TP from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative @@ -93,6 +92,8 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectLabelStyle from ..util.typing import Literal from ..util.typing import Self +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack if TYPE_CHECKING: @@ -150,6 +151,7 @@ if TYPE_CHECKING: __all__ = ["Query", "QueryContext"] _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") @inspection._self_inspects @@ -533,7 +535,9 @@ class Query( return stmt - def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]: + def _final_statement( + self, legacy_query_style: bool = True + ) -> Select[Unpack[Tuple[Any, ...]]]: """Return the 'final' SELECT statement for this :class:`.Query`. This is used by the testing suite only and is fairly inefficient. @@ -822,7 +826,7 @@ class Query( @overload def only_return_tuples( self: Query[_O], value: Literal[True] - ) -> RowReturningQuery[Tuple[_O]]: + ) -> RowReturningQuery[_O]: ... @overload @@ -1493,13 +1497,13 @@ class Query( @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ) -> RowReturningQuery[_T0, _T1]: ... @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ) -> RowReturningQuery[_T0, _T1, _T2]: ... @overload @@ -1509,7 +1513,7 @@ class Query( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: ... @overload @@ -1520,7 +1524,7 @@ class Query( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -1532,7 +1536,7 @@ class Query( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -1545,7 +1549,7 @@ class Query( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -1559,7 +1563,7 @@ class Query( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... # END OVERLOADED FUNCTIONS self.with_entities @@ -3407,8 +3411,8 @@ class BulkDelete(BulkUD): """BulkUD which handles DELETEs.""" -class RowReturningQuery(Query[Row[_TP]]): +class RowReturningQuery(Query[Row[Unpack[_Ts]]]): if TYPE_CHECKING: - def tuples(self) -> Query[_TP]: # type: ignore + def tuples(self) -> Query[tuple[Unpack[_Ts]]]: # type: ignore ... diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index c5cdeac171..95e47fd878 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -1584,7 +1584,7 @@ class scoped_session(Generic[_S]): @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[Tuple[_T]]: + ) -> RowReturningQuery[_T]: ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -1595,13 +1595,13 @@ class scoped_session(Generic[_S]): @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ) -> RowReturningQuery[_T0, _T1]: ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ) -> RowReturningQuery[_T0, _T1, _T2]: ... @overload @@ -1611,7 +1611,7 @@ class scoped_session(Generic[_S]): __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: ... @overload @@ -1622,7 +1622,7 @@ class scoped_session(Generic[_S]): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -1634,7 +1634,7 @@ class scoped_session(Generic[_S]): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -1647,7 +1647,7 @@ class scoped_session(Generic[_S]): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -1661,7 +1661,7 @@ class scoped_session(Generic[_S]): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... # END OVERLOADED FUNCTIONS self.query diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index eb4330b17c..0f3d441ee5 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2809,7 +2809,7 @@ class Session(_SessionClassMethods, EventTarget): @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[Tuple[_T]]: + ) -> RowReturningQuery[_T]: ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -2820,13 +2820,13 @@ class Session(_SessionClassMethods, EventTarget): @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ) -> RowReturningQuery[_T0, _T1]: ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ) -> RowReturningQuery[_T0, _T1, _T2]: ... @overload @@ -2836,7 +2836,7 @@ class Session(_SessionClassMethods, EventTarget): __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: ... @overload @@ -2847,7 +2847,7 @@ class Session(_SessionClassMethods, EventTarget): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -2859,7 +2859,7 @@ class Session(_SessionClassMethods, EventTarget): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -2872,7 +2872,7 @@ class Session(_SessionClassMethods, EventTarget): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -2886,7 +2886,7 @@ class Session(_SessionClassMethods, EventTarget): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... # END OVERLOADED FUNCTIONS self.query @@ -3127,7 +3127,9 @@ class Session(_SessionClassMethods, EventTarget): with_for_update = ForUpdateArg._from_argument(with_for_update) - stmt: Select[Any] = sql.select(object_mapper(instance)) + stmt: Select[Unpack[Tuple[Any, ...]]] = sql.select( + object_mapper(instance) + ) if ( loading.load_on_ident( self, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ae915ce3b2..a2f202aa4c 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1722,7 +1722,7 @@ class Bundle( def create_row_processor( self, - query: Select[Any], + query: Select[Unpack[Tuple[Any, ...]]], procs: Sequence[Callable[[Row[Unpack[Tuple[Any, ...]]]], Any]], labels: Sequence[str], ) -> Callable[[Row[Unpack[Tuple[Any, ...]]]], Any]: diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 41e8b6eb16..59f851f1d3 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -32,6 +32,7 @@ from .selectable import Select from .selectable import TableClause from .selectable import TableSample from .selectable import Values +from ..util.typing import Unpack if TYPE_CHECKING: from ._typing import _FromClauseArgument @@ -330,19 +331,19 @@ def outerjoin( @overload -def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: +def select(__ent0: _TCCA[_T0]) -> Select[_T0]: ... @overload -def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[_T0, _T1]: ... @overload def select( __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] -) -> Select[Tuple[_T0, _T1, _T2]]: +) -> Select[_T0, _T1, _T2]: ... @@ -352,7 +353,7 @@ def select( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], -) -> Select[Tuple[_T0, _T1, _T2, _T3]]: +) -> Select[_T0, _T1, _T2, _T3]: ... @@ -363,7 +364,7 @@ def select( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: +) -> Select[_T0, _T1, _T2, _T3, _T4]: ... @@ -375,7 +376,7 @@ def select( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5]: ... @@ -388,7 +389,7 @@ def select( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @@ -402,7 +403,7 @@ def select( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... @@ -417,7 +418,7 @@ def select( __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], __ent8: _TCCA[_T8], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]: ... @@ -433,7 +434,7 @@ def select( __ent7: _TCCA[_T7], __ent8: _TCCA[_T8], __ent9: _TCCA[_T9], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]: ... @@ -441,11 +442,15 @@ def select( @overload -def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: +def select( + *entities: _ColumnsClauseArgument[Any], **__kw: Any +) -> Select[Unpack[Tuple[Any, ...]]]: ... -def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: +def select( + *entities: _ColumnsClauseArgument[Any], **__kw: Any +) -> Select[Unpack[Tuple[Any, ...]]]: r"""Construct a new :class:`_expression.Select`. diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index c9e183058e..4ecc44700e 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -29,6 +29,7 @@ from ..inspection import Inspectable from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import TypeAlias +from ..util.typing import Unpack if TYPE_CHECKING: from datetime import date @@ -319,7 +320,7 @@ if TYPE_CHECKING: def is_select_statement( t: Union[Executable, ReturnsRows] - ) -> TypeGuard[Select[Any]]: + ) -> TypeGuard[Select[Unpack[Tuple[Any, ...]]]]: ... def is_table(t: FromClause) -> TypeGuard[TableClause]: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cb6899c5e9..75b9a6a3ff 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -88,6 +88,7 @@ from ..util import FastIntFlag from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import TypedDict +from ..util.typing import Unpack if typing.TYPE_CHECKING: from .annotation import _AnnotationDict @@ -405,7 +406,7 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): need_result_map_for_nested: bool need_result_map_for_compound: bool select_0: ReturnsRows - insert_from_select: Select[Any] + insert_from_select: Select[Unpack[Tuple[Any, ...]]] class ExpandedState(NamedTuple): @@ -4767,7 +4768,7 @@ class SQLCompiler(Compiled): return text def _setup_select_hints( - self, select: Select[Any] + self, select: Select[Unpack[Tuple[Any, ...]]] ) -> Tuple[str, _FromHintsType]: byfrom = { from_: hinttext diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 4ca6ed338f..9a63a4123f 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -32,7 +32,6 @@ from typing import Union from . import coercions from . import roles from . import util as sql_util -from ._typing import _TP from ._typing import _unexpected_kw from ._typing import is_column_element from ._typing import is_named_from_clause @@ -67,6 +66,9 @@ from .. import exc from .. import util from ..util.typing import Self from ..util.typing import TypeGuard +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack + if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -107,6 +109,7 @@ else: _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") _DMLColumnElement = Union[str, ColumnClause[Any]] _DMLTableElement = Union[TableClause, Alias, Join] @@ -960,7 +963,7 @@ class ValuesBase(UpdateBase): _supports_multi_parameters = False - select: Optional[Select[Any]] = None + select: Optional[Select[Unpack[Tuple[Any, ...]]]] = None """SELECT statement for INSERT .. FROM SELECT""" _post_values_clause: Optional[ClauseElement] = None @@ -1295,7 +1298,7 @@ class Insert(ValuesBase): @overload def returning( self, __ent0: _TCCA[_T0], *, sort_by_parameter_order: bool = False - ) -> ReturningInsert[Tuple[_T0]]: + ) -> ReturningInsert[_T0]: ... @overload @@ -1305,7 +1308,7 @@ class Insert(ValuesBase): __ent1: _TCCA[_T1], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1]]: + ) -> ReturningInsert[_T0, _T1]: ... @overload @@ -1316,7 +1319,7 @@ class Insert(ValuesBase): __ent2: _TCCA[_T2], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: + ) -> ReturningInsert[_T0, _T1, _T2]: ... @overload @@ -1328,7 +1331,7 @@ class Insert(ValuesBase): __ent3: _TCCA[_T3], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: + ) -> ReturningInsert[_T0, _T1, _T2, _T3]: ... @overload @@ -1341,7 +1344,7 @@ class Insert(ValuesBase): __ent4: _TCCA[_T4], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -1355,7 +1358,7 @@ class Insert(ValuesBase): __ent5: _TCCA[_T5], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -1370,7 +1373,7 @@ class Insert(ValuesBase): __ent6: _TCCA[_T6], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -1386,7 +1389,7 @@ class Insert(ValuesBase): __ent7: _TCCA[_T7], *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... # END OVERLOADED FUNCTIONS self.returning @@ -1409,7 +1412,7 @@ class Insert(ValuesBase): ... -class ReturningInsert(Insert, TypedReturnsRows[_TP]): +class ReturningInsert(Insert, TypedReturnsRows[Unpack[_Ts]]): """Typing-only class that establishes a generic type form of :class:`.Insert` which tracks returned column types. @@ -1596,19 +1599,19 @@ class Update(DMLWhereBase, ValuesBase): # statically generated** by tools/generate_tuple_map_overloads.py @overload - def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]: + def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[_T0]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> ReturningUpdate[Tuple[_T0, _T1]]: + ) -> ReturningUpdate[_T0, _T1]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: + ) -> ReturningUpdate[_T0, _T1, _T2]: ... @overload @@ -1618,7 +1621,7 @@ class Update(DMLWhereBase, ValuesBase): __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: + ) -> ReturningUpdate[_T0, _T1, _T2, _T3]: ... @overload @@ -1629,7 +1632,7 @@ class Update(DMLWhereBase, ValuesBase): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -1641,7 +1644,7 @@ class Update(DMLWhereBase, ValuesBase): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -1654,7 +1657,7 @@ class Update(DMLWhereBase, ValuesBase): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -1668,7 +1671,7 @@ class Update(DMLWhereBase, ValuesBase): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... # END OVERLOADED FUNCTIONS self.returning @@ -1685,7 +1688,7 @@ class Update(DMLWhereBase, ValuesBase): ... -class ReturningUpdate(Update, TypedReturnsRows[_TP]): +class ReturningUpdate(Update, TypedReturnsRows[Unpack[_Ts]]): """Typing-only class that establishes a generic type form of :class:`.Update` which tracks returned column types. @@ -1734,19 +1737,19 @@ class Delete(DMLWhereBase, UpdateBase): # statically generated** by tools/generate_tuple_map_overloads.py @overload - def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]: + def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[_T0]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> ReturningDelete[Tuple[_T0, _T1]]: + ) -> ReturningDelete[_T0, _T1]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: + ) -> ReturningDelete[_T0, _T1, _T2]: ... @overload @@ -1756,7 +1759,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: + ) -> ReturningDelete[_T0, _T1, _T2, _T3]: ... @overload @@ -1767,7 +1770,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -1779,7 +1782,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -1792,7 +1795,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -1806,7 +1809,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... # END OVERLOADED FUNCTIONS self.returning @@ -1814,16 +1817,16 @@ class Delete(DMLWhereBase, UpdateBase): @overload def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningDelete[Any]: + ) -> ReturningDelete[Unpack[Tuple[Any, ...]]]: ... def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningDelete[Any]: + ) -> ReturningDelete[Unpack[Tuple[Any, ...]]]: ... -class ReturningDelete(Update, TypedReturnsRows[_TP]): +class ReturningDelete(Update, TypedReturnsRows[Unpack[_Ts]]): """Typing-only class that establishes a generic type form of :class:`.Delete` which tracks returned column types. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fc23e9d215..4b680033d7 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -59,6 +59,7 @@ from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util +from ..util.typing import Unpack if TYPE_CHECKING: @@ -647,7 +648,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): joins_implicitly=joins_implicitly, ) - def select(self) -> Select[Any]: + def select(self) -> Select[Unpack[Tuple[Any, ...]]]: """Produce a :func:`_expression.select` construct against this :class:`.FunctionElement`. @@ -656,7 +657,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): s = select(function_element) """ - s: Select[Any] = Select(self) + s: Select[Unpack[Tuple[Any, ...]]] = Select(self) if self._execution_options: s = s.execution_options(**self._execution_options) return s diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 91b939e0af..b21132697b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -47,7 +47,6 @@ from . import type_api from . import visitors from ._typing import _ColumnsClauseArgument from ._typing import _no_kw -from ._typing import _TP from ._typing import is_column_element from ._typing import is_select_statement from ._typing import is_subquery @@ -100,10 +99,15 @@ from ..util import HasMemoized_ro_memoized_attribute from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack + and_ = BooleanClauseList.and_ _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") + if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -283,7 +287,7 @@ class ExecutableReturnsRows(Executable, ReturnsRows): """base for executable statements that return rows.""" -class TypedReturnsRows(ExecutableReturnsRows, Generic[_TP]): +class TypedReturnsRows(ExecutableReturnsRows, Generic[Unpack[_Ts]]): """base for executable statements that return rows.""" @@ -610,7 +614,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): _use_schema_map = False - def select(self) -> Select[Any]: + def select(self) -> Select[Unpack[Tuple[Any, ...]]]: r"""Return a SELECT of this :class:`_expression.FromClause`. @@ -1496,7 +1500,7 @@ class Join(roles.DMLTableRole, FromClause): "join explicitly." % (a.description, b.description) ) - def select(self) -> Select[Any]: + def select(self) -> Select[Unpack[Tuple[Any, ...]]]: r"""Create a :class:`_expression.Select` from this :class:`_expression.Join`. @@ -2052,7 +2056,7 @@ class CTE( def _init( self, - selectable: Select[Any], + selectable: Select[Unpack[Tuple[Any, ...]]], *, name: Optional[str] = None, recursive: bool = False, @@ -3477,7 +3481,7 @@ class SelectBase( "first in order to create " "a subquery, which then can be selected.", ) - def select(self, *arg: Any, **kw: Any) -> Select[Any]: + def select(self, *arg: Any, **kw: Any) -> Select[Unpack[Tuple[Any, ...]]]: return self._implicit_subquery.select(*arg, **kw) @HasMemoized.memoized_attribute @@ -4490,7 +4494,7 @@ class SelectState(util.MemoizedSlots, CompileState): def __init__( self, - statement: Select[Any], + statement: Select[Unpack[Tuple[Any, ...]]], compiler: Optional[SQLCompiler], **kw: Any, ): @@ -4518,7 +4522,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def get_column_descriptions( - cls, statement: Select[Any] + cls, statement: Select[Unpack[Tuple[Any, ...]]] ) -> List[Dict[str, Any]]: return [ { @@ -4533,13 +4537,15 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def from_statement( - cls, statement: Select[Any], from_statement: roles.ReturnsRowsRole + cls, + statement: Select[Unpack[Tuple[Any, ...]]], + from_statement: roles.ReturnsRowsRole, ) -> ExecutableReturnsRows: cls._plugin_not_implemented() @classmethod def get_columns_clause_froms( - cls, statement: Select[Any] + cls, statement: Select[Unpack[Tuple[Any, ...]]] ) -> List[FromClause]: return cls._normalize_froms( itertools.chain.from_iterable( @@ -4594,7 +4600,9 @@ class SelectState(util.MemoizedSlots, CompileState): return go - def _get_froms(self, statement: Select[Any]) -> List[FromClause]: + def _get_froms( + self, statement: Select[Unpack[Tuple[Any, ...]]] + ) -> List[FromClause]: ambiguous_table_name_map: _AmbiguousTableNameMap self._ambiguous_table_name_map = ambiguous_table_name_map = {} @@ -4622,7 +4630,7 @@ class SelectState(util.MemoizedSlots, CompileState): def _normalize_froms( cls, iterable_of_froms: Iterable[FromClause], - check_statement: Optional[Select[Any]] = None, + check_statement: Optional[Select[Unpack[Tuple[Any, ...]]]] = None, ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, ) -> List[FromClause]: """given an iterable of things to select FROM, reduce them to what @@ -4767,7 +4775,7 @@ class SelectState(util.MemoizedSlots, CompileState): @classmethod def determine_last_joined_entity( - cls, stmt: Select[Any] + cls, stmt: Select[Unpack[Tuple[Any, ...]]] ) -> Optional[_JoinTargetElement]: if stmt._setup_joins: return stmt._setup_joins[-1][0] @@ -4775,7 +4783,9 @@ class SelectState(util.MemoizedSlots, CompileState): return None @classmethod - def all_selected_columns(cls, statement: Select[Any]) -> _SelectIterable: + def all_selected_columns( + cls, statement: Select[Unpack[Tuple[Any, ...]]] + ) -> _SelectIterable: return [c for c in _select_iterables(statement._raw_columns)] def _setup_joins( @@ -5021,7 +5031,9 @@ class _MemoizedSelectEntities( return c @classmethod - def _generate_for_statement(cls, select_stmt: Select[Any]) -> None: + def _generate_for_statement( + cls, select_stmt: Select[Unpack[Tuple[Any, ...]]] + ) -> None: if select_stmt._setup_joins or select_stmt._with_options: self = _MemoizedSelectEntities() self._raw_columns = select_stmt._raw_columns @@ -5040,7 +5052,7 @@ class Select( HasCompileState, _SelectFromElements, GenerativeSelect, - TypedReturnsRows[_TP], + TypedReturnsRows[Unpack[_Ts]], ): """Represents a ``SELECT`` statement. @@ -5114,7 +5126,7 @@ class Select( _compile_state_factory: Type[SelectState] @classmethod - def _create_raw_select(cls, **kw: Any) -> Select[Any]: + def _create_raw_select(cls, **kw: Any) -> Select[Unpack[Tuple[Any, ...]]]: """Create a :class:`.Select` using raw ``__new__`` with no coercions. Used internally to build up :class:`.Select` constructs with @@ -5176,13 +5188,13 @@ class Select( @overload def scalar_subquery( - self: Select[Tuple[_MAYBE_ENTITY]], + self: Select[_MAYBE_ENTITY], ) -> ScalarSelect[Any]: ... @overload def scalar_subquery( - self: Select[Tuple[_NOT_ENTITY]], + self: Select[_NOT_ENTITY], ) -> ScalarSelect[_NOT_ENTITY]: ... @@ -5664,7 +5676,7 @@ class Select( @_generative def add_columns( self, *entities: _ColumnsClauseArgument[Any] - ) -> Select[Any]: + ) -> Select[Unpack[Tuple[Any, ...]]]: r"""Return a new :func:`_expression.select` construct with the given entities appended to its columns clause. @@ -5714,7 +5726,9 @@ class Select( "be removed in a future release. Please use " ":meth:`_expression.Select.add_columns`", ) - def column(self, column: _ColumnsClauseArgument[Any]) -> Select[Any]: + def column( + self, column: _ColumnsClauseArgument[Any] + ) -> Select[Unpack[Tuple[Any, ...]]]: """Return a new :func:`_expression.select` construct with the given column expression added to its columns clause. @@ -5731,7 +5745,9 @@ class Select( return self.add_columns(column) @util.preload_module("sqlalchemy.sql.util") - def reduce_columns(self, only_synonyms: bool = True) -> Select[Any]: + def reduce_columns( + self, only_synonyms: bool = True + ) -> Select[Unpack[Tuple[Any, ...]]]: """Return a new :func:`_expression.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -5754,7 +5770,7 @@ class Select( all columns that are equivalent to another are removed. """ - woc: Select[Any] + woc: Select[Unpack[Tuple[Any, ...]]] woc = self.with_only_columns( *util.preloaded.sql_util.reduce_columns( self._all_selected_columns, @@ -5853,7 +5869,7 @@ class Select( *entities: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, **__kw: Any, - ) -> Select[Any]: + ) -> Select[Unpack[Tuple[Any, ...]]]: ... @_generative @@ -5862,7 +5878,7 @@ class Select( *entities: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, **__kw: Any, - ) -> Select[Any]: + ) -> Select[Unpack[Tuple[Any, ...]]]: r"""Return a new :func:`_expression.select` construct with its columns clause replaced with the given entities. @@ -6255,7 +6271,7 @@ class Select( meth = SelectState.get_plugin_class(self).all_selected_columns return list(meth(self)) - def _ensure_disambiguated_names(self) -> Select[Any]: + def _ensure_disambiguated_names(self) -> Select[Unpack[Tuple[Any, ...]]]: if self._label_style is LABEL_STYLE_NONE: self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY) return self @@ -6515,7 +6531,9 @@ class ScalarSelect( by this :class:`_expression.ScalarSelect`. """ - self.element = cast("Select[Any]", self.element).where(crit) + self.element = cast( + "Select[Unpack[Tuple[Any, ...]]]", self.element + ).where(crit) return self @overload @@ -6537,7 +6555,7 @@ class ScalarSelect( if TYPE_CHECKING: - def _ungroup(self) -> Select[Any]: + def _ungroup(self) -> Select[Unpack[Tuple[Any, ...]]]: ... @_generative @@ -6571,9 +6589,9 @@ class ScalarSelect( """ - self.element = cast("Select[Any]", self.element).correlate( - *fromclauses - ) + self.element = cast( + "Select[Unpack[Tuple[Any, ...]]]", self.element + ).correlate(*fromclauses) return self @_generative @@ -6609,9 +6627,9 @@ class ScalarSelect( """ - self.element = cast("Select[Any]", self.element).correlate_except( - *fromclauses - ) + self.element = cast( + "Select[Unpack[Tuple[Any, ...]]]", self.element + ).correlate_except(*fromclauses) return self @@ -6626,7 +6644,10 @@ class Exists(UnaryExpression[bool]): """ inherit_cache = True - element: Union[SelectStatementGrouping[Select[Any]], ScalarSelect[Any]] + element: Union[ + SelectStatementGrouping[Select[Unpack[Tuple[Any, ...]]]], + ScalarSelect[Any], + ] def __init__( self, @@ -6660,8 +6681,11 @@ class Exists(UnaryExpression[bool]): return [] def _regroup( - self, fn: Callable[[Select[Any]], Select[Any]] - ) -> SelectStatementGrouping[Select[Any]]: + self, + fn: Callable[ + [Select[Unpack[Tuple[Any, ...]]]], Select[Unpack[Tuple[Any, ...]]] + ], + ) -> SelectStatementGrouping[Select[Unpack[Tuple[Any, ...]]]]: element = self.element._ungroup() new_element = fn(element) @@ -6669,7 +6693,7 @@ class Exists(UnaryExpression[bool]): assert isinstance(return_value, SelectStatementGrouping) return return_value - def select(self) -> Select[Any]: + def select(self) -> Select[Unpack[Tuple[Any, ...]]]: r"""Return a SELECT of this :class:`_expression.Exists`. e.g.:: diff --git a/test/typing/plain_files/orm/typed_queries.py b/test/typing/plain_files/orm/typed_queries.py index 7d8a2dd1a3..e92df3eb47 100644 --- a/test/typing/plain_files/orm/typed_queries.py +++ b/test/typing/plain_files/orm/typed_queries.py @@ -58,7 +58,7 @@ def t_select_1() -> None: result = session.execute(stmt) - # EXPECTED_TYPE: Result[Tuple[int, str]] + # EXPECTED_TYPE: Result[int, str] reveal_type(result) @@ -82,7 +82,7 @@ def t_select_2() -> None: result = session.execute(stmt) - # EXPECTED_TYPE: Result[Tuple[User]] + # EXPECTED_TYPE: Result[User] reveal_type(result) @@ -107,7 +107,7 @@ def t_select_3() -> None: result = session.execute(stmt) - # EXPECTED_TYPE: Result[Tuple[int, str]] + # EXPECTED_TYPE: Result[int, str] reveal_type(result) @@ -120,7 +120,7 @@ def t_select_4() -> None: result = session.execute(stmt) - # EXPECTED_TYPE: Result[Tuple[User, User]] + # EXPECTED_TYPE: Result[User, User] reveal_type(result) diff --git a/tools/generate_tuple_map_overloads.py b/tools/generate_tuple_map_overloads.py index 476636b1d0..c0754c67e4 100644 --- a/tools/generate_tuple_map_overloads.py +++ b/tools/generate_tuple_map_overloads.py @@ -95,7 +95,7 @@ def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: @overload def {current_fnname}( {'self, ' if use_self else ''}{", ".join(combination)}{extra_args} -) -> {return_type}[Tuple[{', '.join(f'_T{i}' for i in range(num_args))}]]: +) -> {return_type}[{', '.join(f'_T{i}' for i in range(num_args))}]: ... """, # noqa: E501