From: Federico Caselli Date: Mon, 28 Apr 2025 21:44:50 +0000 (+0200) Subject: add correct typing for row getitem X-Git-Tag: rel_2_0_41~15 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=6b222d772400500ca7efbb02350bb6d8608f6bf1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add correct typing for row getitem The overloads were broken in 8a4c27589500bc57605bb8f28c215f5f0ae5066d Change-Id: I3736b15e95ead28537e25169a54521e991f763da (cherry picked from commit 4ac02007e030232f57226aafbb9313c8ff186a62) --- diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 5d597fd5f4..b84fb3d1cb 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -722,6 +722,14 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): return manyrows + @overload + def _only_one_row( + self: ResultInternal[Row[Any]], + raise_for_second_row: bool, + raise_for_none: bool, + scalar: Literal[True], + ) -> Any: ... + @overload def _only_one_row( self, diff --git a/lib/sqlalchemy/testing/fixtures/mypy.py b/lib/sqlalchemy/testing/fixtures/mypy.py index 0832d89246..849df4dc30 100644 --- a/lib/sqlalchemy/testing/fixtures/mypy.py +++ b/lib/sqlalchemy/testing/fixtures/mypy.py @@ -143,7 +143,9 @@ class MypyTest(TestBase): from sqlalchemy.ext.mypy.util import mypy_14 expected_messages = [] - expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)") + expected_re = re.compile( + r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)" + ) py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)") with open(path) as file_: current_assert_messages = [] @@ -161,9 +163,24 @@ class MypyTest(TestBase): if m: is_mypy = bool(m.group(1)) is_re = bool(m.group(2)) - is_type = bool(m.group(3)) + is_row = bool(m.group(3)) + is_type = bool(m.group(4)) + + expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5)) + if is_row: + expected_msg = re.sub( + r"Row\[([^\]]+)\]", + lambda m: f"tuple[{m.group(1)}, fallback=s" + f"qlalchemy.engine.row.{m.group(0)}]", + expected_msg, + ) + # For some reason it does not use or syntax (|) + expected_msg = re.sub( + r"Optional\[(.*)\]", + lambda m: f"Union[{m.group(1)}, None]", + expected_msg, + ) - expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4)) if is_type: if not is_re: # the goal here is that we can cut-and-paste @@ -243,7 +260,9 @@ class MypyTest(TestBase): return expected_messages - def _check_output(self, path, expected_messages, stdout, stderr, exitcode): + def _check_output( + self, path, expected_messages, stdout: str, stderr, exitcode + ): not_located = [] filename = os.path.basename(path) if expected_messages: @@ -263,7 +282,8 @@ class MypyTest(TestBase): ): while raw_lines: ol = raw_lines.pop(0) - if not re.match(r".+\.py:\d+: note: +def \[.*", ol): + if not re.match(r".+\.py:\d+: note: +def .*", ol): + raw_lines.insert(0, ol) break elif re.match( r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I diff --git a/test/typing/plain_files/engine/engine_result.py b/test/typing/plain_files/engine/engine_result.py new file mode 100644 index 0000000000..7ff20b7846 --- /dev/null +++ b/test/typing/plain_files/engine/engine_result.py @@ -0,0 +1,94 @@ +from typing import reveal_type +from typing import Tuple + +from sqlalchemy import column +from sqlalchemy.engine import Result +from sqlalchemy.engine import Row + + +def row_one(row: Row[Tuple[int, str, bool]]) -> None: + # EXPECTED_TYPE: Any + reveal_type(row[0]) + # EXPECTED_TYPE: Any + reveal_type(row[1]) + # EXPECTED_TYPE: Any + reveal_type(row[2]) + + # EXPECTED_MYPY: No overload variant of "__getitem__" of "Row" matches argument type "str" # noqa: E501 + row["a"] + + # EXPECTED_TYPE: RowMapping + reveal_type(row._mapping) + rm = row._mapping + # EXPECTED_TYPE: Any + reveal_type(rm["foo"]) + # EXPECTED_TYPE: Any + reveal_type(rm[column("bar")]) + + # EXPECTED_MYPY: Invalid index type "int" for "RowMapping"; expected type "str | SQLCoreOperations[Any]" # noqa: E501 + rm[3] + + +def result_one( + res: Result[Tuple[int, str]], r_single: Result[Tuple[float]] +) -> None: + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(res.one()) + # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None] + reveal_type(res.one_or_none()) + # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None] + reveal_type(res.fetchone()) + # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None] + reveal_type(res.first()) + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(res.all()) + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(res.fetchmany()) + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(res.fetchall()) + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(next(res)) + for rf in res: + # EXPECTED_TYPE: Row[Tuple[int, str]] + reveal_type(rf) + for rp in res.partitions(): + # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]] + reveal_type(rp) + + # EXPECTED_TYPE: ScalarResult[Any] + res_s = reveal_type(res.scalars()) + # EXPECTED_TYPE: ScalarResult[Any] + res_s = reveal_type(res.scalars(0)) + # EXPECTED_TYPE: Any + reveal_type(res_s.one()) + # EXPECTED_TYPE: ScalarResult[Any] + reveal_type(res.scalars(1)) + # EXPECTED_TYPE: MappingResult + reveal_type(res.mappings()) + # EXPECTED_TYPE: FrozenResult[Tuple[int, str]] + reveal_type(res.freeze()) + + # EXPECTED_TYPE: Any + reveal_type(res.scalar_one()) + # EXPECTED_TYPE: Union[Any, None] + reveal_type(res.scalar_one_or_none()) + # EXPECTED_TYPE: Any + reveal_type(res.scalar()) + + # EXPECTED_TYPE: ScalarResult[float] + res_s2 = reveal_type(r_single.scalars()) + # EXPECTED_TYPE: ScalarResult[float] + res_s2 = reveal_type(r_single.scalars(0)) + # EXPECTED_TYPE: float + reveal_type(res_s2.one()) + # EXPECTED_TYPE: ScalarResult[Any] + reveal_type(r_single.scalars(1)) + # EXPECTED_TYPE: MappingResult + reveal_type(r_single.mappings()) + + # EXPECTED_TYPE: float + reveal_type(r_single.scalar_one()) + # EXPECTED_TYPE: Union[float, None] + reveal_type(r_single.scalar_one_or_none()) + # EXPECTED_TYPE: Union[float, None] + reveal_type(r_single.scalar())