From: Federico Caselli Date: Mon, 28 Apr 2025 21:44:50 +0000 (+0200) Subject: add correct typing for row getitem X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=4ac02007e030232f57226aafbb9313c8ff186a62;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add correct typing for row getitem The overloads were broken in 8a4c27589500bc57605bb8f28c215f5f0ae5066d Change-Id: I3736b15e95ead28537e25169a54521e991f763da --- diff --git a/lib/sqlalchemy/engine/_row_cy.py b/lib/sqlalchemy/engine/_row_cy.py index 4319e05f0b..76659e1933 100644 --- a/lib/sqlalchemy/engine/_row_cy.py +++ b/lib/sqlalchemy/engine/_row_cy.py @@ -112,8 +112,10 @@ class BaseRow: def __hash__(self) -> int: return hash(self._data) - def __getitem__(self, key: Any) -> Any: - return self._data[key] + if not TYPE_CHECKING: + + def __getitem__(self, key: Any) -> Any: + return self._data[key] def _get_by_key_impl_mapping(self, key: _KeyType) -> Any: return self._get_by_key_impl(key, False) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 2aa0aec9cd..46c85d6f6c 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -724,6 +724,14 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): return manyrows + @overload + def _only_one_row( + self: ResultInternal[Row[_T, Unpack[TupleAny]]], + raise_for_second_row: bool, + raise_for_none: bool, + scalar: Literal[True], + ) -> _T: ... + @overload def _only_one_row( self, @@ -1463,13 +1471,7 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): raise_for_second_row=True, raise_for_none=False, scalar=False ) - @overload - def scalar_one(self: Result[_T]) -> _T: ... - - @overload - def scalar_one(self) -> Any: ... - - def scalar_one(self) -> Any: + def scalar_one(self: Result[_T, Unpack[TupleAny]]) -> _T: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_engine.Result.scalars` and @@ -1486,13 +1488,7 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): raise_for_second_row=True, raise_for_none=True, scalar=True ) - @overload - def scalar_one_or_none(self: Result[_T]) -> Optional[_T]: ... - - @overload - def scalar_one_or_none(self) -> Optional[Any]: ... - - def scalar_one_or_none(self) -> Optional[Any]: + def scalar_one_or_none(self: Result[_T, Unpack[TupleAny]]) -> Optional[_T]: """Return exactly one scalar result or ``None``. This is equivalent to calling :meth:`_engine.Result.scalars` and @@ -1542,13 +1538,7 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): raise_for_second_row=True, raise_for_none=True, scalar=False ) - @overload - def scalar(self: Result[_T]) -> Optional[_T]: ... - - @overload - def scalar(self) -> Any: ... - - def scalar(self) -> Any: + def scalar(self: Result[_T, Unpack[TupleAny]]) -> Optional[_T]: """Fetch the first column of the first row, and close the result set. Returns ``None`` if there are no rows to fetch. diff --git a/lib/sqlalchemy/testing/fixtures/mypy.py b/lib/sqlalchemy/testing/fixtures/mypy.py index 3a1ae2e9bd..4b43225789 100644 --- a/lib/sqlalchemy/testing/fixtures/mypy.py +++ b/lib/sqlalchemy/testing/fixtures/mypy.py @@ -129,7 +129,9 @@ class MypyTest(TestBase): def _collect_messages(self, path): expected_messages = [] - expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") + expected_re = re.compile( + r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)" + ) py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") with open(path) as file_: current_assert_messages = [] @@ -147,9 +149,24 @@ class MypyTest(TestBase): if m: is_mypy = bool(m.group(1)) is_re = bool(m.group(2)) - is_type = bool(m.group(3)) + is_row = bool(m.group(3)) + is_type = bool(m.group(4)) + + expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5)) + if is_row: + expected_msg = re.sub( + r"Row\[([^\]]+)\]", + lambda m: f"tuple[{m.group(1)}, fallback=s" + f"qlalchemy.engine.row.{m.group(0)}]", + expected_msg, + ) + # For some reason it does not use or syntax (|) + expected_msg = re.sub( + r"Optional\[(.*)\]", + lambda m: f"Union[{m.group(1)}, None]", + expected_msg, + ) - expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) if is_type: if not is_re: # the goal here is that we can cut-and-paste @@ -213,7 +230,9 @@ class MypyTest(TestBase): return expected_messages - def _check_output(self, path, expected_messages, stdout, stderr, exitcode): + def _check_output( + self, path, expected_messages, stdout: str, stderr, exitcode + ): not_located = [] filename = os.path.basename(path) if expected_messages: @@ -233,7 +252,8 @@ class MypyTest(TestBase): ): while raw_lines: ol = raw_lines.pop(0) - if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + if not re.match(r".+\.py:\d+: note: +def .*", ol): + raw_lines.insert(0, ol) break elif re.match( r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I diff --git a/test/typing/plain_files/engine/engine_result.py b/test/typing/plain_files/engine/engine_result.py new file mode 100644 index 0000000000..c8731618cc --- /dev/null +++ b/test/typing/plain_files/engine/engine_result.py @@ -0,0 +1,75 @@ +from typing import reveal_type + +from sqlalchemy import column +from sqlalchemy.engine import Result +from sqlalchemy.engine import Row + + +def row_one(row: Row[int, str, bool]) -> None: + # EXPECTED_TYPE: int + reveal_type(row[0]) + # EXPECTED_TYPE: str + reveal_type(row[1]) + # EXPECTED_TYPE: bool + reveal_type(row[2]) + + # EXPECTED_MYPY: Tuple index out of range + row[3] + # EXPECTED_MYPY: No overload variant of "__getitem__" of "tuple" matches argument type "str" # noqa: E501 + row["a"] + + # EXPECTED_TYPE: RowMapping + reveal_type(row._mapping) + rm = row._mapping + # EXPECTED_TYPE: Any + reveal_type(rm["foo"]) + # EXPECTED_TYPE: Any + reveal_type(rm[column("bar")]) + + # EXPECTED_MYPY: Invalid index type "int" for "RowMapping"; expected type "str | SQLCoreOperations[Any]" # noqa: E501 + rm[3] + + +def result_one(res: Result[int, str]) -> None: + # EXPECTED_ROW_TYPE: Row[int, str] + reveal_type(res.one()) + # EXPECTED_ROW_TYPE: Optional[Row[int, str]] + reveal_type(res.one_or_none()) + # EXPECTED_ROW_TYPE: Optional[Row[int, str]] + reveal_type(res.fetchone()) + # EXPECTED_ROW_TYPE: Optional[Row[int, str]] + reveal_type(res.first()) + # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] + reveal_type(res.all()) + # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] + reveal_type(res.fetchmany()) + # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] + reveal_type(res.fetchall()) + # EXPECTED_ROW_TYPE: Row[int, str] + reveal_type(next(res)) + for rf in res: + # EXPECTED_ROW_TYPE: Row[int, str] + reveal_type(rf) + for rp in res.partitions(): + # EXPECTED_ROW_TYPE: Sequence[Row[int, str]] + reveal_type(rp) + + # EXPECTED_TYPE: ScalarResult[int] + res_s = reveal_type(res.scalars()) + # EXPECTED_TYPE: ScalarResult[int] + res_s = reveal_type(res.scalars(0)) + # EXPECTED_TYPE: int + reveal_type(res_s.one()) + # EXPECTED_TYPE: ScalarResult[Any] + reveal_type(res.scalars(1)) + # EXPECTED_TYPE: MappingResult + reveal_type(res.mappings()) + # EXPECTED_TYPE: FrozenResult[int, str] + reveal_type(res.freeze()) + + # EXPECTED_TYPE: int + reveal_type(res.scalar_one()) + # EXPECTED_TYPE: Union[int, None] + reveal_type(res.scalar_one_or_none()) + # EXPECTED_TYPE: Union[int, None] + reveal_type(res.scalar())