]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Start working on a TypeVarTuple support
authorYurii Karabas <1998uriyyo@gmail.com>
Tue, 14 Nov 2023 15:30:38 +0000 (17:30 +0200)
committerYurii Karabas <1998uriyyo@gmail.com>
Tue, 14 Nov 2023 15:33:24 +0000 (17:33 +0200)
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/events.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/engine/row.py
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/util/typing.py
tox.ini

index ff6e311a743ac4d7e0203409a2a54f609c1fcb31..b3828361f7597b0eb3206896c8ef49d59370129b 100644 (file)
@@ -53,6 +53,8 @@ from ..sql.type_api import TypeEngine
 from ..util import compat
 from ..util.typing import Literal
 from ..util.typing import Self
+from ..util.typing import TypeVarTuple
+from ..util.typing import Unpack
 
 
 if typing.TYPE_CHECKING:
@@ -71,7 +73,7 @@ if typing.TYPE_CHECKING:
     from ..sql.type_api import _ResultProcessorType
 
 
-_T = TypeVar("_T", bound=Any)
+_Ts = TypeVarTuple("_Ts")
 
 
 # metadata entry tuple indexes.
@@ -1375,7 +1377,7 @@ def null_dml_result() -> IteratorResult[Any]:
     return it
 
 
-class CursorResult(Result[_T]):
+class CursorResult(Result[Unpack[_Ts]]):
     """A Result that is representing state from a DBAPI cursor.
 
     .. versionchanged:: 1.4  The :class:`.CursorResult``
@@ -2108,7 +2110,7 @@ class CursorResult(Result[_T]):
     def _raw_row_iterator(self):
         return self._fetchiter_impl()
 
-    def merge(self, *others: Result[Any]) -> MergedResult[Any]:
+    def merge(self, *others: Result[Unpack[Tuple[Any, ...]]]) -> MergedResult[Unpack[Tuple[Any, ...]]]:
         merged_result = super().merge(*others)
         setup_rowcounts = self.context._has_rowcount
         if setup_rowcounts:
index 553d8f0bea1fb3d2260e7afc8524e7d86752fb0d..bd55dd95c3f96bd3a8b8a782d229eb58a089c790 100644 (file)
@@ -66,6 +66,8 @@ from ..sql.compiler import SQLCompiler
 from ..sql.elements import quoted_name
 from ..util.typing import Final
 from ..util.typing import Literal
+from ..util.typing import Unpack
+
 
 if typing.TYPE_CHECKING:
     from types import ModuleType
@@ -1181,7 +1183,7 @@ class DefaultExecutionContext(ExecutionContext):
     result_column_struct: Optional[
         Tuple[List[ResultColumnsEntry], bool, bool, bool, bool]
     ] = None
-    returned_default_rows: Optional[Sequence[Row[Any]]] = None
+    returned_default_rows: Optional[Sequence[Row[Unpack[Tuple[Any, ...]]]]] = None
 
     execution_options: _ExecuteOptions = util.EMPTY_DICT
 
index aac756d18a22e606a462ac3f3e396f632935710d..7e57e92a7f493a1a624916a916c82e9d3c2437c2 100644 (file)
@@ -25,6 +25,7 @@ from .interfaces import Dialect
 from .. import event
 from .. import exc
 from ..util.typing import Literal
+from ..util.typing import Unpack
 
 if typing.TYPE_CHECKING:
     from .interfaces import _CoreMultiExecuteParams
@@ -270,7 +271,7 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
         multiparams: _CoreMultiExecuteParams,
         params: _CoreSingleExecuteParams,
         execution_options: _ExecuteOptions,
-        result: Result[Any],
+        result: Result[Unpack[Tuple[Any, ...]]],
     ) -> None:
         """Intercept high level execute() events after execute.
 
index acbe6f0923632f5e3f3b44600206a4e5658fadd5..07689564a49c51284a36c7948fbebcce9bd2ec0d 100644 (file)
@@ -45,6 +45,8 @@ from ..util import NONE_SET
 from ..util._has_cy import HAS_CYEXTENSION
 from ..util.typing import Literal
 from ..util.typing import Self
+from ..util.typing import Unpack
+from ..util.typing import TypeVarTuple
 
 if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
     from ._py_row import tuplegetter as tuplegetter
@@ -75,6 +77,7 @@ _RawRowType = Tuple[Any, ...]
 _R = TypeVar("_R", bound=_RowData)
 _T = TypeVar("_T", bound=Any)
 _TP = TypeVar("_TP", bound=Tuple[Any, ...])
+_Ts = TypeVarTuple("_Ts")
 
 _InterimRowType = Union[_R, _RawRowType]
 """a catchall "anything" kind of return type that can be applied
@@ -394,7 +397,7 @@ class SimpleResultMetaData(ResultMetaData):
 
 def result_tuple(
     fields: Sequence[str], extra: Optional[Any] = None
-) -> Callable[[Iterable[Any]], Row[Any]]:
+) -> Callable[[Iterable[Any]], Row[Unpack[_RawRowType]]]:
     parent = SimpleResultMetaData(fields, extra)
     return functools.partial(
         Row, parent, parent._effective_processors, parent._key_to_index
@@ -909,7 +912,7 @@ class _WithKeys:
         return self._metadata.keys
 
 
-class Result(_WithKeys, ResultInternal[Row[_TP]]):
+class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]):
     """Represent a set of database results.
 
     .. versionadded:: 1.4  The :class:`_engine.Result` object provides a
@@ -1132,12 +1135,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
         return self._column_slices(col_expressions)
 
     @overload
-    def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]:
+    def scalars(self: Result[_T, Unpack[_RawRowType]]) -> ScalarResult[_T]:
         ...
 
     @overload
     def scalars(
-        self: Result[Tuple[_T]], index: Literal[0]
+        self: Result[_T, Unpack[_RawRowType]], index: Literal[0]
     ) -> ScalarResult[_T]:
         ...
 
@@ -1212,7 +1215,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
         return MappingResult(self)
 
     @property
-    def t(self) -> TupleResult[_TP]:
+    def t(self) -> TupleResult[Tuple[Unpack[_Ts]]]:
         """Apply a "typed tuple" typing filter to returned rows.
 
         The :attr:`_engine.Result.t` attribute is a synonym for
@@ -1223,7 +1226,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
         """
         return self  # type: ignore
 
-    def tuples(self) -> TupleResult[_TP]:
+    def tuples(self) -> TupleResult[Tuple[Unpack[_Ts]]]:
         """Apply a "typed tuple" typing filter to returned rows.
 
         This method returns the same :class:`_engine.Result` object
@@ -1258,15 +1261,15 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
         """
         raise NotImplementedError()
 
-    def __iter__(self) -> Iterator[Row[_TP]]:
+    def __iter__(self) -> Iterator[Row[Unpack[_Ts]]]:
         return self._iter_impl()
 
-    def __next__(self) -> Row[_TP]:
+    def __next__(self) -> Row[Unpack[_Ts]]:
         return self._next_impl()
 
     def partitions(
         self, size: Optional[int] = None
-    ) -> Iterator[Sequence[Row[_TP]]]:
+    ) -> Iterator[Sequence[Row[Unpack[_Ts]]]]:
         """Iterate through sub-lists of rows of the size given.
 
         Each list will be of the size given, excluding the last list to
@@ -1322,12 +1325,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
             else:
                 break
 
-    def fetchall(self) -> Sequence[Row[_TP]]:
+    def fetchall(self) -> Sequence[Row[Unpack[_Ts]]]:
         """A synonym for the :meth:`_engine.Result.all` method."""
 
         return self._allrows()
 
-    def fetchone(self) -> Optional[Row[_TP]]:
+    def fetchone(self) -> Optional[Row[Unpack[_Ts]]]:
         """Fetch one row.
 
         When all rows are exhausted, returns None.
@@ -1349,7 +1352,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
         else:
             return row
 
-    def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]:
+    def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[Unpack[_Ts]]]:
         """Fetch many rows.
 
         When all rows are exhausted, returns an empty list.
@@ -1370,7 +1373,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
 
         return self._manyrow_getter(self, size)
 
-    def all(self) -> Sequence[Row[_TP]]:
+    def all(self) -> Sequence[Row[Unpack[_Ts]]]:
         """Return all rows in a list.
 
         Closes the result set after invocation.   Subsequent invocations
@@ -1389,7 +1392,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
 
         return self._allrows()
 
-    def first(self) -> Optional[Row[_TP]]:
+    def first(self) -> Optional[Row[Unpack[_Ts]]]:
         """Fetch the first row or ``None`` if no row is present.
 
         Closes the result set and discards remaining rows.
@@ -1428,7 +1431,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
             raise_for_second_row=False, raise_for_none=False, scalar=False
         )
 
-    def one_or_none(self) -> Optional[Row[_TP]]:
+    def one_or_none(self) -> Optional[Row[Unpack[_Ts]]]:
         """Return at most one result or raise an exception.
 
         Returns ``None`` if the result has no rows.
@@ -1503,7 +1506,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
             raise_for_second_row=True, raise_for_none=False, scalar=True
         )
 
-    def one(self) -> Row[_TP]:
+    def one(self) -> Row[Unpack[_Ts]]:
         """Return exactly one row or raise an exception.
 
         Raises :class:`.NoResultFound` if the result returns no
@@ -1562,7 +1565,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
             raise_for_second_row=False, raise_for_none=False, scalar=True
         )
 
-    def freeze(self) -> FrozenResult[_TP]:
+    def freeze(self) -> FrozenResult[Unpack[_Ts]]:
         """Return a callable object that will produce copies of this
         :class:`_engine.Result` when invoked.
 
@@ -1585,7 +1588,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
 
         return FrozenResult(self)
 
-    def merge(self, *others: Result[Any]) -> MergedResult[_TP]:
+    def merge(self, *others: Result[Any]) -> MergedResult[Unpack[_Ts]]:
         """Merge this :class:`_engine.Result` with other compatible result
         objects.
 
@@ -1720,7 +1723,7 @@ class ScalarResult(FilterResult[_R]):
 
     _post_creational_filter: Optional[Callable[[Any], Any]]
 
-    def __init__(self, real_result: Result[Any], index: _KeyIndexType):
+    def __init__(self, real_result: Result[Unpack[_RawRowType]], index: _KeyIndexType):
         self._real_result = real_result
 
         if real_result._source_supports_scalars:
@@ -2013,7 +2016,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
 
     _post_creational_filter = operator.attrgetter("_mapping")
 
-    def __init__(self, result: Result[Any]):
+    def __init__(self, result: Result[Unpack[_RawRowType]]):
         self._real_result = result
         self._unique_filter_state = result._unique_filter_state
         self._metadata = result._metadata
@@ -2140,7 +2143,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
         )
 
 
-class FrozenResult(Generic[_TP]):
+class FrozenResult(Generic[Unpack[_Ts]]):
     """Represents a :class:`_engine.Result` object in a "frozen" state suitable
     for caching.
 
@@ -2181,7 +2184,7 @@ class FrozenResult(Generic[_TP]):
 
     data: Sequence[Any]
 
-    def __init__(self, result: Result[_TP]):
+    def __init__(self, result: Result[Unpack[_Ts]]):
         self.metadata = result._metadata._for_freeze()
         self._source_supports_scalars = result._source_supports_scalars
         self._attributes = result._attributes
@@ -2198,21 +2201,21 @@ class FrozenResult(Generic[_TP]):
             return [list(row) for row in self.data]
 
     def with_new_rows(
-        self, tuple_data: Sequence[Row[_TP]]
-    ) -> FrozenResult[_TP]:
+        self, tuple_data: Sequence[Row[Unpack[_Ts]]]
+    ) -> FrozenResult[Unpack[_Ts]]:
         fr = FrozenResult.__new__(FrozenResult)
         fr.metadata = self.metadata
         fr._attributes = self._attributes
         fr._source_supports_scalars = self._source_supports_scalars
 
         if self._source_supports_scalars:
-            fr.data = [d[0] for d in tuple_data]
+            fr.data = [d[0] for d in tuple_data]  # type: ignore[misc]
         else:
             fr.data = tuple_data
         return fr
 
-    def __call__(self) -> Result[_TP]:
-        result: IteratorResult[_TP] = IteratorResult(
+    def __call__(self) -> Result[Unpack[_Ts]]:
+        result: IteratorResult[Unpack[_Ts]] = IteratorResult(
             self.metadata, iter(self.data)
         )
         result._attributes = self._attributes
@@ -2220,7 +2223,7 @@ class FrozenResult(Generic[_TP]):
         return result
 
 
-class IteratorResult(Result[_TP]):
+class IteratorResult(Result[Unpack[_Ts]]):
     """A :class:`_engine.Result` that gets data from a Python iterator of
     :class:`_engine.Row` objects or similar row-like data.
 
@@ -2307,7 +2310,7 @@ def null_result() -> IteratorResult[Any]:
     return IteratorResult(SimpleResultMetaData([]), iter([]))
 
 
-class ChunkedIteratorResult(IteratorResult[_TP]):
+class ChunkedIteratorResult(IteratorResult[Unpack[_Ts]]):
     """An :class:`_engine.IteratorResult` that works from an
     iterator-producing callable.
 
@@ -2364,7 +2367,7 @@ class ChunkedIteratorResult(IteratorResult[_TP]):
         return super()._fetchmany_impl(size=size)
 
 
-class MergedResult(IteratorResult[_TP]):
+class MergedResult(IteratorResult[Unpack[_Ts]]):
     """A :class:`_engine.Result` that is merged from any number of
     :class:`_engine.Result` objects.
 
@@ -2378,7 +2381,7 @@ class MergedResult(IteratorResult[_TP]):
     rowcount: Optional[int]
 
     def __init__(
-        self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]]
+        self, cursor_metadata: ResultMetaData, results: Sequence[Result[Unpack[_Ts]]]
     ):
         self._results = results
         super().__init__(
index d2bb2e4c9a6a716a1420bf39fda7c3b35c6ade63..7e357bb6e69eaa812e99399e2d6b284e9f7712b2 100644 (file)
@@ -31,6 +31,8 @@ from typing import Union
 
 from ..sql import util as sql_util
 from ..util import deprecated
+from ..util.typing import TypeVarTuple
+from ..util.typing import Unpack
 from ..util._has_cy import HAS_CYEXTENSION
 
 if TYPE_CHECKING or not HAS_CYEXTENSION:
@@ -42,12 +44,17 @@ if TYPE_CHECKING:
     from .result import _KeyType
     from .result import _ProcessorsType
     from .result import RMKeyView
+    from typing import Tuple as _RowBase
+else:
+    _RowBase = Sequence
+
 
 _T = TypeVar("_T", bound=Any)
 _TP = TypeVar("_TP", bound=Tuple[Any, ...])
+_Ts = TypeVarTuple("_Ts")
 
 
-class Row(BaseRow, Sequence[Any], Generic[_TP]):
+class Row(BaseRow, _RowBase[Unpack[_Ts]], Generic[Unpack[_Ts]]):
     """Represent a single result row.
 
     The :class:`.Row` object represents a row of a database result.  It is
@@ -83,7 +90,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
     def __delattr__(self, name: str) -> NoReturn:
         raise AttributeError("can't delete attribute")
 
-    def _tuple(self) -> _TP:
+    def _tuple(self) -> Tuple[Unpack[_Ts]]:
         """Return a 'tuple' form of this :class:`.Row`.
 
         At runtime, this method returns "self"; the :class:`.Row` object is
@@ -105,7 +112,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
 
 
         """
-        return self  # type: ignore
+        return self
 
     @deprecated(
         "2.0.19",
@@ -114,7 +121,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
         "methods and library-level attributes are intended to be underscored "
         "to avoid name conflicts.  Please use :meth:`Row._tuple`.",
     )
-    def tuple(self) -> _TP:
+    def tuple(self) -> Tuple[Unpack[_Ts]]:
         """Return a 'tuple' form of this :class:`.Row`.
 
         .. versionadded:: 2.0
@@ -123,7 +130,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
         return self._tuple()
 
     @property
-    def _t(self) -> _TP:
+    def _t(self) -> Tuple[Unpack[_Ts]]:
         """A synonym for :meth:`.Row._tuple`.
 
         .. versionadded:: 2.0.19 - The :attr:`.Row._t` attribute supersedes
@@ -135,7 +142,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
 
             :attr:`.Result.t`
         """
-        return self  # type: ignore
+        return self
 
     @property
     @deprecated(
@@ -145,7 +152,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
         "methods and library-level attributes are intended to be underscored "
         "to avoid name conflicts.  Please use :attr:`Row._t`.",
     )
-    def t(self) -> _TP:
+    def t(self) -> Tuple[Unpack[_Ts]]:
         """A synonym for :meth:`.Row._tuple`.
 
         .. versionadded:: 2.0
@@ -172,7 +179,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
 
     def _filter_on_values(
         self, processor: Optional[_ProcessorsType]
-    ) -> Row[Any]:
+    ) -> Row[Unpack[_Ts]]:
         return Row(self._parent, processor, self._key_to_index, self._data)
 
     if not TYPE_CHECKING:
@@ -210,19 +217,6 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
 
     __hash__ = BaseRow.__hash__
 
-    if TYPE_CHECKING:
-
-        @overload
-        def __getitem__(self, index: int) -> Any:
-            ...
-
-        @overload
-        def __getitem__(self, index: slice) -> Sequence[Any]:
-            ...
-
-        def __getitem__(self, index: Union[int, slice]) -> Any:
-            ...
-
     def __lt__(self, other: Any) -> bool:
         return self._op(other, operator.lt)
 
index 963bd005a4be3fcdbdbc006d7b221766dfbf4568..25b897176f0f1313624f5447bb0c2e97b1fee180 100644 (file)
@@ -50,6 +50,9 @@ from ..orm.session import _PKIdentityArgument
 from ..orm.session import Session
 from ..util.typing import Protocol
 from ..util.typing import Self
+from ..util.typing import TypeVarTuple
+from ..util.typing import Unpack
+
 
 if TYPE_CHECKING:
     from ..engine.base import Connection
@@ -72,6 +75,7 @@ if TYPE_CHECKING:
 __all__ = ["ShardedSession", "ShardedQuery"]
 
 _T = TypeVar("_T", bound=Any)
+_Ts = TypeVarTuple("_Ts")
 
 
 ShardIdentifier = str
@@ -427,7 +431,7 @@ class set_shard_id(ORMOption):
 
 def execute_and_instances(
     orm_context: ORMExecuteState,
-) -> Union[Result[_T], IteratorResult[_TP]]:
+) -> Union[Result[Unpack[_Ts]], IteratorResult[Unpack[_Ts]]]:
     active_options: Union[
         None,
         QueryContext.default_load_options,
@@ -449,7 +453,7 @@ def execute_and_instances(
 
     def iter_for_shard(
         shard_id: ShardIdentifier,
-    ) -> Union[Result[_T], IteratorResult[_TP]]:
+    ) -> Union[Result[Unpack[_Ts]], IteratorResult[Unpack[_Ts]]]:
         bind_arguments = dict(orm_context.bind_arguments)
         bind_arguments["shard_id"] = shard_id
 
index 28480a5d437f0a89cb8c3c524ba03d3721760f7e..15e7da265710b282ebffe4bc875e478d987e75dc 100644 (file)
@@ -70,6 +70,7 @@ from .. import exc
 from .. import util
 from ..util.typing import Literal
 from ..util.typing import Protocol
+from ..util.typing import Unpack
 
 if typing.TYPE_CHECKING:
     from ._typing import _EquivalentColumnMap
@@ -588,7 +589,7 @@ class _repr_row(_repr_base):
 
     __slots__ = ("row",)
 
-    def __init__(self, row: Row[Any], max_chars: int = 300):
+    def __init__(self, row: Row[Unpack[Tuple[Any, ...]]], max_chars: int = 300):
         self.row = row
         self.max_chars = max_chars
 
index 3d15d43db7652212d78994c51e2975ae0b987cee..ce2c0c34cefacab7ccd45fe0f3483e94cc6dd419 100644 (file)
@@ -51,10 +51,13 @@ if True:  # zimports removes the tailing comments
     from typing_extensions import TypeAlias as TypeAlias  # 3.10
     from typing_extensions import TypedDict as TypedDict  # 3.8
     from typing_extensions import TypeGuard as TypeGuard  # 3.10
+    from typing_extensions import TypeVarTuple as TypeVarTuple  # 3.11
     from typing_extensions import Self as Self  # 3.11
+    from typing_extensions import Unpack as Unpack  # 3.11
 
 
 _T = TypeVar("_T", bound=Any)
+_Ts = TypeVarTuple("_Ts")
 _KT = TypeVar("_KT")
 _KT_co = TypeVar("_KT_co", covariant=True)
 _KT_contra = TypeVar("_KT_contra", contravariant=True)
diff --git a/tox.ini b/tox.ini
index bc95175597526e423bf3b6a01205b28d19b103b6..1cb44a1a3da47ba266bb02012cc44a986c7c079b 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -175,7 +175,7 @@ commands=
 [testenv:pep484]
 deps=
      greenlet != 0.4.17
-     mypy >= 1.6.0
+     mypy >= 1.7.0
 commands =
     mypy  {env:MYPY_COLOR} ./lib/sqlalchemy
     # pyright changes too often with not-exactly-correct errors
@@ -187,7 +187,7 @@ deps=
      pytest>=7.0.0rc1,<8
      pytest-xdist
      greenlet != 0.4.17
-     mypy >= 1.2.0
+     mypy >= 1.7.0
      patch==1.*
 
 commands =