]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Result.__enter__ annotation
authorMartin Baláž <embeembe@gmail.com>
Sun, 22 Jan 2023 16:16:56 +0000 (11:16 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jan 2023 16:00:02 +0000 (11:00 -0500)
Fixed typing issue where the object type when using :class:`_engine.Result`
as a context manager were not preserved, indicating :class:`_engine.Result`
in all cases rather than the specific :class:`_engine.Result` sub-type.
Pull request courtesy Martin Baláž.

Fixes: #9136
Closes: #9135
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9135
Pull-request-sha: 97a9829db59db359fbb400ec0d913bdf8954f00a

Change-Id: I60a7f89ba39bf0f9fc5e6e7bf09f642167fe476f

doc/build/changelog/unreleased_20/more_typing.rst
lib/sqlalchemy/engine/result.py
test/ext/mypy/plain_files/typed_results.py

index b958d0d91f055939150264572e0b2e2e26341213..62cd04e8b4d506377105955fc8681046a44fad4f 100644 (file)
     :tickets: 9125
 
     Fixed typing issue where iterating over a :class:`_orm.Query` object
-    was not correctly typed. 
+    was not correctly typed.
+
+.. change::
+    :tags: typing, bug
+    :tickets: 9136
+
+    Fixed typing issue where the object type when using :class:`_engine.Result`
+    as a context manager were not preserved, indicating :class:`_engine.Result`
+    in all cases rather than the specific :class:`_engine.Result` sub-type.
+    Pull request courtesy Martin Baláž.
index 4bf03ae6963b79878acc7459d46909b269ff8ccd..67151913e7221b2f8d103e664177caae37050b8d 100644 (file)
@@ -929,7 +929,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
     def __init__(self, cursor_metadata: ResultMetaData):
         self._metadata = cursor_metadata
 
-    def __enter__(self) -> Result[_TP]:
+    def __enter__(self: SelfResult) -> SelfResult:
         return self
 
     def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
index 262e5b5ff2cc04229371703a94e91145beb8ad76..8fd9e5cd13bc883366e3a0bbbdbe19ebfc211e2b 100644 (file)
@@ -77,6 +77,28 @@ multi_stmt = select(User.id, User.name).where(User.name == "foo")
 reveal_type(multi_stmt)
 
 
+def t_result_ctxmanager() -> None:
+    with connection.execute(select(column("q", Integer))) as r1:
+        # EXPECTED_TYPE: CursorResult[Tuple[int]]
+        reveal_type(r1)
+
+        with r1.mappings() as r1m:
+            # EXPECTED_TYPE: MappingResult
+            reveal_type(r1m)
+
+    with connection.scalars(select(column("q", Integer))) as r2:
+        # EXPECTED_TYPE: ScalarResult[int]
+        reveal_type(r2)
+
+    with session.execute(select(User.id)) as r3:
+        # EXPECTED_TYPE: Result[Tuple[int]]
+        reveal_type(r3)
+
+    with session.scalars(select(User.id)) as r4:
+        # EXPECTED_TYPE: ScalarResult[int]
+        reveal_type(r4)
+
+
 def t_entity_varieties() -> None:
 
     a1 = aliased(User)