]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add correct typing for row getitem
authorFederico Caselli <cfederico87@gmail.com>
Mon, 28 Apr 2025 21:44:50 +0000 (23:44 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 29 Apr 2025 20:41:30 +0000 (22:41 +0200)
The overloads were broken in 8a4c27589500bc57605bb8f28c215f5f0ae5066d

Change-Id: I3736b15e95ead28537e25169a54521e991f763da

lib/sqlalchemy/engine/_row_cy.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/testing/fixtures/mypy.py
test/typing/plain_files/engine/engine_result.py [new file with mode: 0644]

index 4319e05f0bb563437a1832dc21815b05260dd590..76659e193310a7776f6f238f970963517fccab6c 100644 (file)
@@ -112,8 +112,10 @@ class BaseRow:
     def __hash__(self) -> int:
         return hash(self._data)
 
-    def __getitem__(self, key: Any) -> Any:
-        return self._data[key]
+    if not TYPE_CHECKING:
+
+        def __getitem__(self, key: Any) -> Any:
+            return self._data[key]
 
     def _get_by_key_impl_mapping(self, key: _KeyType) -> Any:
         return self._get_by_key_impl(key, False)
index 2aa0aec9cd3e00837e4a06462cf00a45bd72b2a2..46c85d6f6c4f012a350ca234cfbf760fe9f8d6a1 100644 (file)
@@ -724,6 +724,14 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
         return manyrows
 
+    @overload
+    def _only_one_row(
+        self: ResultInternal[Row[_T, Unpack[TupleAny]]],
+        raise_for_second_row: bool,
+        raise_for_none: bool,
+        scalar: Literal[True],
+    ) -> _T: ...
+
     @overload
     def _only_one_row(
         self,
@@ -1463,13 +1471,7 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]):
             raise_for_second_row=True, raise_for_none=False, scalar=False
         )
 
-    @overload
-    def scalar_one(self: Result[_T]) -> _T: ...
-
-    @overload
-    def scalar_one(self) -> Any: ...
-
-    def scalar_one(self) -> Any:
+    def scalar_one(self: Result[_T, Unpack[TupleAny]]) -> _T:
         """Return exactly one scalar result or raise an exception.
 
         This is equivalent to calling :meth:`_engine.Result.scalars` and
@@ -1486,13 +1488,7 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]):
             raise_for_second_row=True, raise_for_none=True, scalar=True
         )
 
-    @overload
-    def scalar_one_or_none(self: Result[_T]) -> Optional[_T]: ...
-
-    @overload
-    def scalar_one_or_none(self) -> Optional[Any]: ...
-
-    def scalar_one_or_none(self) -> Optional[Any]:
+    def scalar_one_or_none(self: Result[_T, Unpack[TupleAny]]) -> Optional[_T]:
         """Return exactly one scalar result or ``None``.
 
         This is equivalent to calling :meth:`_engine.Result.scalars` and
@@ -1542,13 +1538,7 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]):
             raise_for_second_row=True, raise_for_none=True, scalar=False
         )
 
-    @overload
-    def scalar(self: Result[_T]) -> Optional[_T]: ...
-
-    @overload
-    def scalar(self) -> Any: ...
-
-    def scalar(self) -> Any:
+    def scalar(self: Result[_T, Unpack[TupleAny]]) -> Optional[_T]:
         """Fetch the first column of the first row, and close the result set.
 
         Returns ``None`` if there are no rows to fetch.
index 3a1ae2e9bda492cf05eaa2c7a951bd1c3e16ec18..4b43225789cddbea07f7ec8aa6aaa40822179f8c 100644 (file)
@@ -129,7 +129,9 @@ class MypyTest(TestBase):
 
     def _collect_messages(self, path):
         expected_messages = []
-        expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
+        expected_re = re.compile(
+            r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)"
+        )
         py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
         with open(path) as file_:
             current_assert_messages = []
@@ -147,9 +149,24 @@ class MypyTest(TestBase):
                 if m:
                     is_mypy = bool(m.group(1))
                     is_re = bool(m.group(2))
-                    is_type = bool(m.group(3))
+                    is_row = bool(m.group(3))
+                    is_type = bool(m.group(4))
+
+                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5))
+                    if is_row:
+                        expected_msg = re.sub(
+                            r"Row\[([^\]]+)\]",
+                            lambda m: f"tuple[{m.group(1)}, fallback=s"
+                            f"qlalchemy.engine.row.{m.group(0)}]",
+                            expected_msg,
+                        )
+                        # For some reason it does not use or syntax (|)
+                        expected_msg = re.sub(
+                            r"Optional\[(.*)\]",
+                            lambda m: f"Union[{m.group(1)}, None]",
+                            expected_msg,
+                        )
 
-                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
                     if is_type:
                         if not is_re:
                             # the goal here is that we can cut-and-paste
@@ -213,7 +230,9 @@ class MypyTest(TestBase):
 
         return expected_messages
 
-    def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
+    def _check_output(
+        self, path, expected_messages, stdout: str, stderr, exitcode
+    ):
         not_located = []
         filename = os.path.basename(path)
         if expected_messages:
@@ -233,7 +252,8 @@ class MypyTest(TestBase):
                 ):
                     while raw_lines:
                         ol = raw_lines.pop(0)
-                        if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
+                        if not re.match(r".+\.py:\d+: note: +def .*", ol):
+                            raw_lines.insert(0, ol)
                             break
                 elif re.match(
                     r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
diff --git a/test/typing/plain_files/engine/engine_result.py b/test/typing/plain_files/engine/engine_result.py
new file mode 100644 (file)
index 0000000..c873161
--- /dev/null
@@ -0,0 +1,75 @@
+from typing import reveal_type
+
+from sqlalchemy import column
+from sqlalchemy.engine import Result
+from sqlalchemy.engine import Row
+
+
+def row_one(row: Row[int, str, bool]) -> None:
+    # EXPECTED_TYPE: int
+    reveal_type(row[0])
+    # EXPECTED_TYPE: str
+    reveal_type(row[1])
+    # EXPECTED_TYPE: bool
+    reveal_type(row[2])
+
+    # EXPECTED_MYPY: Tuple index out of range
+    row[3]
+    # EXPECTED_MYPY: No overload variant of "__getitem__" of "tuple" matches argument type "str"  # noqa: E501
+    row["a"]
+
+    # EXPECTED_TYPE: RowMapping
+    reveal_type(row._mapping)
+    rm = row._mapping
+    # EXPECTED_TYPE: Any
+    reveal_type(rm["foo"])
+    # EXPECTED_TYPE: Any
+    reveal_type(rm[column("bar")])
+
+    # EXPECTED_MYPY: Invalid index type "int" for "RowMapping"; expected type "str | SQLCoreOperations[Any]"  # noqa: E501
+    rm[3]
+
+
+def result_one(res: Result[int, str]) -> None:
+    # EXPECTED_ROW_TYPE: Row[int, str]
+    reveal_type(res.one())
+    # EXPECTED_ROW_TYPE: Optional[Row[int, str]]
+    reveal_type(res.one_or_none())
+    # EXPECTED_ROW_TYPE: Optional[Row[int, str]]
+    reveal_type(res.fetchone())
+    # EXPECTED_ROW_TYPE: Optional[Row[int, str]]
+    reveal_type(res.first())
+    # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
+    reveal_type(res.all())
+    # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
+    reveal_type(res.fetchmany())
+    # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
+    reveal_type(res.fetchall())
+    # EXPECTED_ROW_TYPE: Row[int, str]
+    reveal_type(next(res))
+    for rf in res:
+        # EXPECTED_ROW_TYPE: Row[int, str]
+        reveal_type(rf)
+    for rp in res.partitions():
+        # EXPECTED_ROW_TYPE: Sequence[Row[int, str]]
+        reveal_type(rp)
+
+    # EXPECTED_TYPE: ScalarResult[int]
+    res_s = reveal_type(res.scalars())
+    # EXPECTED_TYPE: ScalarResult[int]
+    res_s = reveal_type(res.scalars(0))
+    # EXPECTED_TYPE: int
+    reveal_type(res_s.one())
+    # EXPECTED_TYPE: ScalarResult[Any]
+    reveal_type(res.scalars(1))
+    # EXPECTED_TYPE: MappingResult
+    reveal_type(res.mappings())
+    # EXPECTED_TYPE: FrozenResult[int, str]
+    reveal_type(res.freeze())
+
+    # EXPECTED_TYPE: int
+    reveal_type(res.scalar_one())
+    # EXPECTED_TYPE: Union[int, None]
+    reveal_type(res.scalar_one_or_none())
+    # EXPECTED_TYPE: Union[int, None]
+    reveal_type(res.scalar())