From: Mike Bayer Date: Mon, 18 Jul 2022 12:48:55 +0000 (-0400) Subject: add contextmanager typing, open run_sync typing X-Git-Tag: rel_2_0_0b1~172^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=10204576215fad27640739c295b9208a0bb3c6ce;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add contextmanager typing, open run_sync typing was missing AsyncConnection type for the async context manager. fixing that revealed that _SyncConnectionCallable and _SyncSessionCallable protocols are infeasible because the given callable can have a lot of different signatures that are compatible. Change-Id: I559aa3dd88a902d0e7681c52223bb4bc0890adc1 --- diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 97d69fcbd2..2418dab884 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Any +from typing import Callable from typing import Dict from typing import Generator from typing import NoReturn @@ -33,7 +34,6 @@ from ...engine import Engine from ...engine.base import NestedTransaction from ...engine.base import Transaction from ...util.concurrency import greenlet_spawn -from ...util.typing import Protocol if TYPE_CHECKING: from ...engine.cursor import CursorResult @@ -55,11 +55,6 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) -class _SyncConnectionCallable(Protocol): - def __call__(self, connection: Connection, *arg: Any, **kw: Any) -> Any: - ... - - def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: """Create a new async engine instance. @@ -667,7 +662,7 @@ class AsyncConnection( return result.scalars() async def run_sync( - self, fn: _SyncConnectionCallable, *arg: Any, **kw: Any + self, fn: Callable[..., Any], *arg: Any, **kw: Any ) -> Any: """Invoke the given sync callable passing self as the first argument. @@ -845,6 +840,11 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): def __init__(self, conn: AsyncConnection): self.conn = conn + if TYPE_CHECKING: + + async def __aenter__(self) -> AsyncConnection: + ... + async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: await self.conn.start(is_ctxmanager=is_ctxmanager) self.transaction = self.conn.begin() diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index be3414cef4..ea587f8de4 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Any +from typing import Callable from typing import Dict from typing import Generic from typing import Iterable @@ -33,7 +34,6 @@ from ...orm import Session from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn -from ...util.typing import Protocol if TYPE_CHECKING: from .engine import AsyncConnection @@ -68,11 +68,6 @@ _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] _T = TypeVar("_T", bound=Any) -class _SyncSessionCallable(Protocol): - def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any: - ... - - _EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) @@ -234,7 +229,7 @@ class AsyncSession(ReversibleProxy[Session]): ) async def run_sync( - self, fn: _SyncSessionCallable, *arg: Any, **kw: Any + self, fn: Callable[..., Any], *arg: Any, **kw: Any ) -> Any: """Invoke the given sync callable passing sync self as the first argument. diff --git a/test/ext/mypy/plain_files/async_sessionmaker.py b/test/ext/mypy/plain_files/async_sessionmaker.py index 01a26d0354..e28e9499b4 100644 --- a/test/ext/mypy/plain_files/async_sessionmaker.py +++ b/test/ext/mypy/plain_files/async_sessionmaker.py @@ -5,7 +5,9 @@ for asynchronous ORM use. from __future__ import annotations import asyncio +from typing import Any from typing import List +from typing import Optional from typing import TYPE_CHECKING from sqlalchemy import ForeignKey @@ -16,6 +18,7 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session if TYPE_CHECKING: from sqlalchemy import ScalarResult @@ -40,6 +43,14 @@ class B(Base): data: Mapped[str] +def work_with_a_session_one(sess: Session) -> Any: + pass + + +def work_with_a_session_two(sess: Session, param: Optional[str] = None) -> Any: + pass + + async def async_main() -> None: """Main program function.""" @@ -56,6 +67,9 @@ async def async_main() -> None: async_session = async_sessionmaker(engine, expire_on_commit=False) async with async_session.begin() as session: + await session.run_sync(work_with_a_session_one) + await session.run_sync(work_with_a_session_two, param="foo") + session.add_all( [ A(bs=[B(), B()], data="a1"),