]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed typing issues with sync code runners
authorFrancisco R. Del Roio <francipvb@hotmail.com>
Sun, 25 Feb 2024 19:37:27 +0000 (14:37 -0500)
committerMichael Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Mar 2024 17:06:41 +0000 (17:06 +0000)
Fixed typing issue allowing asyncio ``run_sync()`` methods to correctly
type the parameters according to the callable that was passed, making use
of :pep:`612` ``ParamSpec`` variables.  Pull request courtesy Francisco R.
Del Roio.

Closes: #11055
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11055
Pull-request-sha: 712b4382b16e4c07c09ac40a570c4bfb76c28161

Change-Id: I94ec8bbb0688d6c6e1610f8f769abab550179c14

doc/build/changelog/unreleased_20/11055.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
test/typing/plain_files/ext/asyncio/async_sessionmaker.py
test/typing/plain_files/ext/asyncio/engines.py

diff --git a/doc/build/changelog/unreleased_20/11055.rst b/doc/build/changelog/unreleased_20/11055.rst
new file mode 100644 (file)
index 0000000..8784d7a
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 11055
+
+    Fixed typing issue allowing asyncio ``run_sync()`` methods to correctly
+    type the parameters according to the callable that was passed, making use
+    of :pep:`612` ``ParamSpec`` variables.  Pull request courtesy Francisco R.
+    Del Roio.
index ae04833ad60cb2e4ae701fd20b56518788897807..2be452747edf97630c4bba93544cc5e5880451a3 100644 (file)
@@ -40,6 +40,8 @@ from ...engine.base import NestedTransaction
 from ...engine.base import Transaction
 from ...exc import ArgumentError
 from ...util.concurrency import greenlet_spawn
+from ...util.typing import Concatenate
+from ...util.typing import ParamSpec
 from ...util.typing import TupleAny
 from ...util.typing import TypeVarTuple
 from ...util.typing import Unpack
@@ -63,6 +65,7 @@ if TYPE_CHECKING:
     from ...sql.base import Executable
     from ...sql.selectable import TypedReturnsRows
 
+_P = ParamSpec("_P")
 _T = TypeVar("_T", bound=Any)
 _Ts = TypeVarTuple("_Ts")
 
@@ -816,7 +819,10 @@ class AsyncConnection(
             yield result.scalars()
 
     async def run_sync(
-        self, fn: Callable[..., _T], *arg: Any, **kw: Any
+        self,
+        fn: Callable[Concatenate[Connection, _P], _T],
+        *arg: _P.args,
+        **kw: _P.kwargs,
     ) -> _T:
         """Invoke the given synchronous (i.e. not async) callable,
         passing a synchronous-style :class:`_engine.Connection` as the first
@@ -880,7 +886,9 @@ class AsyncConnection(
 
         """  # noqa: E501
 
-        return await greenlet_spawn(fn, self._proxied, *arg, **kw)
+        return await greenlet_spawn(
+            fn, self._proxied, *arg, _require_await=False, **kw
+        )
 
     def __await__(self) -> Generator[Any, None, AsyncConnection]:
         return self.start().__await__()
index f8c823cff0637a89dc774fe976c45a03971b6085..87f1a8c977137fcd7511b31bccbce0b4737170f0 100644 (file)
@@ -38,6 +38,8 @@ from ...orm import Session
 from ...orm import SessionTransaction
 from ...orm import state as _instance_state
 from ...util.concurrency import greenlet_spawn
+from ...util.typing import Concatenate
+from ...util.typing import ParamSpec
 from ...util.typing import TupleAny
 from ...util.typing import TypeVarTuple
 from ...util.typing import Unpack
@@ -75,6 +77,7 @@ if TYPE_CHECKING:
 
 _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
 
+_P = ParamSpec("_P")
 _T = TypeVar("_T", bound=Any)
 _Ts = TypeVarTuple("_Ts")
 
@@ -336,7 +339,10 @@ class AsyncSession(ReversibleProxy[Session]):
         )
 
     async def run_sync(
-        self, fn: Callable[..., _T], *arg: Any, **kw: Any
+        self,
+        fn: Callable[Concatenate[Session, _P], _T],
+        *arg: _P.args,
+        **kw: _P.kwargs,
     ) -> _T:
         """Invoke the given synchronous (i.e. not async) callable,
         passing a synchronous-style :class:`_orm.Session` as the first
@@ -390,7 +396,9 @@ class AsyncSession(ReversibleProxy[Session]):
             :ref:`session_run_sync`
         """  # noqa: E501
 
-        return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+        return await greenlet_spawn(
+            fn, self.sync_session, *arg, _require_await=False, **kw
+        )
 
     @overload
     async def execute(
index d9997141a101160ec78479be696d387e0eecadad..b081aa1b130d13d2537c2d91e38a2fd179931ada 100644 (file)
@@ -52,6 +52,10 @@ def work_with_a_session_two(sess: Session, param: Optional[str] = None) -> Any:
     pass
 
 
+def work_with_wrong_parameter(session: Session, foo: int) -> Any:
+    pass
+
+
 async def async_main() -> None:
     """Main program function."""
 
@@ -71,6 +75,9 @@ async def async_main() -> None:
         await session.run_sync(work_with_a_session_one)
         await session.run_sync(work_with_a_session_two, param="foo")
 
+        # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncSession"
+        await session.run_sync(work_with_wrong_parameter)
+
         session.add_all(
             [
                 A(bs=[B(), B()], data="a1"),
index ae7880f58490cc0427aa3744970709a86d1d60e3..1b13ff1e9524ad1216b3a8ab37a5d3968bca8855 100644 (file)
@@ -1,7 +1,14 @@
+from typing import Any
+
+from sqlalchemy import Connection
 from sqlalchemy import text
 from sqlalchemy.ext.asyncio import create_async_engine
 
 
+def work_sync(conn: Connection, foo: int) -> Any:
+    pass
+
+
 async def asyncio() -> None:
     e = create_async_engine("sqlite://")
 
@@ -53,3 +60,8 @@ async def asyncio() -> None:
 
         # EXPECTED_TYPE: CursorResult[Unpack[.*tuple[Any, ...]]]
         reveal_type(result)
+
+        await conn.run_sync(work_sync, 1)
+
+        # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection"
+        await conn.run_sync(work_sync)