]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
🐛 Fix `AsyncSession` type annotations for `exec()` (#58)
authorArseny Boykov <36469655+Bobronium@users.noreply.github.com>
Mon, 23 Oct 2023 14:58:16 +0000 (17:58 +0300)
committerGitHub <noreply@github.com>
Mon, 23 Oct 2023 14:58:16 +0000 (18:58 +0400)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
sqlmodel/ext/asyncio/session.py
sqlmodel/orm/session.py

index 80267b25e5243fa4e0709981f35671bc4ffc1a0a..f500c44dc2192701fb6a8b37f5ed9f195d4b1989 100644 (file)
@@ -1,17 +1,17 @@
-from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
+from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
 
 from sqlalchemy import util
 from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
 from sqlalchemy.ext.asyncio import engine
 from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
 from sqlalchemy.util.concurrency import greenlet_spawn
-from sqlmodel.sql.base import Executable
 
-from ...engine.result import ScalarResult
+from ...engine.result import Result, ScalarResult
 from ...orm.session import Session
-from ...sql.expression import Select
+from ...sql.base import Executable
+from ...sql.expression import Select, SelectOfScalar
 
-_T = TypeVar("_T")
+_TSelectParam = TypeVar("_TSelectParam")
 
 
 class AsyncSession(_AsyncSession):
@@ -40,14 +40,46 @@ class AsyncSession(_AsyncSession):
             Session(bind=bind, binds=binds, **kw)  # type: ignore
         )
 
+    @overload
     async def exec(
         self,
-        statement: Union[Select[_T], Executable[_T]],
+        statement: Select[_TSelectParam],
+        *,
+        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+        **kw: Any,
+    ) -> Result[_TSelectParam]:
+        ...
+
+    @overload
+    async def exec(
+        self,
+        statement: SelectOfScalar[_TSelectParam],
+        *,
+        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
+        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping[str, Any]] = None,
+        _parent_execute_state: Optional[Any] = None,
+        _add_event: Optional[Any] = None,
+        **kw: Any,
+    ) -> ScalarResult[_TSelectParam]:
+        ...
+
+    async def exec(
+        self,
+        statement: Union[
+            Select[_TSelectParam],
+            SelectOfScalar[_TSelectParam],
+            Executable[_TSelectParam],
+        ],
         params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
         execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
         bind_arguments: Optional[Mapping[str, Any]] = None,
         **kw: Any,
-    ) -> ScalarResult[_T]:
+    ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
         # TODO: the documentation says execution_options accepts a dict, but only
         # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
         execution_options = execution_options.union({"prebuffer_rows": True})  # type: ignore
index 1692fdcbcbca1003b7df52b20a44475714be7616..0c70c290ae9e5b3317bd2ddecf2665bcb83531a0 100644 (file)
@@ -4,11 +4,11 @@ from sqlalchemy import util
 from sqlalchemy.orm import Query as _Query
 from sqlalchemy.orm import Session as _Session
 from sqlalchemy.sql.base import Executable as _Executable
-from sqlmodel.sql.expression import Select, SelectOfScalar
 from typing_extensions import Literal
 
 from ..engine.result import Result, ScalarResult
 from ..sql.base import Executable
+from ..sql.expression import Select, SelectOfScalar
 
 _TSelectParam = TypeVar("_TSelectParam")