]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed typing issues with sync code runners
authorFrancisco R. Del Roio <francipvb@hotmail.com>
Sun, 25 Feb 2024 19:37:27 +0000 (14:37 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Mar 2024 15:48:00 +0000 (11:48 -0400)
Fixed typing issue allowing asyncio ``run_sync()`` methods to correctly
type the parameters according to the callable that was passed, making use
of :pep:`612` ``ParamSpec`` variables.  Pull request courtesy Francisco R.
Del Roio.

Closes: #11055
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11055
Pull-request-sha: 712b4382b16e4c07c09ac40a570c4bfb76c28161

Change-Id: I94ec8bbb0688d6c6e1610f8f769abab550179c14
(cherry picked from commit b687624f63b8613f3c487866292fa88f763c79ee)

doc/build/changelog/unreleased_20/11055.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
test/typing/plain_files/ext/asyncio/async_sessionmaker.py
test/typing/plain_files/ext/asyncio/engines.py

diff --git a/doc/build/changelog/unreleased_20/11055.rst b/doc/build/changelog/unreleased_20/11055.rst
new file mode 100644 (file)
index 0000000..8784d7a
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 11055
+
+    Fixed typing issue allowing asyncio ``run_sync()`` methods to correctly
+    type the parameters according to the callable that was passed, making use
+    of :pep:`612` ``ParamSpec`` variables.  Pull request courtesy Francisco R.
+    Del Roio.
index dc6f89d6b5923ed4ffdf42c4e12661c62a54d097..5d7d7e6b4253fb3d4b7c8fca3a3fea4be1422ba4 100644 (file)
@@ -41,6 +41,8 @@ from ...engine.base import NestedTransaction
 from ...engine.base import Transaction
 from ...exc import ArgumentError
 from ...util.concurrency import greenlet_spawn
+from ...util.typing import Concatenate
+from ...util.typing import ParamSpec
 
 if TYPE_CHECKING:
     from ...engine.cursor import CursorResult
@@ -61,6 +63,7 @@ if TYPE_CHECKING:
     from ...sql.base import Executable
     from ...sql.selectable import TypedReturnsRows
 
+_P = ParamSpec("_P")
 _T = TypeVar("_T", bound=Any)
 
 
@@ -813,7 +816,10 @@ class AsyncConnection(
             yield result.scalars()
 
     async def run_sync(
-        self, fn: Callable[..., _T], *arg: Any, **kw: Any
+        self,
+        fn: Callable[Concatenate[Connection, _P], _T],
+        *arg: _P.args,
+        **kw: _P.kwargs,
     ) -> _T:
         """Invoke the given synchronous (i.e. not async) callable,
         passing a synchronous-style :class:`_engine.Connection` as the first
@@ -877,7 +883,9 @@ class AsyncConnection(
 
         """  # noqa: E501
 
-        return await greenlet_spawn(fn, self._proxied, *arg, **kw)
+        return await greenlet_spawn(
+            fn, self._proxied, *arg, _require_await=False, **kw
+        )
 
     def __await__(self) -> Generator[Any, None, AsyncConnection]:
         return self.start().__await__()
index a9ea55e4966d2e2fde6e09c253106a8fe8073697..c5fe469a0d4f25ab05332dfc80fd8316b92a140e 100644 (file)
@@ -38,6 +38,9 @@ from ...orm import Session
 from ...orm import SessionTransaction
 from ...orm import state as _instance_state
 from ...util.concurrency import greenlet_spawn
+from ...util.typing import Concatenate
+from ...util.typing import ParamSpec
+
 
 if TYPE_CHECKING:
     from .engine import AsyncConnection
@@ -71,6 +74,7 @@ if TYPE_CHECKING:
 
 _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
 
+_P = ParamSpec("_P")
 _T = TypeVar("_T", bound=Any)
 
 
@@ -332,7 +336,10 @@ class AsyncSession(ReversibleProxy[Session]):
         )
 
     async def run_sync(
-        self, fn: Callable[..., _T], *arg: Any, **kw: Any
+        self,
+        fn: Callable[Concatenate[Session, _P], _T],
+        *arg: _P.args,
+        **kw: _P.kwargs,
     ) -> _T:
         """Invoke the given synchronous (i.e. not async) callable,
         passing a synchronous-style :class:`_orm.Session` as the first
@@ -386,7 +393,9 @@ class AsyncSession(ReversibleProxy[Session]):
             :ref:`session_run_sync`
         """  # noqa: E501
 
-        return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+        return await greenlet_spawn(
+            fn, self.sync_session, *arg, _require_await=False, **kw
+        )
 
     @overload
     async def execute(
index d9997141a101160ec78479be696d387e0eecadad..b081aa1b130d13d2537c2d91e38a2fd179931ada 100644 (file)
@@ -52,6 +52,10 @@ def work_with_a_session_two(sess: Session, param: Optional[str] = None) -> Any:
     pass
 
 
+def work_with_wrong_parameter(session: Session, foo: int) -> Any:
+    pass
+
+
 async def async_main() -> None:
     """Main program function."""
 
@@ -71,6 +75,9 @@ async def async_main() -> None:
         await session.run_sync(work_with_a_session_one)
         await session.run_sync(work_with_a_session_two, param="foo")
 
+        # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncSession"
+        await session.run_sync(work_with_wrong_parameter)
+
         session.add_all(
             [
                 A(bs=[B(), B()], data="a1"),
index 598d319a7765ef17db5d0360369df1b8717739cf..01475dc71e594fcfc486ce518d6c8d3a1ecdfddd 100644 (file)
@@ -1,7 +1,14 @@
+from typing import Any
+
+from sqlalchemy import Connection
 from sqlalchemy import text
 from sqlalchemy.ext.asyncio import create_async_engine
 
 
+def work_sync(conn: Connection, foo: int) -> Any:
+    pass
+
+
 async def asyncio() -> None:
     e = create_async_engine("sqlite://")
 
@@ -53,3 +60,8 @@ async def asyncio() -> None:
 
         # EXPECTED_TYPE: CursorResult[Any]
         reveal_type(result)
+
+        await conn.run_sync(work_sync, 1)
+
+        # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection"
+        await conn.run_sync(work_sync)