From: Francisco R. Del Roio Date: Sun, 25 Feb 2024 19:37:27 +0000 (-0500) Subject: Fixed typing issues with sync code runners X-Git-Tag: rel_2_0_29~35 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=07b8c4ba5655c5b95bd839732a291707654bf113;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixed typing issues with sync code runners Fixed typing issue allowing asyncio ``run_sync()`` methods to correctly type the parameters according to the callable that was passed, making use of :pep:`612` ``ParamSpec`` variables. Pull request courtesy Francisco R. Del Roio. Closes: #11055 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11055 Pull-request-sha: 712b4382b16e4c07c09ac40a570c4bfb76c28161 Change-Id: I94ec8bbb0688d6c6e1610f8f769abab550179c14 (cherry picked from commit b687624f63b8613f3c487866292fa88f763c79ee) --- diff --git a/doc/build/changelog/unreleased_20/11055.rst b/doc/build/changelog/unreleased_20/11055.rst new file mode 100644 index 0000000000..8784d7aec1 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11055.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, typing + :tickets: 11055 + + Fixed typing issue allowing asyncio ``run_sync()`` methods to correctly + type the parameters according to the callable that was passed, making use + of :pep:`612` ``ParamSpec`` variables. Pull request courtesy Francisco R. + Del Roio. diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index dc6f89d6b5..5d7d7e6b42 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -41,6 +41,8 @@ from ...engine.base import NestedTransaction from ...engine.base import Transaction from ...exc import ArgumentError from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec if TYPE_CHECKING: from ...engine.cursor import CursorResult @@ -61,6 +63,7 @@ if TYPE_CHECKING: from ...sql.base import Executable from ...sql.selectable import TypedReturnsRows +_P = ParamSpec("_P") _T = TypeVar("_T", bound=Any) @@ -813,7 +816,10 @@ class AsyncConnection( yield result.scalars() async def run_sync( - self, fn: Callable[..., _T], *arg: Any, **kw: Any + self, + fn: Callable[Concatenate[Connection, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, ) -> _T: """Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_engine.Connection` as the first @@ -877,7 +883,9 @@ class AsyncConnection( """ # noqa: E501 - return await greenlet_spawn(fn, self._proxied, *arg, **kw) + return await greenlet_spawn( + fn, self._proxied, *arg, _require_await=False, **kw + ) def __await__(self) -> Generator[Any, None, AsyncConnection]: return self.start().__await__() diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index a9ea55e496..c5fe469a0d 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -38,6 +38,9 @@ from ...orm import Session from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec + if TYPE_CHECKING: from .engine import AsyncConnection @@ -71,6 +74,7 @@ if TYPE_CHECKING: _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] +_P = ParamSpec("_P") _T = TypeVar("_T", bound=Any) @@ -332,7 +336,10 @@ class AsyncSession(ReversibleProxy[Session]): ) async def run_sync( - self, fn: Callable[..., _T], *arg: Any, **kw: Any + self, + fn: Callable[Concatenate[Session, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, ) -> _T: """Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_orm.Session` as the first @@ -386,7 +393,9 @@ class AsyncSession(ReversibleProxy[Session]): :ref:`session_run_sync` """ # noqa: E501 - return await greenlet_spawn(fn, self.sync_session, *arg, **kw) + return await greenlet_spawn( + fn, self.sync_session, *arg, _require_await=False, **kw + ) @overload async def execute( diff --git a/test/typing/plain_files/ext/asyncio/async_sessionmaker.py b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py index d9997141a1..b081aa1b13 100644 --- a/test/typing/plain_files/ext/asyncio/async_sessionmaker.py +++ b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py @@ -52,6 +52,10 @@ def work_with_a_session_two(sess: Session, param: Optional[str] = None) -> Any: pass +def work_with_wrong_parameter(session: Session, foo: int) -> Any: + pass + + async def async_main() -> None: """Main program function.""" @@ -71,6 +75,9 @@ async def async_main() -> None: await session.run_sync(work_with_a_session_one) await session.run_sync(work_with_a_session_two, param="foo") + # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncSession" + await session.run_sync(work_with_wrong_parameter) + session.add_all( [ A(bs=[B(), B()], data="a1"), diff --git a/test/typing/plain_files/ext/asyncio/engines.py b/test/typing/plain_files/ext/asyncio/engines.py index 598d319a77..01475dc71e 100644 --- a/test/typing/plain_files/ext/asyncio/engines.py +++ b/test/typing/plain_files/ext/asyncio/engines.py @@ -1,7 +1,14 @@ +from typing import Any + +from sqlalchemy import Connection from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine +def work_sync(conn: Connection, foo: int) -> Any: + pass + + async def asyncio() -> None: e = create_async_engine("sqlite://") @@ -53,3 +60,8 @@ async def asyncio() -> None: # EXPECTED_TYPE: CursorResult[Any] reveal_type(result) + + await conn.run_sync(work_sync, 1) + + # EXPECTED_MYPY: Missing positional argument "foo" in call to "run_sync" of "AsyncConnection" + await conn.run_sync(work_sync)