From: Yurii Karabas <1998uriyyo@gmail.com> Date: Tue, 14 Nov 2023 15:30:38 +0000 (+0200) Subject: Start working on a TypeVarTuple support X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9424d8f26f31218d00893e16b31db2e19a26707b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Start working on a TypeVarTuple support --- diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index ff6e311a74..b3828361f7 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -53,6 +53,8 @@ from ..sql.type_api import TypeEngine from ..util import compat from ..util.typing import Literal from ..util.typing import Self +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack if typing.TYPE_CHECKING: @@ -71,7 +73,7 @@ if typing.TYPE_CHECKING: from ..sql.type_api import _ResultProcessorType -_T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") # metadata entry tuple indexes. @@ -1375,7 +1377,7 @@ def null_dml_result() -> IteratorResult[Any]: return it -class CursorResult(Result[_T]): +class CursorResult(Result[Unpack[_Ts]]): """A Result that is representing state from a DBAPI cursor. .. versionchanged:: 1.4 The :class:`.CursorResult`` @@ -2108,7 +2110,7 @@ class CursorResult(Result[_T]): def _raw_row_iterator(self): return self._fetchiter_impl() - def merge(self, *others: Result[Any]) -> MergedResult[Any]: + def merge(self, *others: Result[Unpack[Tuple[Any, ...]]]) -> MergedResult[Unpack[Tuple[Any, ...]]]: merged_result = super().merge(*others) setup_rowcounts = self.context._has_rowcount if setup_rowcounts: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 553d8f0bea..bd55dd95c3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -66,6 +66,8 @@ from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name from ..util.typing import Final from ..util.typing import Literal +from ..util.typing import Unpack + if typing.TYPE_CHECKING: from types import ModuleType @@ -1181,7 +1183,7 @@ class DefaultExecutionContext(ExecutionContext): result_column_struct: Optional[ Tuple[List[ResultColumnsEntry], bool, bool, bool, bool] ] = None - returned_default_rows: Optional[Sequence[Row[Any]]] = None + returned_default_rows: Optional[Sequence[Row[Unpack[Tuple[Any, ...]]]]] = None execution_options: _ExecuteOptions = util.EMPTY_DICT diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index aac756d18a..7e57e92a7f 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -25,6 +25,7 @@ from .interfaces import Dialect from .. import event from .. import exc from ..util.typing import Literal +from ..util.typing import Unpack if typing.TYPE_CHECKING: from .interfaces import _CoreMultiExecuteParams @@ -270,7 +271,7 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]): multiparams: _CoreMultiExecuteParams, params: _CoreSingleExecuteParams, execution_options: _ExecuteOptions, - result: Result[Any], + result: Result[Unpack[Tuple[Any, ...]]], ) -> None: """Intercept high level execute() events after execute. diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index acbe6f0923..07689564a4 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -45,6 +45,8 @@ from ..util import NONE_SET from ..util._has_cy import HAS_CYEXTENSION from ..util.typing import Literal from ..util.typing import Self +from ..util.typing import Unpack +from ..util.typing import TypeVarTuple if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_row import tuplegetter as tuplegetter @@ -75,6 +77,7 @@ _RawRowType = Tuple[Any, ...] _R = TypeVar("_R", bound=_RowData) _T = TypeVar("_T", bound=Any) _TP = TypeVar("_TP", bound=Tuple[Any, ...]) +_Ts = TypeVarTuple("_Ts") _InterimRowType = Union[_R, _RawRowType] """a catchall "anything" kind of return type that can be applied @@ -394,7 +397,7 @@ class SimpleResultMetaData(ResultMetaData): def result_tuple( fields: Sequence[str], extra: Optional[Any] = None -) -> Callable[[Iterable[Any]], Row[Any]]: +) -> Callable[[Iterable[Any]], Row[Unpack[_RawRowType]]]: parent = SimpleResultMetaData(fields, extra) return functools.partial( Row, parent, parent._effective_processors, parent._key_to_index @@ -909,7 +912,7 @@ class _WithKeys: return self._metadata.keys -class Result(_WithKeys, ResultInternal[Row[_TP]]): +class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): """Represent a set of database results. .. versionadded:: 1.4 The :class:`_engine.Result` object provides a @@ -1132,12 +1135,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): return self._column_slices(col_expressions) @overload - def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: + def scalars(self: Result[_T, Unpack[_RawRowType]]) -> ScalarResult[_T]: ... @overload def scalars( - self: Result[Tuple[_T]], index: Literal[0] + self: Result[_T, Unpack[_RawRowType]], index: Literal[0] ) -> ScalarResult[_T]: ... @@ -1212,7 +1215,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): return MappingResult(self) @property - def t(self) -> TupleResult[_TP]: + def t(self) -> TupleResult[Tuple[Unpack[_Ts]]]: """Apply a "typed tuple" typing filter to returned rows. The :attr:`_engine.Result.t` attribute is a synonym for @@ -1223,7 +1226,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): """ return self # type: ignore - def tuples(self) -> TupleResult[_TP]: + def tuples(self) -> TupleResult[Tuple[Unpack[_Ts]]]: """Apply a "typed tuple" typing filter to returned rows. This method returns the same :class:`_engine.Result` object @@ -1258,15 +1261,15 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): """ raise NotImplementedError() - def __iter__(self) -> Iterator[Row[_TP]]: + def __iter__(self) -> Iterator[Row[Unpack[_Ts]]]: return self._iter_impl() - def __next__(self) -> Row[_TP]: + def __next__(self) -> Row[Unpack[_Ts]]: return self._next_impl() def partitions( self, size: Optional[int] = None - ) -> Iterator[Sequence[Row[_TP]]]: + ) -> Iterator[Sequence[Row[Unpack[_Ts]]]]: """Iterate through sub-lists of rows of the size given. Each list will be of the size given, excluding the last list to @@ -1322,12 +1325,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): else: break - def fetchall(self) -> Sequence[Row[_TP]]: + def fetchall(self) -> Sequence[Row[Unpack[_Ts]]]: """A synonym for the :meth:`_engine.Result.all` method.""" return self._allrows() - def fetchone(self) -> Optional[Row[_TP]]: + def fetchone(self) -> Optional[Row[Unpack[_Ts]]]: """Fetch one row. When all rows are exhausted, returns None. @@ -1349,7 +1352,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): else: return row - def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[Unpack[_Ts]]]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -1370,7 +1373,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): return self._manyrow_getter(self, size) - def all(self) -> Sequence[Row[_TP]]: + def all(self) -> Sequence[Row[Unpack[_Ts]]]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -1389,7 +1392,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): return self._allrows() - def first(self) -> Optional[Row[_TP]]: + def first(self) -> Optional[Row[Unpack[_Ts]]]: """Fetch the first row or ``None`` if no row is present. Closes the result set and discards remaining rows. @@ -1428,7 +1431,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): raise_for_second_row=False, raise_for_none=False, scalar=False ) - def one_or_none(self) -> Optional[Row[_TP]]: + def one_or_none(self) -> Optional[Row[Unpack[_Ts]]]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -1503,7 +1506,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): raise_for_second_row=True, raise_for_none=False, scalar=True ) - def one(self) -> Row[_TP]: + def one(self) -> Row[Unpack[_Ts]]: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -1562,7 +1565,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): raise_for_second_row=False, raise_for_none=False, scalar=True ) - def freeze(self) -> FrozenResult[_TP]: + def freeze(self) -> FrozenResult[Unpack[_Ts]]: """Return a callable object that will produce copies of this :class:`_engine.Result` when invoked. @@ -1585,7 +1588,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): return FrozenResult(self) - def merge(self, *others: Result[Any]) -> MergedResult[_TP]: + def merge(self, *others: Result[Any]) -> MergedResult[Unpack[_Ts]]: """Merge this :class:`_engine.Result` with other compatible result objects. @@ -1720,7 +1723,7 @@ class ScalarResult(FilterResult[_R]): _post_creational_filter: Optional[Callable[[Any], Any]] - def __init__(self, real_result: Result[Any], index: _KeyIndexType): + def __init__(self, real_result: Result[Unpack[_RawRowType]], index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -2013,7 +2016,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result: Result[Any]): + def __init__(self, result: Result[Unpack[_RawRowType]]): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata @@ -2140,7 +2143,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): ) -class FrozenResult(Generic[_TP]): +class FrozenResult(Generic[Unpack[_Ts]]): """Represents a :class:`_engine.Result` object in a "frozen" state suitable for caching. @@ -2181,7 +2184,7 @@ class FrozenResult(Generic[_TP]): data: Sequence[Any] - def __init__(self, result: Result[_TP]): + def __init__(self, result: Result[Unpack[_Ts]]): self.metadata = result._metadata._for_freeze() self._source_supports_scalars = result._source_supports_scalars self._attributes = result._attributes @@ -2198,21 +2201,21 @@ class FrozenResult(Generic[_TP]): return [list(row) for row in self.data] def with_new_rows( - self, tuple_data: Sequence[Row[_TP]] - ) -> FrozenResult[_TP]: + self, tuple_data: Sequence[Row[Unpack[_Ts]]] + ) -> FrozenResult[Unpack[_Ts]]: fr = FrozenResult.__new__(FrozenResult) fr.metadata = self.metadata fr._attributes = self._attributes fr._source_supports_scalars = self._source_supports_scalars if self._source_supports_scalars: - fr.data = [d[0] for d in tuple_data] + fr.data = [d[0] for d in tuple_data] # type: ignore[misc] else: fr.data = tuple_data return fr - def __call__(self) -> Result[_TP]: - result: IteratorResult[_TP] = IteratorResult( + def __call__(self) -> Result[Unpack[_Ts]]: + result: IteratorResult[Unpack[_Ts]] = IteratorResult( self.metadata, iter(self.data) ) result._attributes = self._attributes @@ -2220,7 +2223,7 @@ class FrozenResult(Generic[_TP]): return result -class IteratorResult(Result[_TP]): +class IteratorResult(Result[Unpack[_Ts]]): """A :class:`_engine.Result` that gets data from a Python iterator of :class:`_engine.Row` objects or similar row-like data. @@ -2307,7 +2310,7 @@ def null_result() -> IteratorResult[Any]: return IteratorResult(SimpleResultMetaData([]), iter([])) -class ChunkedIteratorResult(IteratorResult[_TP]): +class ChunkedIteratorResult(IteratorResult[Unpack[_Ts]]): """An :class:`_engine.IteratorResult` that works from an iterator-producing callable. @@ -2364,7 +2367,7 @@ class ChunkedIteratorResult(IteratorResult[_TP]): return super()._fetchmany_impl(size=size) -class MergedResult(IteratorResult[_TP]): +class MergedResult(IteratorResult[Unpack[_Ts]]): """A :class:`_engine.Result` that is merged from any number of :class:`_engine.Result` objects. @@ -2378,7 +2381,7 @@ class MergedResult(IteratorResult[_TP]): rowcount: Optional[int] def __init__( - self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]] + self, cursor_metadata: ResultMetaData, results: Sequence[Result[Unpack[_Ts]]] ): self._results = results super().__init__( diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index d2bb2e4c9a..7e357bb6e6 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -31,6 +31,8 @@ from typing import Union from ..sql import util as sql_util from ..util import deprecated +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack from ..util._has_cy import HAS_CYEXTENSION if TYPE_CHECKING or not HAS_CYEXTENSION: @@ -42,12 +44,17 @@ if TYPE_CHECKING: from .result import _KeyType from .result import _ProcessorsType from .result import RMKeyView + from typing import Tuple as _RowBase +else: + _RowBase = Sequence + _T = TypeVar("_T", bound=Any) _TP = TypeVar("_TP", bound=Tuple[Any, ...]) +_Ts = TypeVarTuple("_Ts") -class Row(BaseRow, Sequence[Any], Generic[_TP]): +class Row(BaseRow, _RowBase[Unpack[_Ts]], Generic[Unpack[_Ts]]): """Represent a single result row. The :class:`.Row` object represents a row of a database result. It is @@ -83,7 +90,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): def __delattr__(self, name: str) -> NoReturn: raise AttributeError("can't delete attribute") - def _tuple(self) -> _TP: + def _tuple(self) -> Tuple[Unpack[_Ts]]: """Return a 'tuple' form of this :class:`.Row`. At runtime, this method returns "self"; the :class:`.Row` object is @@ -105,7 +112,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): """ - return self # type: ignore + return self @deprecated( "2.0.19", @@ -114,7 +121,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): "methods and library-level attributes are intended to be underscored " "to avoid name conflicts. Please use :meth:`Row._tuple`.", ) - def tuple(self) -> _TP: + def tuple(self) -> Tuple[Unpack[_Ts]]: """Return a 'tuple' form of this :class:`.Row`. .. versionadded:: 2.0 @@ -123,7 +130,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): return self._tuple() @property - def _t(self) -> _TP: + def _t(self) -> Tuple[Unpack[_Ts]]: """A synonym for :meth:`.Row._tuple`. .. versionadded:: 2.0.19 - The :attr:`.Row._t` attribute supersedes @@ -135,7 +142,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): :attr:`.Result.t` """ - return self # type: ignore + return self @property @deprecated( @@ -145,7 +152,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): "methods and library-level attributes are intended to be underscored " "to avoid name conflicts. Please use :attr:`Row._t`.", ) - def t(self) -> _TP: + def t(self) -> Tuple[Unpack[_Ts]]: """A synonym for :meth:`.Row._tuple`. .. versionadded:: 2.0 @@ -172,7 +179,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): def _filter_on_values( self, processor: Optional[_ProcessorsType] - ) -> Row[Any]: + ) -> Row[Unpack[_Ts]]: return Row(self._parent, processor, self._key_to_index, self._data) if not TYPE_CHECKING: @@ -210,19 +217,6 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): __hash__ = BaseRow.__hash__ - if TYPE_CHECKING: - - @overload - def __getitem__(self, index: int) -> Any: - ... - - @overload - def __getitem__(self, index: slice) -> Sequence[Any]: - ... - - def __getitem__(self, index: Union[int, slice]) -> Any: - ... - def __lt__(self, other: Any) -> bool: return self._op(other, operator.lt) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 963bd005a4..25b897176f 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -50,6 +50,9 @@ from ..orm.session import _PKIdentityArgument from ..orm.session import Session from ..util.typing import Protocol from ..util.typing import Self +from ..util.typing import TypeVarTuple +from ..util.typing import Unpack + if TYPE_CHECKING: from ..engine.base import Connection @@ -72,6 +75,7 @@ if TYPE_CHECKING: __all__ = ["ShardedSession", "ShardedQuery"] _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") ShardIdentifier = str @@ -427,7 +431,7 @@ class set_shard_id(ORMOption): def execute_and_instances( orm_context: ORMExecuteState, -) -> Union[Result[_T], IteratorResult[_TP]]: +) -> Union[Result[Unpack[_Ts]], IteratorResult[Unpack[_Ts]]]: active_options: Union[ None, QueryContext.default_load_options, @@ -449,7 +453,7 @@ def execute_and_instances( def iter_for_shard( shard_id: ShardIdentifier, - ) -> Union[Result[_T], IteratorResult[_TP]]: + ) -> Union[Result[Unpack[_Ts]], IteratorResult[Unpack[_Ts]]]: bind_arguments = dict(orm_context.bind_arguments) bind_arguments["shard_id"] = shard_id diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 28480a5d43..15e7da2657 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -70,6 +70,7 @@ from .. import exc from .. import util from ..util.typing import Literal from ..util.typing import Protocol +from ..util.typing import Unpack if typing.TYPE_CHECKING: from ._typing import _EquivalentColumnMap @@ -588,7 +589,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: Row[Any], max_chars: int = 300): + def __init__(self, row: Row[Unpack[Tuple[Any, ...]]], max_chars: int = 300): self.row = row self.max_chars = max_chars diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 3d15d43db7..ce2c0c34ce 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -51,10 +51,13 @@ if True: # zimports removes the tailing comments from typing_extensions import TypeAlias as TypeAlias # 3.10 from typing_extensions import TypedDict as TypedDict # 3.8 from typing_extensions import TypeGuard as TypeGuard # 3.10 + from typing_extensions import TypeVarTuple as TypeVarTuple # 3.11 from typing_extensions import Self as Self # 3.11 + from typing_extensions import Unpack as Unpack # 3.11 _T = TypeVar("_T", bound=Any) +_Ts = TypeVarTuple("_Ts") _KT = TypeVar("_KT") _KT_co = TypeVar("_KT_co", covariant=True) _KT_contra = TypeVar("_KT_contra", contravariant=True) diff --git a/tox.ini b/tox.ini index bc95175597..1cb44a1a3d 100644 --- a/tox.ini +++ b/tox.ini @@ -175,7 +175,7 @@ commands= [testenv:pep484] deps= greenlet != 0.4.17 - mypy >= 1.6.0 + mypy >= 1.7.0 commands = mypy {env:MYPY_COLOR} ./lib/sqlalchemy # pyright changes too often with not-exactly-correct errors @@ -187,7 +187,7 @@ deps= pytest>=7.0.0rc1,<8 pytest-xdist greenlet != 0.4.17 - mypy >= 1.2.0 + mypy >= 1.7.0 patch==1.* commands =