From: Yurii Karabas <1998uriyyo@gmail.com> Date: Tue, 14 Nov 2023 16:55:24 +0000 (+0200) Subject: update typing X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=06ca301ef129640d3ebf67da1c04a50e902cf979;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git update typing --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 0000e28103..47597054c9 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -43,6 +43,8 @@ from .. import log from .. import util from ..sql import compiler from ..sql import util as sql_util +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack if typing.TYPE_CHECKING: from . import CursorResult @@ -80,6 +82,7 @@ if typing.TYPE_CHECKING: _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") _EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT @@ -1352,11 +1355,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @overload def execute( self, - statement: TypedReturnsRows[_T], + statement: TypedReturnsRows[Tuple[Unpack[_Ts]]], parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[_T]: + ) -> CursorResult[Unpack[_Ts]]: ... @overload @@ -1366,7 +1369,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: ... def execute( @@ -1375,7 +1378,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1424,7 +1427,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): func: FunctionElement[Any], distilled_parameters: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: """Execute a sql.FunctionElement object.""" return self._execute_clauseelement( @@ -1495,7 +1498,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ddl: ExecutableDDLElement, distilled_parameters: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: """Execute a schema.DDL object.""" execution_options = ddl._execution_options.merge_with( @@ -1591,7 +1594,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): elem: Executable, distilled_parameters: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: """Execute a sql.ClauseElement object.""" execution_options = elem._execution_options.merge_with( @@ -1664,7 +1667,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): compiled: Compiled, distilled_parameters: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: """Execute a sql.Compiled object. TODO: why do we have this? likely deprecate or remove @@ -1714,7 +1717,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): statement: str, parameters: Optional[_DBAPIAnyExecuteParams] = None, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: r"""Executes a string SQL statement on the DBAPI cursor directly, without any SQL compilation steps. @@ -1796,7 +1799,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options: _ExecuteOptions, *args: Any, **kw: Any, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.CursorResult`.""" @@ -1855,7 +1858,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): context: ExecutionContext, statement: Union[str, Compiled], parameters: Optional[_AnyMultiExecuteParams], - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: """continue the _execute_context() method for a single DBAPI cursor.execute() or cursor.executemany() call. @@ -1995,7 +1998,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): self, dialect: Dialect, context: ExecutionContext, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: """continue the _execute_context() method for an "insertmanyvalues" operation, which will invoke DBAPI cursor.execute() one or more times with individual log and diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 5e536323ee..579528d966 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -417,7 +417,7 @@ _NO_ROW = _NoRow._NO_ROW class ResultInternal(InPlaceGenerative, Generic[_R]): __slots__ = () - _real_result: Optional[Result[Any]] = None + _real_result: Optional[Result[Unpack[Tuple[Any, ...]]]] = None _generate_rows: bool = True _row_logging_fn: Optional[Callable[[Any], Any]] @@ -450,10 +450,10 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): @HasMemoized_ro_memoized_attribute def _row_getter(self) -> Optional[Callable[..., _R]]: - real_result: Result[Any] = ( + real_result: Result[Unpack[Tuple[Any, ...]]] = ( self._real_result if self._real_result - else cast("Result[Any]", self) + else cast("Result[Unpack[Tuple[Any, ...]]]", self) ) if real_result._source_supports_scalars: @@ -517,7 +517,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def iterrows(self: Result[Any]) -> Iterator[_R]: + def iterrows( + self: Result[Unpack[Tuple[Any, ...]]] + ) -> Iterator[_R]: for raw_row in self._fetchiter_impl(): obj: _InterimRowType[Any] = ( make_row(raw_row) if make_row else raw_row @@ -532,7 +534,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: - def iterrows(self: Result[Any]) -> Iterator[_R]: + def iterrows( + self: Result[Unpack[Tuple[Any, ...]]] + ) -> Iterator[_R]: for raw_row in self._fetchiter_impl(): row: _InterimRowType[Any] = ( make_row(raw_row) if make_row else raw_row @@ -597,7 +601,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): if self._unique_filter_state: uniques, strategy = self._unique_strategy - def onerow(self: Result[Any]) -> Union[_NoRow, _R]: + def onerow( + self: Result[Unpack[Tuple[Any, ...]]] + ) -> Union[_NoRow, _R]: _onerow = self._fetchone_impl while True: row = _onerow() @@ -618,7 +624,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: - def onerow(self: Result[Any]) -> Union[_NoRow, _R]: + def onerow( + self: Result[Unpack[Tuple[Any, ...]]] + ) -> Union[_NoRow, _R]: row = self._fetchone_impl() if row is None: return _NO_ROW @@ -678,7 +686,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result - else cast("Result[Any]", self) + else cast("Result[Unpack[Tuple[Any, ...]]]", self) ) if real_result._yield_per: num_required = num = real_result._yield_per @@ -718,7 +726,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result - else cast("Result[Any]", self) + else cast("Result[Unpack[Tuple[Any, ...]]]", self) ) num = real_result._yield_per @@ -850,7 +858,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result - else cast("Result[Any]", self) + else cast("Result[Unpack[Tuple[Any, ...]]]", self) ) if not real_result._source_supports_scalars or len(indexes) != 1: @@ -868,7 +876,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): real_result = ( self._real_result if self._real_result is not None - else cast("Result[Any]", self) + else cast("Result[Unpack[Tuple[Any, ...]]]", self) ) if not strategy and self._metadata._unique_filters: @@ -1629,7 +1637,7 @@ class FilterResult(ResultInternal[_R]): _post_creational_filter: Optional[Callable[[Any], Any]] - _real_result: Result[Any] + _real_result: Result[Unpack[Tuple[Any, ...]]] def __enter__(self) -> Self: return self diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index a13e106ff3..b854e4b94f 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -31,6 +31,8 @@ from ...sql.base import _generative from ...util.concurrency import greenlet_spawn from ...util.typing import Literal from ...util.typing import Self +from ...util.typing import TypeVarTuple +from ...util.typing import Unpack if TYPE_CHECKING: from ...engine import CursorResult @@ -39,12 +41,13 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _TP = TypeVar("_TP", bound=Tuple[Any, ...]) +_Ts = TypeVarTuple("_Ts") class AsyncCommon(FilterResult[_R]): __slots__ = () - _real_result: Result[Any] + _real_result: Result[Unpack[Tuple[Any, ...]]] _metadata: ResultMetaData async def close(self) -> None: # type: ignore[override] @@ -63,7 +66,7 @@ class AsyncCommon(FilterResult[_R]): return self._real_result.closed -class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): +class AsyncResult(_WithKeys, AsyncCommon[Row[Unpack[_Ts]]]): """An asyncio wrapper around a :class:`_result.Result` object. The :class:`_asyncio.AsyncResult` only applies to statement executions that @@ -86,9 +89,9 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): __slots__ = () - _real_result: Result[_TP] + _real_result: Result[Unpack[_Ts]] - def __init__(self, real_result: Result[_TP]): + def __init__(self, real_result: Result[Unpack[_Ts]]): self._real_result = real_result self._metadata = real_result._metadata @@ -103,7 +106,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): ) @property - def t(self) -> AsyncTupleResult[_TP]: + def t(self) -> AsyncTupleResult[Tuple[Unpack[_Ts]]]: """Apply a "typed tuple" typing filter to returned rows. The :attr:`_asyncio.AsyncResult.t` attribute is a synonym for @@ -114,7 +117,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): """ return self # type: ignore - def tuples(self) -> AsyncTupleResult[_TP]: + def tuples(self) -> AsyncTupleResult[Tuple[Unpack[_Ts]]]: """Apply a "typed tuple" typing filter to returned rows. This method returns the same :class:`_asyncio.AsyncResult` object @@ -163,7 +166,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): async def partitions( self, size: Optional[int] = None - ) -> AsyncIterator[Sequence[Row[_TP]]]: + ) -> AsyncIterator[Sequence[Row[Unpack[_Ts]]]]: """Iterate through sub-lists of rows of the size given. An async iterator is returned:: @@ -188,7 +191,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): else: break - async def fetchall(self) -> Sequence[Row[_TP]]: + async def fetchall(self) -> Sequence[Row[Unpack[_Ts]]]: """A synonym for the :meth:`_asyncio.AsyncResult.all` method. .. versionadded:: 2.0 @@ -197,7 +200,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): return await greenlet_spawn(self._allrows) - async def fetchone(self) -> Optional[Row[_TP]]: + async def fetchone(self) -> Optional[Row[Unpack[_Ts]]]: """Fetch one row. When all rows are exhausted, returns None. @@ -221,7 +224,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): async def fetchmany( self, size: Optional[int] = None - ) -> Sequence[Row[_TP]]: + ) -> Sequence[Row[Unpack[_Ts]]]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -242,7 +245,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self) -> Sequence[Row[_TP]]: + async def all(self) -> Sequence[Row[Unpack[_Ts]]]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -254,17 +257,17 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): return await greenlet_spawn(self._allrows) - def __aiter__(self) -> AsyncResult[_TP]: + def __aiter__(self) -> AsyncResult[Unpack[_Ts]]: return self - async def __anext__(self) -> Row[_TP]: + async def __anext__(self) -> Row[Unpack[_Ts]]: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self) -> Optional[Row[_TP]]: + async def first(self) -> Optional[Row[Unpack[_Ts]]]: """Fetch the first row or ``None`` if no row is present. Closes the result set and discards remaining rows. @@ -300,7 +303,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self) -> Optional[Row[_TP]]: + async def one_or_none(self) -> Optional[Row[Unpack[_Ts]]]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -371,7 +374,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): """ return await greenlet_spawn(self._only_one_row, True, False, True) - async def one(self) -> Row[_TP]: + async def one(self) -> Row[Unpack[_Ts]]: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -426,7 +429,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): """ return await greenlet_spawn(self._only_one_row, False, False, True) - async def freeze(self) -> FrozenResult[_TP]: + async def freeze(self) -> FrozenResult[Tuple[Unpack[_Ts]]]: """Return a callable object that will produce copies of this :class:`_asyncio.AsyncResult` when invoked. @@ -513,7 +516,11 @@ class AsyncScalarResult(AsyncCommon[_R]): _generate_rows = False - def __init__(self, real_result: Result[Any], index: _KeyIndexType): + def __init__( + self, + real_result: Result[Unpack[Tuple[Any, ...]]], + index: _KeyIndexType, + ): self._real_result = real_result if real_result._source_supports_scalars: @@ -644,7 +651,7 @@ class AsyncMappingResult(_WithKeys, AsyncCommon[RowMapping]): _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result: Result[Any]): + def __init__(self, result: Result[Unpack[Tuple[Any, ...]]]): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata @@ -944,7 +951,7 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): ... -_RT = TypeVar("_RT", bound="Result[Any]") +_RT = TypeVar("_RT", bound="Result[Unpack[Tuple[Any, ...]]]") async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 4c68f53ffa..3e1f7ecb33 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -31,6 +31,7 @@ from ...util import create_proxy_methods from ...util import ScopedRegistry from ...util import warn from ...util import warn_deprecated +from ...util.typing import Unpack if TYPE_CHECKING: from .engine import AsyncConnection @@ -562,7 +563,7 @@ class async_scoped_session(Generic[_AS]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: ... async def execute( @@ -573,7 +574,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: r"""Execute a statement and return a buffered :class:`_engine.Result` object. diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 30232e59cb..36a73ef6aa 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -38,6 +38,7 @@ from ...orm import Session from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import Unpack if TYPE_CHECKING: from .engine import AsyncConnection @@ -424,7 +425,7 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: ... async def execute( @@ -435,7 +436,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: """Execute a statement and return a buffered :class:`_engine.Result` object. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 31caedc378..c716a25683 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -20,6 +20,7 @@ from typing import Dict from typing import Iterable from typing import Optional from typing import overload +from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -53,6 +54,7 @@ from ..sql.dml import InsertDMLState from ..sql.dml import UpdateDMLState from ..util import EMPTY_DICT from ..util.typing import Literal +from ..util.typing import Unpack if TYPE_CHECKING: from ._typing import DMLStrategyArgument @@ -249,7 +251,7 @@ def _bulk_update( update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., enable_check_rowcount: bool = True, -) -> _result.Result[Any]: +) -> _result.Result[Unpack[Tuple[Any, ...]]]: ... @@ -261,7 +263,7 @@ def _bulk_update( update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = None, enable_check_rowcount: bool = True, -) -> Optional[_result.Result[Any]]: +) -> Optional[_result.Result[Unpack[Tuple[Any, ...]]]]: base_mapper = mapper.base_mapper search_keys = mapper._primary_key_propkeys @@ -1236,7 +1238,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState): "are 'raw', 'orm', 'bulk', 'auto" ) - result: _result.Result[Any] + result: _result.Result[Unpack[Tuple[Any, ...]]] if insert_options._dml_strategy == "raw": result = conn.execute( @@ -1572,7 +1574,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): "are 'orm', 'auto', 'bulk', 'core_only'" ) - result: _result.Result[Any] + result: _result.Result[Unpack[Tuple[Any, ...]]] if update_options._dml_strategy == "bulk": enable_check_rowcount = not statement._where_criteria diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index fed07334fb..247ca03874 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -72,6 +72,8 @@ from ..sql.type_api import TypeEngine from ..util import warn_deprecated from ..util.typing import RODescriptorReference from ..util.typing import TypedDict +from ..util.typing import Unpack + if typing.TYPE_CHECKING: from ._typing import _EntityType @@ -486,7 +488,7 @@ class MapperProperty( query_entity: _MapperEntity, path: AbstractEntityRegistry, mapper: Mapper[Any], - result: Result[Any], + result: Result[Unpack[Tuple[Any, ...]]], adapter: Optional[ORMAdapter], populators: _PopulatorDict, ) -> None: @@ -1056,7 +1058,7 @@ class StrategizedProperty(MapperProperty[_T]): query_entity: _MapperEntity, path: AbstractEntityRegistry, mapper: Mapper[Any], - result: Result[Any], + result: Result[Unpack[Tuple[Any, ...]]], adapter: Optional[ORMAdapter], populators: _PopulatorDict, ) -> None: @@ -1447,7 +1449,7 @@ class LoaderStrategy: path: AbstractEntityRegistry, loadopt: Optional[_LoadElement], mapper: Mapper[Any], - result: Result[Any], + result: Result[Unpack[Tuple[Any, ...]]], adapter: Optional[ORMAdapter], populators: _PopulatorDict, ) -> None: diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index cae6f0be21..8cab4a0dfc 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -53,6 +53,8 @@ from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import SelectState from ..util import EMPTY_DICT +from ..util.typing import Unpack + if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -75,7 +77,9 @@ _new_runid = util.counter() _PopulatorDict = Dict[str, List[Tuple[str, Any]]] -def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: +def instances( + cursor: CursorResult[Unpack[Tuple[Any, ...]]], context: QueryContext +) -> Result[Unpack[Tuple[Any, ...]]]: """Return a :class:`.Result` given an ORM query context. :param cursor: a :class:`.CursorResult`, generated by a statement diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index ab632bdd56..65ab1abef7 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -32,6 +32,7 @@ from ..util import ThreadLocalRegistry from ..util import warn from ..util import warn_deprecated from ..util.typing import Protocol +from ..util.typing import Unpack if TYPE_CHECKING: from ._typing import _EntityType @@ -708,7 +709,7 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: ... def execute( @@ -720,7 +721,7 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: r"""Execute a SQL expression construct. .. container:: class_bases diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d861981271..45925a6d9f 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -91,6 +91,9 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import IdentitySet from ..util.typing import Literal from ..util.typing import Protocol +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack + if typing.TYPE_CHECKING: from ._typing import _EntityType @@ -134,6 +137,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import TypedReturnsRows _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") __all__ = [ "Session", @@ -385,7 +389,7 @@ class ORMExecuteState(util.MemoizedSlots): params: Optional[_CoreAnyExecuteParams] = None, execution_options: Optional[OrmExecuteOptionsParameter] = None, bind_arguments: Optional[_BindArguments] = None, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: """Execute the statement represented by this :class:`.ORMExecuteState`, without re-invoking events that have already proceeded. @@ -2071,7 +2075,7 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: bool = ..., - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: ... def _execute_internal( @@ -2147,7 +2151,9 @@ class Session(_SessionClassMethods, EventTarget): ) for idx, fn in enumerate(events_todo): orm_exec_state._starting_event_idx = idx - fn_result: Optional[Result[Any]] = fn(orm_exec_state) + fn_result: Optional[Result[Unpack[Tuple[Any, ...]]]] = fn( + orm_exec_state + ) if fn_result: if _scalar_result: return fn_result.scalar() @@ -2187,7 +2193,9 @@ class Session(_SessionClassMethods, EventTarget): ) if compile_state_cls: - result: Result[Any] = compile_state_cls.orm_execute_statement( + result: Result[ + Unpack[Tuple[Any, ...]] + ] = compile_state_cls.orm_execute_statement( self, statement, params or {}, @@ -2208,14 +2216,14 @@ class Session(_SessionClassMethods, EventTarget): @overload def execute( self, - statement: TypedReturnsRows[_T], + statement: TypedReturnsRows[Tuple[Unpack[_Ts]]], params: Optional[_CoreAnyExecuteParams] = None, *, execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: + ) -> Result[Unpack[_Ts]]: ... @overload @@ -2228,7 +2236,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: + ) -> CursorResult[Unpack[Tuple[Any, ...]]]: ... @overload @@ -2241,7 +2249,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: ... def execute( @@ -2253,7 +2261,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: + ) -> Result[Unpack[Tuple[Any, ...]]]: r"""Execute a SQL expression construct. Returns a :class:`_engine.Result` object representing diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 48dfd25829..76be953a42 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -78,6 +78,7 @@ from ..util import HasMemoized_ro_memoized_attribute from ..util import TypingOnly from ..util.typing import Literal from ..util.typing import Self +from ..util.typing import Unpack if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -509,7 +510,7 @@ class ClauseElement( connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> Result[Any]: + ) -> Result[Unpack[typing_Tuple[Any, ...]]]: if self.supports_execution: if TYPE_CHECKING: assert isinstance(self, Executable)