]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add correct typing for row getitem
authorFederico Caselli <cfederico87@gmail.com>
Mon, 28 Apr 2025 21:44:50 +0000 (23:44 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 29 Apr 2025 21:29:39 +0000 (23:29 +0200)
The overloads were broken in 8a4c27589500bc57605bb8f28c215f5f0ae5066d

Change-Id: I3736b15e95ead28537e25169a54521e991f763da
(cherry picked from commit 4ac02007e030232f57226aafbb9313c8ff186a62)

lib/sqlalchemy/engine/result.py
lib/sqlalchemy/testing/fixtures/mypy.py
test/typing/plain_files/engine/engine_result.py [new file with mode: 0644]

index 5d597fd5f49af66b16d64eb60fb594d5b6d4d622..b84fb3d1cb5e45ffe789b603a6caf5292a3a1bfc 100644 (file)
@@ -722,6 +722,14 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
 
         return manyrows
 
+    @overload
+    def _only_one_row(
+        self: ResultInternal[Row[Any]],
+        raise_for_second_row: bool,
+        raise_for_none: bool,
+        scalar: Literal[True],
+    ) -> Any: ...
+
     @overload
     def _only_one_row(
         self,
index 0832d89246f4c0954f7f396f64d2c593ee4a191b..849df4dc30a8511161faf25d462a3fc2319c5125 100644 (file)
@@ -143,7 +143,9 @@ class MypyTest(TestBase):
         from sqlalchemy.ext.mypy.util import mypy_14
 
         expected_messages = []
-        expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
+        expected_re = re.compile(
+            r"\s*# EXPECTED(_MYPY)?(_RE)?(_ROW)?(_TYPE)?: (.+)"
+        )
         py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
         with open(path) as file_:
             current_assert_messages = []
@@ -161,9 +163,24 @@ class MypyTest(TestBase):
                 if m:
                     is_mypy = bool(m.group(1))
                     is_re = bool(m.group(2))
-                    is_type = bool(m.group(3))
+                    is_row = bool(m.group(3))
+                    is_type = bool(m.group(4))
+
+                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(5))
+                    if is_row:
+                        expected_msg = re.sub(
+                            r"Row\[([^\]]+)\]",
+                            lambda m: f"tuple[{m.group(1)}, fallback=s"
+                            f"qlalchemy.engine.row.{m.group(0)}]",
+                            expected_msg,
+                        )
+                        # For some reason it does not use or syntax (|)
+                        expected_msg = re.sub(
+                            r"Optional\[(.*)\]",
+                            lambda m: f"Union[{m.group(1)}, None]",
+                            expected_msg,
+                        )
 
-                    expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
                     if is_type:
                         if not is_re:
                             # the goal here is that we can cut-and-paste
@@ -243,7 +260,9 @@ class MypyTest(TestBase):
 
         return expected_messages
 
-    def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
+    def _check_output(
+        self, path, expected_messages, stdout: str, stderr, exitcode
+    ):
         not_located = []
         filename = os.path.basename(path)
         if expected_messages:
@@ -263,7 +282,8 @@ class MypyTest(TestBase):
                 ):
                     while raw_lines:
                         ol = raw_lines.pop(0)
-                        if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
+                        if not re.match(r".+\.py:\d+: note: +def .*", ol):
+                            raw_lines.insert(0, ol)
                             break
                 elif re.match(
                     r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
diff --git a/test/typing/plain_files/engine/engine_result.py b/test/typing/plain_files/engine/engine_result.py
new file mode 100644 (file)
index 0000000..7ff20b7
--- /dev/null
@@ -0,0 +1,94 @@
+from typing import reveal_type
+from typing import Tuple
+
+from sqlalchemy import column
+from sqlalchemy.engine import Result
+from sqlalchemy.engine import Row
+
+
+def row_one(row: Row[Tuple[int, str, bool]]) -> None:
+    # EXPECTED_TYPE: Any
+    reveal_type(row[0])
+    # EXPECTED_TYPE: Any
+    reveal_type(row[1])
+    # EXPECTED_TYPE: Any
+    reveal_type(row[2])
+
+    # EXPECTED_MYPY: No overload variant of "__getitem__" of "Row" matches argument type "str"  # noqa: E501
+    row["a"]
+
+    # EXPECTED_TYPE: RowMapping
+    reveal_type(row._mapping)
+    rm = row._mapping
+    # EXPECTED_TYPE: Any
+    reveal_type(rm["foo"])
+    # EXPECTED_TYPE: Any
+    reveal_type(rm[column("bar")])
+
+    # EXPECTED_MYPY: Invalid index type "int" for "RowMapping"; expected type "str | SQLCoreOperations[Any]"  # noqa: E501
+    rm[3]
+
+
+def result_one(
+    res: Result[Tuple[int, str]], r_single: Result[Tuple[float]]
+) -> None:
+    # EXPECTED_TYPE: Row[Tuple[int, str]]
+    reveal_type(res.one())
+    # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None]
+    reveal_type(res.one_or_none())
+    # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None]
+    reveal_type(res.fetchone())
+    # EXPECTED_TYPE: Union[Row[Tuple[int, str]], None]
+    reveal_type(res.first())
+    # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]]
+    reveal_type(res.all())
+    # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]]
+    reveal_type(res.fetchmany())
+    # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]]
+    reveal_type(res.fetchall())
+    # EXPECTED_TYPE: Row[Tuple[int, str]]
+    reveal_type(next(res))
+    for rf in res:
+        # EXPECTED_TYPE: Row[Tuple[int, str]]
+        reveal_type(rf)
+    for rp in res.partitions():
+        # EXPECTED_TYPE: Sequence[Row[Tuple[int, str]]]
+        reveal_type(rp)
+
+    # EXPECTED_TYPE: ScalarResult[Any]
+    res_s = reveal_type(res.scalars())
+    # EXPECTED_TYPE: ScalarResult[Any]
+    res_s = reveal_type(res.scalars(0))
+    # EXPECTED_TYPE: Any
+    reveal_type(res_s.one())
+    # EXPECTED_TYPE: ScalarResult[Any]
+    reveal_type(res.scalars(1))
+    # EXPECTED_TYPE: MappingResult
+    reveal_type(res.mappings())
+    # EXPECTED_TYPE: FrozenResult[Tuple[int, str]]
+    reveal_type(res.freeze())
+
+    # EXPECTED_TYPE: Any
+    reveal_type(res.scalar_one())
+    # EXPECTED_TYPE: Union[Any, None]
+    reveal_type(res.scalar_one_or_none())
+    # EXPECTED_TYPE: Any
+    reveal_type(res.scalar())
+
+    # EXPECTED_TYPE: ScalarResult[float]
+    res_s2 = reveal_type(r_single.scalars())
+    # EXPECTED_TYPE: ScalarResult[float]
+    res_s2 = reveal_type(r_single.scalars(0))
+    # EXPECTED_TYPE: float
+    reveal_type(res_s2.one())
+    # EXPECTED_TYPE: ScalarResult[Any]
+    reveal_type(r_single.scalars(1))
+    # EXPECTED_TYPE: MappingResult
+    reveal_type(r_single.mappings())
+
+    # EXPECTED_TYPE: float
+    reveal_type(r_single.scalar_one())
+    # EXPECTED_TYPE: Union[float, None]
+    reveal_type(r_single.scalar_one_or_none())
+    # EXPECTED_TYPE: Union[float, None]
+    reveal_type(r_single.scalar())