]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add contextmanager typing, open run_sync typing
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Jul 2022 12:48:55 +0000 (08:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Jul 2022 15:03:31 +0000 (11:03 -0400)
was missing AsyncConnection type for the async
context manager.

fixing that revealed that _SyncConnectionCallable
and _SyncSessionCallable protocols are infeasible because
the given callable can have a lot of different signatures
that are compatible.

Change-Id: I559aa3dd88a902d0e7681c52223bb4bc0890adc1

lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
test/ext/mypy/plain_files/async_sessionmaker.py

index 97d69fcbd29a1eaed4d734052c88af58ec4e609d..2418dab884f99ea55459ec5d0e9b14eae4711052 100644 (file)
@@ -7,6 +7,7 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import Generator
 from typing import NoReturn
@@ -33,7 +34,6 @@ from ...engine import Engine
 from ...engine.base import NestedTransaction
 from ...engine.base import Transaction
 from ...util.concurrency import greenlet_spawn
-from ...util.typing import Protocol
 
 if TYPE_CHECKING:
     from ...engine.cursor import CursorResult
@@ -55,11 +55,6 @@ if TYPE_CHECKING:
 _T = TypeVar("_T", bound=Any)
 
 
-class _SyncConnectionCallable(Protocol):
-    def __call__(self, connection: Connection, *arg: Any, **kw: Any) -> Any:
-        ...
-
-
 def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
     """Create a new async engine instance.
 
@@ -667,7 +662,7 @@ class AsyncConnection(
         return result.scalars()
 
     async def run_sync(
-        self, fn: _SyncConnectionCallable, *arg: Any, **kw: Any
+        self, fn: Callable[..., Any], *arg: Any, **kw: Any
     ) -> Any:
         """Invoke the given sync callable passing self as the first argument.
 
@@ -845,6 +840,11 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
         def __init__(self, conn: AsyncConnection):
             self.conn = conn
 
+        if TYPE_CHECKING:
+
+            async def __aenter__(self) -> AsyncConnection:
+                ...
+
         async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
             await self.conn.start(is_ctxmanager=is_ctxmanager)
             self.transaction = self.conn.begin()
index be3414cef4aa39f13194a3bdf0cc8bd3f6170ee7..ea587f8de46bcb2cb095a2a97f319f176057a06a 100644 (file)
@@ -7,6 +7,7 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import Generic
 from typing import Iterable
@@ -33,7 +34,6 @@ from ...orm import Session
 from ...orm import SessionTransaction
 from ...orm import state as _instance_state
 from ...util.concurrency import greenlet_spawn
-from ...util.typing import Protocol
 
 if TYPE_CHECKING:
     from .engine import AsyncConnection
@@ -68,11 +68,6 @@ _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
 _T = TypeVar("_T", bound=Any)
 
 
-class _SyncSessionCallable(Protocol):
-    def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
-        ...
-
-
 _EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
 _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
 
@@ -234,7 +229,7 @@ class AsyncSession(ReversibleProxy[Session]):
         )
 
     async def run_sync(
-        self, fn: _SyncSessionCallable, *arg: Any, **kw: Any
+        self, fn: Callable[..., Any], *arg: Any, **kw: Any
     ) -> Any:
         """Invoke the given sync callable passing sync self as the first
         argument.
index 01a26d0354c71382b55e5280d8a66926c7841697..e28e9499b47c9a4f0c04e3ad066339aab702c5c6 100644 (file)
@@ -5,7 +5,9 @@ for asynchronous ORM use.
 from __future__ import annotations
 
 import asyncio
+from typing import Any
 from typing import List
+from typing import Optional
 from typing import TYPE_CHECKING
 
 from sqlalchemy import ForeignKey
@@ -16,6 +18,7 @@ from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 
 if TYPE_CHECKING:
     from sqlalchemy import ScalarResult
@@ -40,6 +43,14 @@ class B(Base):
     data: Mapped[str]
 
 
+def work_with_a_session_one(sess: Session) -> Any:
+    pass
+
+
+def work_with_a_session_two(sess: Session, param: Optional[str] = None) -> Any:
+    pass
+
+
 async def async_main() -> None:
     """Main program function."""
 
@@ -56,6 +67,9 @@ async def async_main() -> None:
     async_session = async_sessionmaker(engine, expire_on_commit=False)
 
     async with async_session.begin() as session:
+        await session.run_sync(work_with_a_session_one)
+        await session.run_sync(work_with_a_session_two, param="foo")
+
         session.add_all(
             [
                 A(bs=[B(), B()], data="a1"),