]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pep-484: asyncio
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Apr 2022 19:42:35 +0000 (15:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2022 02:11:07 +0000 (22:11 -0400)
in this patch the asyncio/events.py module, which
existed only to raise errors when trying to attach event
listeners, is removed, as we were already coding an asyncio-specific
workaround in upstream Pool / Session to raise this error,
just moved the error out to the target and did the same thing
for Engine.

We also add an async_sessionmaker class.  The initial rationale
here is because sessionmaker() is hardcoded to Session subclasses,
and there's not a way to get the use case of
sessionmaker(class_=AsyncSession) to type correctly without changing
the sessionmaker() symbol itself to be a function and not a class,
which gets too complicated for what this is. Additionally,
_SessionClassMethods has only three methods on it, one of which
is not usable with asyncio (close_all()), the others
not generally used from the session class.

Change-Id: I064a5fa5d91cc8d5bbe9597437536e37b4e801fe

29 files changed:
doc/build/orm/extensions/asyncio.rst
examples/asyncio/async_orm.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/events.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/event/base.py
lib/sqlalchemy/ext/asyncio/__init__.py
lib/sqlalchemy/ext/asyncio/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/events.py [deleted file]
lib/sqlalchemy/ext/asyncio/result.py
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/pool/events.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/util/_concurrency_py3k.py
pyproject.toml
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py
test/ext/mypy/plain_files/async_sessionmaker.py [new file with mode: 0644]
test/ext/mypy/plain_files/session.py

index 9badcb41843028e0fdbd6ad4476adc514a254ffc..82ba7cabb2b22f6be9e90c67ce519b4407026111 100644 (file)
@@ -147,12 +147,12 @@ illustrates a complete example including mapper and session configuration::
     from sqlalchemy import Integer
     from sqlalchemy import String
     from sqlalchemy.ext.asyncio import AsyncSession
+    from sqlalchemy.ext.asyncio import async_sessionmaker
     from sqlalchemy.ext.asyncio import create_async_engine
     from sqlalchemy.future import select
     from sqlalchemy.orm import declarative_base
     from sqlalchemy.orm import relationship
     from sqlalchemy.orm import selectinload
-    from sqlalchemy.orm import sessionmaker
 
     Base = declarative_base()
 
@@ -190,9 +190,7 @@ illustrates a complete example including mapper and session configuration::
 
         # expire_on_commit=False will prevent attributes from being expired
         # after commit.
-        async_session = sessionmaker(
-            engine, expire_on_commit=False, class_=AsyncSession
-        )
+        async_session = async_sessionmaker(engine, expire_on_commit=False)
 
         async with async_session() as session:
             async with session.begin():
@@ -234,7 +232,7 @@ illustrates a complete example including mapper and session configuration::
     asyncio.run(async_main())
 
 In the example above, the :class:`_asyncio.AsyncSession` is instantiated using
-the optional :class:`_orm.sessionmaker` helper, and associated with an
+the optional :class:`_asyncio.async_sessionmaker` helper, and associated with an
 :class:`_asyncio.AsyncEngine` against particular database URL. It is
 then used in a Python asynchronous context manager (i.e. ``async with:``
 statement) so that it is automatically closed at the end of the block; this is
@@ -284,8 +282,8 @@ prevent this:
       async_session = AsyncSession(engine, expire_on_commit=False)
 
       # sessionmaker version
-      async_session = sessionmaker(
-          engine, expire_on_commit=False, class_=AsyncSession
+      async_session = async_sessionmaker(
+          engine, expire_on_commit=False
       )
 
       async with async_session() as session:
@@ -722,11 +720,11 @@ constructor::
 
     from asyncio import current_task
 
-    from sqlalchemy.orm import sessionmaker
+    from sqlalchemy.ext.asyncio import async_sessionmaker
     from sqlalchemy.ext.asyncio import async_scoped_session
     from sqlalchemy.ext.asyncio import AsyncSession
 
-    async_session_factory = sessionmaker(some_async_engine, class_=AsyncSession)
+    async_session_factory = async_sessionmaker(some_async_engine, expire_on_commit=False)
     AsyncScopedSession = async_scoped_session(async_session_factory, scopefunc=current_task)
 
     some_async_session = AsyncScopedSession()
@@ -833,6 +831,10 @@ ORM Session API Documentation
 
 .. autofunction:: async_session
 
+.. autoclass:: async_sessionmaker
+   :members:
+   :inherited-members:
+
 .. autoclass:: async_scoped_session
    :members:
    :inherited-members:
index 174ebf30b5f4f11f795946b3ebb824f771f15898..4688911588a8dfc034812c26aabc83906548f198 100644 (file)
@@ -11,13 +11,12 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import String
-from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import async_sessionmaker
 from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.future import select
 from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
-from sqlalchemy.orm import sessionmaker
 
 Base = declarative_base()
 
@@ -58,9 +57,7 @@ async def async_main():
 
     # expire_on_commit=False will prevent attributes from being expired
     # after commit.
-    async_session = sessionmaker(
-        engine, expire_on_commit=False, class_=AsyncSession
-    )
+    async_session = async_sessionmaker(engine, expire_on_commit=False)
 
     async with async_session() as session:
         async with session.begin():
index 8bcc7e2587482da37f1f1724ea698b25b48f7b86..594a193446b16b4fc847f01e938d2f81f429848d 100644 (file)
@@ -42,7 +42,7 @@ from ..sql import util as sql_util
 _CompiledCacheType = MutableMapping[Any, "Compiled"]
 
 if typing.TYPE_CHECKING:
-    from . import Result
+    from . import CursorResult
     from . import ScalarResult
     from .interfaces import _AnyExecuteParams
     from .interfaces import _AnyMultiExecuteParams
@@ -472,7 +472,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         else:
             return self._dbapi_connection
 
-    def get_isolation_level(self) -> str:
+    def get_isolation_level(self) -> _IsolationLevel:
         """Return the current isolation level assigned to this
         :class:`_engine.Connection`.
 
@@ -1186,9 +1186,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         statement: Executable,
         parameters: Optional[_CoreAnyExecuteParams] = None,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-    ) -> Result:
+    ) -> CursorResult:
         r"""Executes a SQL statement construct and returns a
-        :class:`_engine.Result`.
+        :class:`_engine.CursorResult`.
 
         :param statement: The statement to be executed.  This is always
          an object that is in both the :class:`_expression.ClauseElement` and
@@ -1235,7 +1235,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         func: FunctionElement[Any],
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter,
-    ) -> Result:
+    ) -> CursorResult:
         """Execute a sql.FunctionElement object."""
 
         return self._execute_clauseelement(
@@ -1306,7 +1306,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         ddl: DDLElement,
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter,
-    ) -> Result:
+    ) -> CursorResult:
         """Execute a schema.DDL object."""
 
         execution_options = ddl._execution_options.merge_with(
@@ -1403,7 +1403,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         elem: Executable,
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter,
-    ) -> Result:
+    ) -> CursorResult:
         """Execute a sql.ClauseElement object."""
 
         execution_options = elem._execution_options.merge_with(
@@ -1476,7 +1476,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         compiled: Compiled,
         distilled_parameters: _CoreMultiExecuteParams,
         execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS,
-    ) -> Result:
+    ) -> CursorResult:
         """Execute a sql.Compiled object.
 
         TODO: why do we have this?   likely deprecate or remove
@@ -1526,7 +1526,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         statement: str,
         parameters: Optional[_DBAPIAnyExecuteParams] = None,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-    ) -> Result:
+    ) -> CursorResult:
         r"""Executes a SQL statement construct and returns a
         :class:`_engine.CursorResult`.
 
@@ -1603,7 +1603,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         execution_options: _ExecuteOptions,
         *args: Any,
         **kw: Any,
-    ) -> Result:
+    ) -> CursorResult:
         """Create an :class:`.ExecutionContext` and execute, returning
         a :class:`_engine.CursorResult`."""
 
index 699faf48975516e70247e2a6e7e20ba3f4cdc940..ef10946a86c077c5d5733783c0b412ddd2f1a126 100644 (file)
@@ -16,6 +16,7 @@ from typing import Tuple
 from typing import Type
 from typing import Union
 
+from .base import Connection
 from .base import Engine
 from .interfaces import ConnectionEventsTarget
 from .interfaces import DBAPIConnection
@@ -123,9 +124,23 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
     _dispatch_target = ConnectionEventsTarget
 
     @classmethod
-    def _listen(  # type: ignore[override]
+    def _accept_with(
+        cls,
+        target: Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]],
+    ) -> Optional[Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]]]:
+        default_dispatch = super()._accept_with(target)
+        if default_dispatch is None and hasattr(
+            target, "_no_async_engine_events"
+        ):
+            target._no_async_engine_events()  # type: ignore
+
+        return default_dispatch
+
+    @classmethod
+    def _listen(
         cls,
         event_key: event._EventKey[ConnectionEventsTarget],
+        *,
         retval: bool = False,
         **kw: Any,
     ) -> None:
@@ -769,7 +784,9 @@ class DialectEvents(event.Events[Dialect]):
     def _listen(  # type: ignore
         cls,
         event_key: event._EventKey[Dialect],
+        *,
         retval: bool = False,
+        **kw: Any,
     ) -> None:
         target = event_key.dispatch_target
 
@@ -789,10 +806,8 @@ class DialectEvents(event.Events[Dialect]):
             return target.dialect
         elif isinstance(target, Dialect):
             return target
-        elif hasattr(target, "dispatch") and hasattr(
-            target.dispatch._events, "_no_async_engine_events"
-        ):
-            target.dispatch._events._no_async_engine_events()
+        elif hasattr(target, "_no_async_engine_events"):
+            target._no_async_engine_events()
         else:
             return None
 
index aa75da61411a0636be262af7db6f3d6877bd1a9b..54fe21d747afed32d64a0636748b59d948fe5ce9 100644 (file)
@@ -46,7 +46,7 @@ from ..util.typing import TypedDict
 if TYPE_CHECKING:
     from .base import Connection
     from .base import Engine
-    from .result import Result
+    from .cursor import CursorResult
     from .url import URL
     from ..event import _ListenerFnType
     from ..event import dispatcher
@@ -2422,7 +2422,7 @@ class ExecutionContext:
     def _get_cache_stats(self) -> str:
         raise NotImplementedError()
 
-    def _setup_result_proxy(self) -> Result:
+    def _setup_result_proxy(self) -> CursorResult:
         raise NotImplementedError()
 
     def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int:
index 880bd8d4c2b9d8007afe5f761b8016ecf4830558..11998e7188dcc9c3d83645830b9e8fe8863d3bd3 100644 (file)
@@ -536,7 +536,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
         return interim_rows
 
     @HasMemoized_ro_memoized_attribute
-    def _onerow_getter(self) -> Callable[..., Union[_NoRow, _R]]:
+    def _onerow_getter(
+        self,
+    ) -> Callable[..., Union[Literal[_NoRow._NO_ROW], _R]]:
         make_row = self._row_getter
 
         post_creational_filter = self._post_creational_filter
index 8ed4c64bac4fc05caf8d646f1bc53b2f6dd29a45..c16f6870be2ba592294fbed01adba7b18e2d9c02 100644 (file)
@@ -256,6 +256,7 @@ class _HasEventsDispatch(Generic[_ET]):
     def _listen(
         cls,
         event_key: _EventKey[_ET],
+        *,
         propagate: bool = False,
         insert: bool = False,
         named: bool = False,
@@ -361,6 +362,7 @@ class Events(_HasEventsDispatch[_ET]):
     def _listen(
         cls,
         event_key: _EventKey[_ET],
+        *,
         propagate: bool = False,
         insert: bool = False,
         named: bool = False,
index 15b2cb015b7996b25a88cf76e1dc4f9efc5a9311..dfe89a154e5429c499f8fdea6f550b9ad886eeaa 100644 (file)
@@ -5,18 +5,17 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
-from .engine import async_engine_from_config
-from .engine import AsyncConnection
-from .engine import AsyncEngine
-from .engine import AsyncTransaction
-from .engine import create_async_engine
-from .events import AsyncConnectionEvents
-from .events import AsyncSessionEvents
-from .result import AsyncMappingResult
-from .result import AsyncResult
-from .result import AsyncScalarResult
-from .scoping import async_scoped_session
-from .session import async_object_session
-from .session import async_session
-from .session import AsyncSession
-from .session import AsyncSessionTransaction
+from .engine import async_engine_from_config as async_engine_from_config
+from .engine import AsyncConnection as AsyncConnection
+from .engine import AsyncEngine as AsyncEngine
+from .engine import AsyncTransaction as AsyncTransaction
+from .engine import create_async_engine as create_async_engine
+from .result import AsyncMappingResult as AsyncMappingResult
+from .result import AsyncResult as AsyncResult
+from .result import AsyncScalarResult as AsyncScalarResult
+from .scoping import async_scoped_session as async_scoped_session
+from .session import async_object_session as async_object_session
+from .session import async_session as async_session
+from .session import async_sessionmaker as async_sessionmaker
+from .session import AsyncSession as AsyncSession
+from .session import AsyncSessionTransaction as AsyncSessionTransaction
index 3f77f55007e6a6402552badfddb06803a23be14d..7fdd2d7e064314f1b3b5b55b2304ea28cc1c8d3a 100644 (file)
+# ext/asyncio/base.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import annotations
+
 import abc
 import functools
+from typing import Any
+from typing import ClassVar
+from typing import Dict
+from typing import Generic
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TypeVar
 import weakref
 
 from . import exc as async_exc
+from ... import util
+from ...util.typing import Literal
+
+_T = TypeVar("_T", bound=Any)
+
+
+_PT = TypeVar("_PT", bound=Any)
 
 
-class ReversibleProxy:
-    # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
-    _proxy_objects = {}
+SelfReversibleProxy = TypeVar(
+    "SelfReversibleProxy", bound="ReversibleProxy[Any]"
+)
+
+
+class ReversibleProxy(Generic[_PT]):
+    _proxy_objects: ClassVar[
+        Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
+    ] = {}
     __slots__ = ("__weakref__",)
 
-    def _assign_proxied(self, target):
+    @overload
+    def _assign_proxied(self, target: _PT) -> _PT:
+        ...
+
+    @overload
+    def _assign_proxied(self, target: None) -> None:
+        ...
+
+    def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
         if target is not None:
-            target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+            target_ref: weakref.ref[_PT] = weakref.ref(
+                target, ReversibleProxy._target_gced
+            )
             proxy_ref = weakref.ref(
                 self,
-                functools.partial(ReversibleProxy._target_gced, target_ref),
+                functools.partial(  # type: ignore
+                    ReversibleProxy._target_gced, target_ref
+                ),
             )
             ReversibleProxy._proxy_objects[target_ref] = proxy_ref
 
         return target
 
     @classmethod
-    def _target_gced(cls, ref, proxy_ref=None):
+    def _target_gced(
+        cls: Type[SelfReversibleProxy],
+        ref: weakref.ref[_PT],
+        proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None,
+    ) -> None:
         cls._proxy_objects.pop(ref, None)
 
     @classmethod
-    def _regenerate_proxy_for_target(cls, target):
+    def _regenerate_proxy_for_target(
+        cls: Type[SelfReversibleProxy], target: _PT
+    ) -> SelfReversibleProxy:
         raise NotImplementedError()
 
+    @overload
     @classmethod
-    def _retrieve_proxy_for_target(cls, target, regenerate=True):
+    def _retrieve_proxy_for_target(
+        cls: Type[SelfReversibleProxy],
+        target: _PT,
+        regenerate: Literal[True] = ...,
+    ) -> SelfReversibleProxy:
+        ...
+
+    @overload
+    @classmethod
+    def _retrieve_proxy_for_target(
+        cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+    ) -> Optional[SelfReversibleProxy]:
+        ...
+
+    @classmethod
+    def _retrieve_proxy_for_target(
+        cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+    ) -> Optional[SelfReversibleProxy]:
         try:
             proxy_ref = cls._proxy_objects[weakref.ref(target)]
         except KeyError:
@@ -38,7 +105,7 @@ class ReversibleProxy:
         else:
             proxy = proxy_ref()
             if proxy is not None:
-                return proxy
+                return proxy  # type: ignore
 
         if regenerate:
             return cls._regenerate_proxy_for_target(target)
@@ -46,43 +113,54 @@ class ReversibleProxy:
             return None
 
 
+SelfStartableContext = TypeVar(
+    "SelfStartableContext", bound="StartableContext"
+)
+
+
 class StartableContext(abc.ABC):
     __slots__ = ()
 
     @abc.abstractmethod
-    async def start(self, is_ctxmanager=False):
-        pass
+    async def start(
+        self: SelfStartableContext, is_ctxmanager: bool = False
+    ) -> Any:
+        raise NotImplementedError()
 
-    def __await__(self):
+    def __await__(self) -> Any:
         return self.start().__await__()
 
-    async def __aenter__(self):
+    async def __aenter__(self: SelfStartableContext) -> Any:
         return await self.start(is_ctxmanager=True)
 
     @abc.abstractmethod
-    async def __aexit__(self, type_, value, traceback):
+    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
         pass
 
-    def _raise_for_not_started(self):
+    def _raise_for_not_started(self) -> NoReturn:
         raise async_exc.AsyncContextNotStarted(
             "%s context has not been started and object has not been awaited."
             % (self.__class__.__name__)
         )
 
 
-class ProxyComparable(ReversibleProxy):
+class ProxyComparable(ReversibleProxy[_PT]):
     __slots__ = ()
 
-    def __hash__(self):
+    @util.ro_non_memoized_property
+    def _proxied(self) -> _PT:
+        raise NotImplementedError()
+
+    def __hash__(self) -> int:
         return id(self)
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         return (
             isinstance(other, self.__class__)
             and self._proxied == other._proxied
         )
 
-    def __ne__(self, other):
+    def __ne__(self, other: Any) -> bool:
         return (
             not isinstance(other, self.__class__)
             or self._proxied != other._proxied
index 3b54405c15b2d07c1e76f9f6e051ceeedcea3831..bb51a4d2256a77a359fa197315b016b1101da53e 100644 (file)
@@ -7,23 +7,56 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
 
 from . import exc as async_exc
 from .base import ProxyComparable
 from .base import StartableContext
 from .result import _ensure_sync_result
 from .result import AsyncResult
+from .result import AsyncScalarResult
 from ... import exc
 from ... import inspection
 from ... import util
+from ...engine import Connection
 from ...engine import create_engine as _create_engine
+from ...engine import Engine
 from ...engine.base import NestedTransaction
-from ...future import Connection
-from ...future import Engine
+from ...engine.base import Transaction
 from ...util.concurrency import greenlet_spawn
-
-
-def create_async_engine(*arg, **kw):
+from ...util.typing import Protocol
+
+if TYPE_CHECKING:
+    from ...engine import Connection
+    from ...engine import Engine
+    from ...engine.cursor import CursorResult
+    from ...engine.interfaces import _CoreAnyExecuteParams
+    from ...engine.interfaces import _CoreSingleExecuteParams
+    from ...engine.interfaces import _DBAPIAnyExecuteParams
+    from ...engine.interfaces import _ExecuteOptions
+    from ...engine.interfaces import _ExecuteOptionsParameter
+    from ...engine.interfaces import _IsolationLevel
+    from ...engine.interfaces import Dialect
+    from ...engine.result import ScalarResult
+    from ...engine.url import URL
+    from ...pool import Pool
+    from ...pool import PoolProxiedConnection
+    from ...sql.base import Executable
+
+
+class _SyncConnectionCallable(Protocol):
+    def __call__(self, connection: Connection, *arg: Any, **kw: Any) -> Any:
+        ...
+
+
+def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
     """Create a new async engine instance.
 
     Arguments passed to :func:`_asyncio.create_async_engine` are mostly
@@ -43,11 +76,13 @@ def create_async_engine(*arg, **kw):
         )
     kw["future"] = True
     kw["_is_async"] = True
-    sync_engine = _create_engine(*arg, **kw)
+    sync_engine = _create_engine(url, **kw)
     return AsyncEngine(sync_engine)
 
 
-def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+def async_engine_from_config(
+    configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any
+) -> AsyncEngine:
     """Create a new AsyncEngine instance using a configuration dictionary.
 
     This function is analogous to the :func:`_sa.engine_from_config` function
@@ -73,6 +108,14 @@ def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
 class AsyncConnectable:
     __slots__ = "_slots_dispatch", "__weakref__"
 
+    @classmethod
+    def _no_async_engine_events(cls) -> NoReturn:
+        raise NotImplementedError(
+            "asynchronous events are not implemented at this time.  Apply "
+            "synchronous listeners to the AsyncEngine.sync_engine or "
+            "AsyncConnection.sync_connection attributes."
+        )
+
 
 @util.create_proxy_methods(
     Connection,
@@ -87,7 +130,9 @@ class AsyncConnectable:
         "default_isolation_level",
     ],
 )
-class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
+class AsyncConnection(
+    ProxyComparable[Connection], StartableContext, AsyncConnectable
+):
     """An asyncio proxy for a :class:`_engine.Connection`.
 
     :class:`_asyncio.AsyncConnection` is acquired using the
@@ -115,12 +160,16 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         "sync_connection",
     )
 
-    def __init__(self, async_engine, sync_connection=None):
+    def __init__(
+        self,
+        async_engine: AsyncEngine,
+        sync_connection: Optional[Connection] = None,
+    ):
         self.engine = async_engine
         self.sync_engine = async_engine.sync_engine
         self.sync_connection = self._assign_proxied(sync_connection)
 
-    sync_connection: Connection
+    sync_connection: Optional[Connection]
     """Reference to the sync-style :class:`_engine.Connection` this
     :class:`_asyncio.AsyncConnection` proxies requests towards.
 
@@ -146,12 +195,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
     """
 
     @classmethod
-    def _regenerate_proxy_for_target(cls, target):
+    def _regenerate_proxy_for_target(
+        cls, target: Connection
+    ) -> AsyncConnection:
         return AsyncConnection(
             AsyncEngine._retrieve_proxy_for_target(target.engine), target
         )
 
-    async def start(self, is_ctxmanager=False):
+    async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
         """Start this :class:`_asyncio.AsyncConnection` object's context
         outside of using a Python ``with:`` block.
 
@@ -164,7 +215,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         return self
 
     @property
-    def connection(self):
+    def connection(self) -> NoReturn:
         """Not implemented for async; call
         :meth:`_asyncio.AsyncConnection.get_raw_connection`.
         """
@@ -174,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
             "Use the get_raw_connection() method."
         )
 
-    async def get_raw_connection(self):
+    async def get_raw_connection(self) -> PoolProxiedConnection:
         """Return the pooled DBAPI-level connection in use by this
         :class:`_asyncio.AsyncConnection`.
 
@@ -187,16 +238,11 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         adapts the driver connection to the DBAPI protocol.
 
         """
-        conn = self._sync_connection()
-
-        return await greenlet_spawn(getattr, conn, "connection")
 
-    @property
-    def _proxied(self):
-        return self.sync_connection
+        return await greenlet_spawn(getattr, self._proxied, "connection")
 
     @property
-    def info(self):
+    def info(self) -> Dict[str, Any]:
         """Return the :attr:`_engine.Connection.info` dictionary of the
         underlying :class:`_engine.Connection`.
 
@@ -211,24 +257,28 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         .. versionadded:: 1.4.0b2
 
         """
-        return self.sync_connection.info
+        return self._proxied.info
 
-    def _sync_connection(self):
+    @util.ro_non_memoized_property
+    def _proxied(self) -> Connection:
         if not self.sync_connection:
             self._raise_for_not_started()
         return self.sync_connection
 
-    def begin(self):
+    def begin(self) -> AsyncTransaction:
         """Begin a transaction prior to autobegin occurring."""
-        self._sync_connection()
+        assert self._proxied
         return AsyncTransaction(self)
 
-    def begin_nested(self):
+    def begin_nested(self) -> AsyncTransaction:
         """Begin a nested transaction and return a transaction handle."""
-        self._sync_connection()
+        assert self._proxied
         return AsyncTransaction(self, nested=True)
 
-    async def invalidate(self, exception=None):
+    async def invalidate(
+        self, exception: Optional[BaseException] = None
+    ) -> None:
+
         """Invalidate the underlying DBAPI connection associated with
         this :class:`_engine.Connection`.
 
@@ -237,39 +287,27 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
         """
 
-        conn = self._sync_connection()
-        return await greenlet_spawn(conn.invalidate, exception=exception)
-
-    async def get_isolation_level(self):
-        conn = self._sync_connection()
-        return await greenlet_spawn(conn.get_isolation_level)
-
-    async def set_isolation_level(self):
-        conn = self._sync_connection()
-        return await greenlet_spawn(conn.get_isolation_level)
-
-    def in_transaction(self):
-        """Return True if a transaction is in progress.
-
-        .. versionadded:: 1.4.0b2
+        return await greenlet_spawn(
+            self._proxied.invalidate, exception=exception
+        )
 
-        """
+    async def get_isolation_level(self) -> _IsolationLevel:
+        return await greenlet_spawn(self._proxied.get_isolation_level)
 
-        conn = self._sync_connection()
+    def in_transaction(self) -> bool:
+        """Return True if a transaction is in progress."""
 
-        return conn.in_transaction()
+        return self._proxied.in_transaction()
 
-    def in_nested_transaction(self):
+    def in_nested_transaction(self) -> bool:
         """Return True if a transaction is in progress.
 
         .. versionadded:: 1.4.0b2
 
         """
-        conn = self._sync_connection()
-
-        return conn.in_nested_transaction()
+        return self._proxied.in_nested_transaction()
 
-    def get_transaction(self):
+    def get_transaction(self) -> Optional[AsyncTransaction]:
         """Return an :class:`.AsyncTransaction` representing the current
         transaction, if any.
 
@@ -281,15 +319,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         .. versionadded:: 1.4.0b2
 
         """
-        conn = self._sync_connection()
 
-        trans = conn.get_transaction()
+        trans = self._proxied.get_transaction()
         if trans is not None:
             return AsyncTransaction._retrieve_proxy_for_target(trans)
         else:
             return None
 
-    def get_nested_transaction(self):
+    def get_nested_transaction(self) -> Optional[AsyncTransaction]:
         """Return an :class:`.AsyncTransaction` representing the current
         nested (savepoint) transaction, if any.
 
@@ -301,15 +338,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         .. versionadded:: 1.4.0b2
 
         """
-        conn = self._sync_connection()
 
-        trans = conn.get_nested_transaction()
+        trans = self._proxied.get_nested_transaction()
         if trans is not None:
             return AsyncTransaction._retrieve_proxy_for_target(trans)
         else:
             return None
 
-    async def execution_options(self, **opt):
+    async def execution_options(self, **opt: Any) -> AsyncConnection:
         r"""Set non-SQL options for the connection which take effect
         during execution.
 
@@ -321,12 +357,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
         """
 
-        conn = self._sync_connection()
+        conn = self._proxied
         c2 = await greenlet_spawn(conn.execution_options, **opt)
         assert c2 is conn
         return self
 
-    async def commit(self):
+    async def commit(self) -> None:
         """Commit the transaction that is currently in progress.
 
         This method commits the current transaction if one has been started.
@@ -338,10 +374,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         :meth:`_future.Connection.begin` method is called.
 
         """
-        conn = self._sync_connection()
-        await greenlet_spawn(conn.commit)
+        await greenlet_spawn(self._proxied.commit)
 
-    async def rollback(self):
+    async def rollback(self) -> None:
         """Roll back the transaction that is currently in progress.
 
         This method rolls back the current transaction if one has been started.
@@ -355,34 +390,30 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
 
         """
-        conn = self._sync_connection()
-        await greenlet_spawn(conn.rollback)
+        await greenlet_spawn(self._proxied.rollback)
 
-    async def close(self):
+    async def close(self) -> None:
         """Close this :class:`_asyncio.AsyncConnection`.
 
         This has the effect of also rolling back the transaction if one
         is in place.
 
         """
-        conn = self._sync_connection()
-        await greenlet_spawn(conn.close)
+        await greenlet_spawn(self._proxied.close)
 
     async def exec_driver_sql(
         self,
-        statement,
-        parameters=None,
-        execution_options=util.EMPTY_DICT,
-    ):
+        statement: str,
+        parameters: Optional[_DBAPIAnyExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> CursorResult:
         r"""Executes a driver-level SQL string and return buffered
         :class:`_engine.Result`.
 
         """
 
-        conn = self._sync_connection()
-
         result = await greenlet_spawn(
-            conn.exec_driver_sql,
+            self._proxied.exec_driver_sql,
             statement,
             parameters,
             execution_options,
@@ -393,17 +424,15 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
     async def stream(
         self,
-        statement,
-        parameters=None,
-        execution_options=util.EMPTY_DICT,
-    ):
+        statement: Executable,
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> AsyncResult:
         """Execute a statement and return a streaming
         :class:`_asyncio.AsyncResult` object."""
 
-        conn = self._sync_connection()
-
         result = await greenlet_spawn(
-            conn.execute,
+            self._proxied.execute,
             statement,
             parameters,
             util.EMPTY_DICT.merge_with(
@@ -418,10 +447,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
     async def execute(
         self,
-        statement,
-        parameters=None,
-        execution_options=util.EMPTY_DICT,
-    ):
+        statement: Executable,
+        parameters: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> CursorResult:
         r"""Executes a SQL statement construct and return a buffered
         :class:`_engine.Result`.
 
@@ -453,10 +482,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         :return: a :class:`_engine.Result` object.
 
         """
-        conn = self._sync_connection()
-
         result = await greenlet_spawn(
-            conn.execute,
+            self._proxied.execute,
             statement,
             parameters,
             execution_options,
@@ -466,10 +493,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
     async def scalar(
         self,
-        statement,
-        parameters=None,
-        execution_options=util.EMPTY_DICT,
-    ):
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> Any:
         r"""Executes a SQL statement construct and returns a scalar object.
 
         This method is shorthand for invoking the
@@ -485,10 +512,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
     async def scalars(
         self,
-        statement,
-        parameters=None,
-        execution_options=util.EMPTY_DICT,
-    ):
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> ScalarResult[Any]:
         r"""Executes a SQL statement construct and returns a scalar objects.
 
         This method is shorthand for invoking the
@@ -505,10 +532,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
 
     async def stream_scalars(
         self,
-        statement,
-        parameters=None,
-        execution_options=util.EMPTY_DICT,
-    ):
+        statement: Executable,
+        parameters: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: Optional[_ExecuteOptionsParameter] = None,
+    ) -> AsyncScalarResult[Any]:
         r"""Executes a SQL statement and returns a streaming scalar result
         object.
 
@@ -524,7 +551,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
         result = await self.stream(statement, parameters, execution_options)
         return result.scalars()
 
-    async def run_sync(self, fn, *arg, **kw):
+    async def run_sync(
+        self, fn: _SyncConnectionCallable, *arg: Any, **kw: Any
+    ) -> Any:
         """Invoke the given sync callable passing self as the first argument.
 
         This method maintains the asyncio event loop all the way through
@@ -548,14 +577,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
             :ref:`session_run_sync`
         """
 
-        conn = self._sync_connection()
-
-        return await greenlet_spawn(fn, conn, *arg, **kw)
+        return await greenlet_spawn(fn, self._proxied, *arg, **kw)
 
-    def __await__(self):
+    def __await__(self) -> Generator[Any, None, AsyncConnection]:
         return self.start().__await__()
 
-    async def __aexit__(self, type_, value, traceback):
+    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
         await self.close()
 
     # START PROXY METHODS AsyncConnection
@@ -661,7 +688,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
     ],
     attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
 )
-class AsyncEngine(ProxyComparable, AsyncConnectable):
+class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
     """An asyncio proxy for a :class:`_engine.Engine`.
 
     :class:`_asyncio.AsyncEngine` is acquired using the
@@ -679,51 +706,60 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
     # current transaction, info, etc.   It should be possible to
     # create a new AsyncEngine that matches this one given only the
     # "sync" elements.
-    __slots__ = ("sync_engine", "_proxied")
+    __slots__ = "sync_engine"
 
-    _connection_cls = AsyncConnection
+    _connection_cls: Type[AsyncConnection] = AsyncConnection
 
-    _option_cls: type
+    sync_engine: Engine
+    """Reference to the sync-style :class:`_engine.Engine` this
+    :class:`_asyncio.AsyncEngine` proxies requests towards.
+
+    This instance can be used as an event target.
+
+    .. seealso::
+
+        :ref:`asyncio_events`
+    """
 
     class _trans_ctx(StartableContext):
-        def __init__(self, conn):
+        __slots__ = ("conn", "transaction")
+
+        conn: AsyncConnection
+        transaction: AsyncTransaction
+
+        def __init__(self, conn: AsyncConnection):
             self.conn = conn
 
-        async def start(self, is_ctxmanager=False):
+        async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
             await self.conn.start(is_ctxmanager=is_ctxmanager)
             self.transaction = self.conn.begin()
             await self.transaction.__aenter__()
 
             return self.conn
 
-        async def __aexit__(self, type_, value, traceback):
+        async def __aexit__(
+            self, type_: Any, value: Any, traceback: Any
+        ) -> None:
             await self.transaction.__aexit__(type_, value, traceback)
             await self.conn.close()
 
-    def __init__(self, sync_engine):
+    def __init__(self, sync_engine: Engine):
         if not sync_engine.dialect.is_async:
             raise exc.InvalidRequestError(
                 "The asyncio extension requires an async driver to be used. "
                 f"The loaded {sync_engine.dialect.driver!r} is not async."
             )
-        self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
-
-    sync_engine: Engine
-    """Reference to the sync-style :class:`_engine.Engine` this
-    :class:`_asyncio.AsyncEngine` proxies requests towards.
+        self.sync_engine = self._assign_proxied(sync_engine)
 
-    This instance can be used as an event target.
-
-    .. seealso::
-
-        :ref:`asyncio_events`
-    """
+    @util.ro_non_memoized_property
+    def _proxied(self) -> Engine:
+        return self.sync_engine
 
     @classmethod
-    def _regenerate_proxy_for_target(cls, target):
+    def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
         return AsyncEngine(target)
 
-    def begin(self):
+    def begin(self) -> AsyncEngine._trans_ctx:
         """Return a context manager which when entered will deliver an
         :class:`_asyncio.AsyncConnection` with an
         :class:`_asyncio.AsyncTransaction` established.
@@ -741,7 +777,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
         conn = self.connect()
         return self._trans_ctx(conn)
 
-    def connect(self):
+    def connect(self) -> AsyncConnection:
         """Return an :class:`_asyncio.AsyncConnection` object.
 
         The :class:`_asyncio.AsyncConnection` will procure a database
@@ -759,7 +795,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
 
         return self._connection_cls(self)
 
-    async def raw_connection(self):
+    async def raw_connection(self) -> PoolProxiedConnection:
         """Return a "raw" DBAPI connection from the connection pool.
 
         .. seealso::
@@ -769,7 +805,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
         """
         return await greenlet_spawn(self.sync_engine.raw_connection)
 
-    def execution_options(self, **opt):
+    def execution_options(self, **opt: Any) -> AsyncEngine:
         """Return a new :class:`_asyncio.AsyncEngine` that will provide
         :class:`_asyncio.AsyncConnection` objects with the given execution
         options.
@@ -781,21 +817,31 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
 
         return AsyncEngine(self.sync_engine.execution_options(**opt))
 
-    async def dispose(self):
+    async def dispose(self, close: bool = True) -> None:
+
         """Dispose of the connection pool used by this
         :class:`_asyncio.AsyncEngine`.
 
-        This will close all connection pool connections that are
-        **currently checked in**.  See the documentation for the underlying
-        :meth:`_future.Engine.dispose` method for further notes.
+        :param close: if left at its default of ``True``, has the
+         effect of fully closing all **currently checked in**
+         database connections.  Connections that are still checked out
+         will **not** be closed, however they will no longer be associated
+         with this :class:`_engine.Engine`,
+         so when they are closed individually, eventually the
+         :class:`_pool.Pool` which they are associated with will
+         be garbage collected and they will be closed out fully, if
+         not already closed on checkin.
+
+         If set to ``False``, the previous connection pool is de-referenced,
+         and otherwise not touched in any way.
 
         .. seealso::
 
-            :meth:`_future.Engine.dispose`
+            :meth:`_engine.Engine.dispose`
 
         """
 
-        return await greenlet_spawn(self.sync_engine.dispose)
+        return await greenlet_spawn(self.sync_engine.dispose, close=close)
 
     # START PROXY METHODS AsyncEngine
 
@@ -973,18 +1019,24 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
     # END PROXY METHODS AsyncEngine
 
 
-class AsyncTransaction(ProxyComparable, StartableContext):
+class AsyncTransaction(ProxyComparable[Transaction], StartableContext):
     """An asyncio proxy for a :class:`_engine.Transaction`."""
 
     __slots__ = ("connection", "sync_transaction", "nested")
 
-    def __init__(self, connection, nested=False):
-        self.connection = connection  # AsyncConnection
-        self.sync_transaction = None  # sqlalchemy.engine.Transaction
+    sync_transaction: Optional[Transaction]
+    connection: AsyncConnection
+    nested: bool
+
+    def __init__(self, connection: AsyncConnection, nested: bool = False):
+        self.connection = connection
+        self.sync_transaction = None
         self.nested = nested
 
     @classmethod
-    def _regenerate_proxy_for_target(cls, target):
+    def _regenerate_proxy_for_target(
+        cls, target: Transaction
+    ) -> AsyncTransaction:
         sync_connection = target.connection
         sync_transaction = target
         nested = isinstance(target, NestedTransaction)
@@ -1000,25 +1052,22 @@ class AsyncTransaction(ProxyComparable, StartableContext):
         obj.nested = nested
         return obj
 
-    def _sync_transaction(self):
+    @util.ro_non_memoized_property
+    def _proxied(self) -> Transaction:
         if not self.sync_transaction:
             self._raise_for_not_started()
         return self.sync_transaction
 
     @property
-    def _proxied(self):
-        return self.sync_transaction
+    def is_valid(self) -> bool:
+        return self._proxied.is_valid
 
     @property
-    def is_valid(self):
-        return self._sync_transaction().is_valid
+    def is_active(self) -> bool:
+        return self._proxied.is_active
 
-    @property
-    def is_active(self):
-        return self._sync_transaction().is_active
-
-    async def close(self):
-        """Close this :class:`.Transaction`.
+    async def close(self) -> None:
+        """Close this :class:`.AsyncTransaction`.
 
         If this transaction is the base transaction in a begin/commit
         nesting, the transaction will rollback().  Otherwise, the
@@ -1028,18 +1077,18 @@ class AsyncTransaction(ProxyComparable, StartableContext):
         an enclosing transaction.
 
         """
-        await greenlet_spawn(self._sync_transaction().close)
+        await greenlet_spawn(self._proxied.close)
 
-    async def rollback(self):
-        """Roll back this :class:`.Transaction`."""
-        await greenlet_spawn(self._sync_transaction().rollback)
+    async def rollback(self) -> None:
+        """Roll back this :class:`.AsyncTransaction`."""
+        await greenlet_spawn(self._proxied.rollback)
 
-    async def commit(self):
-        """Commit this :class:`.Transaction`."""
+    async def commit(self) -> None:
+        """Commit this :class:`.AsyncTransaction`."""
 
-        await greenlet_spawn(self._sync_transaction().commit)
+        await greenlet_spawn(self._proxied.commit)
 
-    async def start(self, is_ctxmanager=False):
+    async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction:
         """Start this :class:`_asyncio.AsyncTransaction` object's context
         outside of using a Python ``with:`` block.
 
@@ -1047,24 +1096,36 @@ class AsyncTransaction(ProxyComparable, StartableContext):
 
         self.sync_transaction = self._assign_proxied(
             await greenlet_spawn(
-                self.connection._sync_connection().begin_nested
+                self.connection._proxied.begin_nested
                 if self.nested
-                else self.connection._sync_connection().begin
+                else self.connection._proxied.begin
             )
         )
         if is_ctxmanager:
             self.sync_transaction.__enter__()
         return self
 
-    async def __aexit__(self, type_, value, traceback):
-        await greenlet_spawn(
-            self._sync_transaction().__exit__, type_, value, traceback
-        )
+    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
+        await greenlet_spawn(self._proxied.__exit__, type_, value, traceback)
+
+
+@overload
+def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine:
+    ...
+
+
+@overload
+def _get_sync_engine_or_connection(
+    async_engine: AsyncConnection,
+) -> Connection:
+    ...
 
 
-def _get_sync_engine_or_connection(async_engine):
+def _get_sync_engine_or_connection(
+    async_engine: Union[AsyncEngine, AsyncConnection]
+) -> Union[Engine, Connection]:
     if isinstance(async_engine, AsyncConnection):
-        return async_engine.sync_connection
+        return async_engine._proxied
 
     try:
         return async_engine.sync_engine
@@ -1075,7 +1136,7 @@ def _get_sync_engine_or_connection(async_engine):
 
 
 @inspection._inspects(AsyncConnection)
-def _no_insp_for_async_conn_yet(subject):
+def _no_insp_for_async_conn_yet(subject: AsyncConnection) -> NoReturn:
     raise exc.NoInspectionAvailable(
         "Inspection on an AsyncConnection is currently not supported. "
         "Please use ``run_sync`` to pass a callable where it's possible "
@@ -1085,7 +1146,7 @@ def _no_insp_for_async_conn_yet(subject):
 
 
 @inspection._inspects(AsyncEngine)
-def _no_insp_for_async_engine_xyet(subject):
+def _no_insp_for_async_engine_xyet(subject: AsyncEngine) -> NoReturn:
     raise exc.NoInspectionAvailable(
         "Inspection on an AsyncEngine is currently not supported. "
         "Please obtain a connection then use ``conn.run_sync`` to pass a "
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
deleted file mode 100644 (file)
index c5d5e01..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-# ext/asyncio/events.py
-# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: https://www.opensource.org/licenses/mit-license.php
-
-from .engine import AsyncConnectable
-from .session import AsyncSession
-from ...engine import events as engine_event
-from ...orm import events as orm_event
-
-
-class AsyncConnectionEvents(engine_event.ConnectionEvents):
-    _target_class_doc = "SomeEngine"
-    _dispatch_target = AsyncConnectable
-
-    @classmethod
-    def _no_async_engine_events(cls):
-        raise NotImplementedError(
-            "asynchronous events are not implemented at this time.  Apply "
-            "synchronous listeners to the AsyncEngine.sync_engine or "
-            "AsyncConnection.sync_connection attributes."
-        )
-
-    @classmethod
-    def _listen(cls, event_key, retval=False):
-        cls._no_async_engine_events()
-
-
-class AsyncSessionEvents(orm_event.SessionEvents):
-    _target_class_doc = "SomeSession"
-    _dispatch_target = AsyncSession
-
-    @classmethod
-    def _no_async_engine_events(cls):
-        raise NotImplementedError(
-            "asynchronous events are not implemented at this time.  Apply "
-            "synchronous listeners to the AsyncSession.sync_session."
-        )
-
-    @classmethod
-    def _listen(cls, event_key, retval=False):
-        cls._no_async_engine_events()
index 39718735cc8916ac9f0558467cd1bf1c789035d7..a9db822a6a91fd0e89bd47b055914c687e24671b 100644 (file)
@@ -4,25 +4,49 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
 
 import operator
+from typing import Any
+from typing import AsyncIterator
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
 
 from . import exc as async_exc
 from ...engine.result import _NO_ROW
+from ...engine.result import _R
 from ...engine.result import FilterResult
 from ...engine.result import FrozenResult
 from ...engine.result import MergedResult
+from ...engine.result import ResultMetaData
+from ...engine.row import Row
+from ...engine.row import RowMapping
 from ...util.concurrency import greenlet_spawn
 
+if TYPE_CHECKING:
+    from ...engine import CursorResult
+    from ...engine import Result
+    from ...engine.result import _KeyIndexType
+    from ...engine.result import _UniqueFilterType
+    from ...engine.result import RMKeyView
 
-class AsyncCommon(FilterResult):
-    async def close(self):
+
+class AsyncCommon(FilterResult[_R]):
+    _real_result: Result
+    _metadata: ResultMetaData
+
+    async def close(self) -> None:
         """Close this result."""
 
         await greenlet_spawn(self._real_result.close)
 
 
-class AsyncResult(AsyncCommon):
+SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult")
+
+
+class AsyncResult(AsyncCommon[Row]):
     """An asyncio wrapper around a :class:`_result.Result` object.
 
     The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -43,7 +67,7 @@ class AsyncResult(AsyncCommon):
 
     """
 
-    def __init__(self, real_result):
+    def __init__(self, real_result: Result):
         self._real_result = real_result
 
         self._metadata = real_result._metadata
@@ -56,14 +80,16 @@ class AsyncResult(AsyncCommon):
                 "_row_getter", real_result.__dict__["_row_getter"]
             )
 
-    def keys(self):
+    def keys(self) -> RMKeyView:
         """Return the :meth:`_engine.Result.keys` collection from the
         underlying :class:`_engine.Result`.
 
         """
         return self._metadata.keys
 
-    def unique(self, strategy=None):
+    def unique(
+        self: SelfAsyncResult, strategy: Optional[_UniqueFilterType] = None
+    ) -> SelfAsyncResult:
         """Apply unique filtering to the objects returned by this
         :class:`_asyncio.AsyncResult`.
 
@@ -75,7 +101,9 @@ class AsyncResult(AsyncCommon):
         self._unique_filter_state = (set(), strategy)
         return self
 
-    def columns(self, *col_expressions):
+    def columns(
+        self: SelfAsyncResult, *col_expressions: _KeyIndexType
+    ) -> SelfAsyncResult:
         r"""Establish the columns that should be returned in each row.
 
         Refer to :meth:`_engine.Result.columns` in the synchronous
@@ -85,7 +113,9 @@ class AsyncResult(AsyncCommon):
         """
         return self._column_slices(col_expressions)
 
-    async def partitions(self, size=None):
+    async def partitions(
+        self, size: Optional[int] = None
+    ) -> AsyncIterator[List[Row]]:
         """Iterate through sub-lists of rows of the size given.
 
         An async iterator is returned::
@@ -111,7 +141,7 @@ class AsyncResult(AsyncCommon):
             else:
                 break
 
-    async def fetchone(self):
+    async def fetchone(self) -> Optional[Row]:
         """Fetch one row.
 
         When all rows are exhausted, returns None.
@@ -131,9 +161,9 @@ class AsyncResult(AsyncCommon):
         if row is _NO_ROW:
             return None
         else:
-            return row  # type: ignore[return-value]
+            return row
 
-    async def fetchmany(self, size=None):
+    async def fetchmany(self, size: Optional[int] = None) -> List[Row]:
         """Fetch many rows.
 
         When all rows are exhausted, returns an empty list.
@@ -152,11 +182,9 @@ class AsyncResult(AsyncCommon):
 
         """
 
-        return await greenlet_spawn(
-            self._manyrow_getter, self, size  # type: ignore
-        )
+        return await greenlet_spawn(self._manyrow_getter, self, size)
 
-    async def all(self):
+    async def all(self) -> List[Row]:
         """Return all rows in a list.
 
         Closes the result set after invocation.   Subsequent invocations
@@ -166,19 +194,19 @@ class AsyncResult(AsyncCommon):
 
         """
 
-        return await greenlet_spawn(self._allrows)  # type: ignore
+        return await greenlet_spawn(self._allrows)
 
-    def __aiter__(self):
+    def __aiter__(self) -> AsyncResult:
         return self
 
-    async def __anext__(self):
+    async def __anext__(self) -> Row:
         row = await greenlet_spawn(self._onerow_getter, self)
         if row is _NO_ROW:
             raise StopAsyncIteration()
         else:
             return row
 
-    async def first(self):
+    async def first(self) -> Optional[Row]:
         """Fetch the first row or None if no row is present.
 
         Closes the result set and discards remaining rows.
@@ -201,7 +229,7 @@ class AsyncResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, False, False, False)
 
-    async def one_or_none(self):
+    async def one_or_none(self) -> Optional[Row]:
         """Return at most one result or raise an exception.
 
         Returns ``None`` if the result has no rows.
@@ -223,7 +251,7 @@ class AsyncResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, True, False, False)
 
-    async def scalar_one(self):
+    async def scalar_one(self) -> Any:
         """Return exactly one scalar result or raise an exception.
 
         This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
@@ -238,7 +266,7 @@ class AsyncResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, True, True, True)
 
-    async def scalar_one_or_none(self):
+    async def scalar_one_or_none(self) -> Optional[Any]:
         """Return exactly one or no scalar result.
 
         This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
@@ -253,7 +281,7 @@ class AsyncResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, True, False, True)
 
-    async def one(self):
+    async def one(self) -> Row:
         """Return exactly one row or raise an exception.
 
         Raises :class:`.NoResultFound` if the result returns no
@@ -284,7 +312,7 @@ class AsyncResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, True, True, False)
 
-    async def scalar(self):
+    async def scalar(self) -> Any:
         """Fetch the first column of the first row, and close the result set.
 
         Returns None if there are no rows to fetch.
@@ -300,7 +328,7 @@ class AsyncResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, False, False, True)
 
-    async def freeze(self):
+    async def freeze(self) -> FrozenResult:
         """Return a callable object that will produce copies of this
         :class:`_asyncio.AsyncResult` when invoked.
 
@@ -323,7 +351,7 @@ class AsyncResult(AsyncCommon):
 
         return await greenlet_spawn(FrozenResult, self)
 
-    def merge(self, *others):
+    def merge(self, *others: AsyncResult) -> MergedResult:
         """Merge this :class:`_asyncio.AsyncResult` with other compatible result
         objects.
 
@@ -337,9 +365,12 @@ class AsyncResult(AsyncCommon):
         undefined.
 
         """
-        return MergedResult(self._metadata, (self,) + others)
+        return MergedResult(
+            self._metadata,
+            (self._real_result,) + tuple(o._real_result for o in others),
+        )
 
-    def scalars(self, index=0):
+    def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
         """Return an :class:`_asyncio.AsyncScalarResult` filtering object which
         will return single elements rather than :class:`_row.Row` objects.
 
@@ -355,7 +386,7 @@ class AsyncResult(AsyncCommon):
         """
         return AsyncScalarResult(self._real_result, index)
 
-    def mappings(self):
+    def mappings(self) -> AsyncMappingResult:
         """Apply a mappings filter to returned rows, returning an instance of
         :class:`_asyncio.AsyncMappingResult`.
 
@@ -373,7 +404,12 @@ class AsyncResult(AsyncCommon):
         return AsyncMappingResult(self._real_result)
 
 
-class AsyncScalarResult(AsyncCommon):
+SelfAsyncScalarResult = TypeVar(
+    "SelfAsyncScalarResult", bound="AsyncScalarResult[Any]"
+)
+
+
+class AsyncScalarResult(AsyncCommon[_R]):
     """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
     rather than :class:`_row.Row` values.
 
@@ -389,7 +425,7 @@ class AsyncScalarResult(AsyncCommon):
 
     _generate_rows = False
 
-    def __init__(self, real_result, index):
+    def __init__(self, real_result: Result, index: _KeyIndexType):
         self._real_result = real_result
 
         if real_result._source_supports_scalars:
@@ -401,7 +437,10 @@ class AsyncScalarResult(AsyncCommon):
 
         self._unique_filter_state = real_result._unique_filter_state
 
-    def unique(self, strategy=None):
+    def unique(
+        self: SelfAsyncScalarResult,
+        strategy: Optional[_UniqueFilterType] = None,
+    ) -> SelfAsyncScalarResult:
         """Apply unique filtering to the objects returned by this
         :class:`_asyncio.AsyncScalarResult`.
 
@@ -411,7 +450,9 @@ class AsyncScalarResult(AsyncCommon):
         self._unique_filter_state = (set(), strategy)
         return self
 
-    async def partitions(self, size=None):
+    async def partitions(
+        self, size: Optional[int] = None
+    ) -> AsyncIterator[List[_R]]:
         """Iterate through sub-lists of elements of the size given.
 
         Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -429,12 +470,12 @@ class AsyncScalarResult(AsyncCommon):
             else:
                 break
 
-    async def fetchall(self):
+    async def fetchall(self) -> List[_R]:
         """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
 
         return await greenlet_spawn(self._allrows)
 
-    async def fetchmany(self, size=None):
+    async def fetchmany(self, size: Optional[int] = None) -> List[_R]:
         """Fetch many objects.
 
         Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -444,7 +485,7 @@ class AsyncScalarResult(AsyncCommon):
         """
         return await greenlet_spawn(self._manyrow_getter, self, size)
 
-    async def all(self):
+    async def all(self) -> List[_R]:
         """Return all scalar values in a list.
 
         Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -454,17 +495,17 @@ class AsyncScalarResult(AsyncCommon):
         """
         return await greenlet_spawn(self._allrows)
 
-    def __aiter__(self):
+    def __aiter__(self) -> AsyncScalarResult[_R]:
         return self
 
-    async def __anext__(self):
+    async def __anext__(self) -> _R:
         row = await greenlet_spawn(self._onerow_getter, self)
         if row is _NO_ROW:
             raise StopAsyncIteration()
         else:
             return row
 
-    async def first(self):
+    async def first(self) -> Optional[_R]:
         """Fetch the first object or None if no object is present.
 
         Equivalent to :meth:`_asyncio.AsyncResult.first` except that
@@ -474,7 +515,7 @@ class AsyncScalarResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, False, False, False)
 
-    async def one_or_none(self):
+    async def one_or_none(self) -> Optional[_R]:
         """Return at most one object or raise an exception.
 
         Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
@@ -484,7 +525,7 @@ class AsyncScalarResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, True, False, False)
 
-    async def one(self):
+    async def one(self) -> _R:
         """Return exactly one object or raise an exception.
 
         Equivalent to :meth:`_asyncio.AsyncResult.one` except that
@@ -495,7 +536,12 @@ class AsyncScalarResult(AsyncCommon):
         return await greenlet_spawn(self._only_one_row, True, True, False)
 
 
-class AsyncMappingResult(AsyncCommon):
+SelfAsyncMappingResult = TypeVar(
+    "SelfAsyncMappingResult", bound="AsyncMappingResult"
+)
+
+
+class AsyncMappingResult(AsyncCommon[RowMapping]):
     """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values
     rather than :class:`_engine.Row` values.
 
@@ -513,14 +559,14 @@ class AsyncMappingResult(AsyncCommon):
 
     _post_creational_filter = operator.attrgetter("_mapping")
 
-    def __init__(self, result):
+    def __init__(self, result: Result):
         self._real_result = result
         self._unique_filter_state = result._unique_filter_state
         self._metadata = result._metadata
         if result._source_supports_scalars:
             self._metadata = self._metadata._reduce([0])
 
-    def keys(self):
+    def keys(self) -> RMKeyView:
         """Return an iterable view which yields the string keys that would
         be represented by each :class:`.Row`.
 
@@ -535,7 +581,10 @@ class AsyncMappingResult(AsyncCommon):
         """
         return self._metadata.keys
 
-    def unique(self, strategy=None):
+    def unique(
+        self: SelfAsyncMappingResult,
+        strategy: Optional[_UniqueFilterType] = None,
+    ) -> SelfAsyncMappingResult:
         """Apply unique filtering to the objects returned by this
         :class:`_asyncio.AsyncMappingResult`.
 
@@ -545,11 +594,16 @@ class AsyncMappingResult(AsyncCommon):
         self._unique_filter_state = (set(), strategy)
         return self
 
-    def columns(self, *col_expressions):
+    def columns(
+        self: SelfAsyncMappingResult, *col_expressions: _KeyIndexType
+    ) -> SelfAsyncMappingResult:
         r"""Establish the columns that should be returned in each row."""
         return self._column_slices(col_expressions)
 
-    async def partitions(self, size=None):
+    async def partitions(
+        self, size: Optional[int] = None
+    ) -> AsyncIterator[List[RowMapping]]:
+
         """Iterate through sub-lists of elements of the size given.
 
         Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -567,12 +621,12 @@ class AsyncMappingResult(AsyncCommon):
             else:
                 break
 
-    async def fetchall(self):
+    async def fetchall(self) -> List[RowMapping]:
         """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
 
         return await greenlet_spawn(self._allrows)
 
-    async def fetchone(self):
+    async def fetchone(self) -> Optional[RowMapping]:
         """Fetch one object.
 
         Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
@@ -587,8 +641,8 @@ class AsyncMappingResult(AsyncCommon):
         else:
             return row
 
-    async def fetchmany(self, size=None):
-        """Fetch many objects.
+    async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+        """Fetch many rows.
 
         Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
         :class:`_result.RowMapping` values, rather than :class:`_result.Row`
@@ -598,8 +652,8 @@ class AsyncMappingResult(AsyncCommon):
 
         return await greenlet_spawn(self._manyrow_getter, self, size)
 
-    async def all(self):
-        """Return all scalar values in a list.
+    async def all(self) -> List[RowMapping]:
+        """Return all rows in a list.
 
         Equivalent to :meth:`_asyncio.AsyncResult.all` except that
         :class:`_result.RowMapping` values, rather than :class:`_result.Row`
@@ -609,17 +663,17 @@ class AsyncMappingResult(AsyncCommon):
 
         return await greenlet_spawn(self._allrows)
 
-    def __aiter__(self):
+    def __aiter__(self) -> AsyncMappingResult:
         return self
 
-    async def __anext__(self):
+    async def __anext__(self) -> RowMapping:
         row = await greenlet_spawn(self._onerow_getter, self)
         if row is _NO_ROW:
             raise StopAsyncIteration()
         else:
             return row
 
-    async def first(self):
+    async def first(self) -> Optional[RowMapping]:
         """Fetch the first object or None if no object is present.
 
         Equivalent to :meth:`_asyncio.AsyncResult.first` except that
@@ -630,7 +684,7 @@ class AsyncMappingResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, False, False, False)
 
-    async def one_or_none(self):
+    async def one_or_none(self) -> Optional[RowMapping]:
         """Return at most one object or raise an exception.
 
         Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
@@ -640,7 +694,7 @@ class AsyncMappingResult(AsyncCommon):
         """
         return await greenlet_spawn(self._only_one_row, True, False, False)
 
-    async def one(self):
+    async def one(self) -> RowMapping:
         """Return exactly one object or raise an exception.
 
         Equivalent to :meth:`_asyncio.AsyncResult.one` except that
@@ -651,11 +705,15 @@ class AsyncMappingResult(AsyncCommon):
         return await greenlet_spawn(self._only_one_row, True, True, False)
 
 
-async def _ensure_sync_result(result, calling_method):
+_RT = TypeVar("_RT", bound="Result")
+
+
+async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
+    cursor_result: CursorResult
     if not result._is_cursor:
-        cursor_result = getattr(result, "raw", None)
+        cursor_result = getattr(result, "raw", None)  # type: ignore
     else:
-        cursor_result = result
+        cursor_result = result  # type: ignore
     if cursor_result and cursor_result.context._is_server_side:
         await greenlet_spawn(cursor_result.close)
         raise async_exc.AsyncMethodRequired(
index 0503076aaf7d903fd00c0c461cb409ddc64abb9a..0d6ae92b4126ccaf041097c3b86d1a8c092be6bb 100644 (file)
@@ -8,12 +8,50 @@
 from __future__ import annotations
 
 from typing import Any
-
+from typing import Callable
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
+
+from .session import async_sessionmaker
 from .session import AsyncSession
+from ... import exc as sa_exc
 from ... import util
-from ...orm.scoping import ScopedSessionMixin
+from ...orm.session import Session
 from ...util import create_proxy_methods
 from ...util import ScopedRegistry
+from ...util import warn
+from ...util import warn_deprecated
+
+if TYPE_CHECKING:
+    from .engine import AsyncConnection
+    from .result import AsyncResult
+    from .result import AsyncScalarResult
+    from .session import AsyncSessionTransaction
+    from ...engine import Connection
+    from ...engine import Engine
+    from ...engine import Result
+    from ...engine import Row
+    from ...engine.interfaces import _CoreAnyExecuteParams
+    from ...engine.interfaces import _CoreSingleExecuteParams
+    from ...engine.interfaces import _ExecuteOptions
+    from ...engine.interfaces import _ExecuteOptionsParameter
+    from ...engine.result import ScalarResult
+    from ...orm._typing import _IdentityKeyType
+    from ...orm._typing import _O
+    from ...orm.interfaces import ORMOption
+    from ...orm.session import _BindArguments
+    from ...orm.session import _EntityBindKey
+    from ...orm.session import _PKIdentityArgument
+    from ...orm.session import _SessionBind
+    from ...sql.base import Executable
+    from ...sql.elements import ClauseElement
+    from ...sql.selectable import ForUpdateArg
 
 
 @create_proxy_methods(
@@ -62,7 +100,7 @@ from ...util import ScopedRegistry
         "info",
     ],
 )
-class async_scoped_session(ScopedSessionMixin):
+class async_scoped_session:
     """Provides scoped management of :class:`.AsyncSession` objects.
 
     See the section :ref:`asyncio_scoped_session` for usage details.
@@ -74,17 +112,23 @@ class async_scoped_session(ScopedSessionMixin):
 
     _support_async = True
 
-    def __init__(self, session_factory, scopefunc):
+    session_factory: async_sessionmaker
+    """The `session_factory` provided to `__init__` is stored in this
+    attribute and may be accessed at a later time.  This can be useful when
+    a new non-scoped :class:`.AsyncSession` is needed."""
+
+    registry: ScopedRegistry[AsyncSession]
+
+    def __init__(
+        self,
+        session_factory: async_sessionmaker,
+        scopefunc: Callable[[], Any],
+    ):
         """Construct a new :class:`_asyncio.async_scoped_session`.
 
         :param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
          instances. This is usually, but not necessarily, an instance
-         of :class:`_orm.sessionmaker` which itself was passed the
-         :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
-         parameter::
-
-            async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
-            AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+         of :class:`_asyncio.async_sessionmaker`.
 
         :param scopefunc: function which defines
          the current scope.   A function such as ``asyncio.current_task``
@@ -96,10 +140,59 @@ class async_scoped_session(ScopedSessionMixin):
         self.registry = ScopedRegistry(session_factory, scopefunc)
 
     @property
-    def _proxied(self):
+    def _proxied(self) -> AsyncSession:
         return self.registry()
 
-    async def remove(self):
+    def __call__(self, **kw: Any) -> AsyncSession:
+        r"""Return the current :class:`.AsyncSession`, creating it
+        using the :attr:`.scoped_session.session_factory` if not present.
+
+        :param \**kw: Keyword arguments will be passed to the
+         :attr:`.scoped_session.session_factory` callable, if an existing
+         :class:`.AsyncSession` is not present.  If the
+         :class:`.AsyncSession` is present
+         and keyword arguments have been passed,
+         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+        """
+        if kw:
+            if self.registry.has():
+                raise sa_exc.InvalidRequestError(
+                    "Scoped session is already present; "
+                    "no new arguments may be specified."
+                )
+            else:
+                sess = self.session_factory(**kw)
+                self.registry.set(sess)
+        else:
+            sess = self.registry()
+        if not self._support_async and sess._is_asyncio:
+            warn_deprecated(
+                "Using `scoped_session` with asyncio is deprecated and "
+                "will raise an error in a future version. "
+                "Please use `async_scoped_session` instead.",
+                "1.4.23",
+            )
+        return sess
+
+    def configure(self, **kwargs: Any) -> None:
+        """reconfigure the :class:`.sessionmaker` used by this
+        :class:`.scoped_session`.
+
+        See :meth:`.sessionmaker.configure`.
+
+        """
+
+        if self.registry.has():
+            warn(
+                "At least one scoped session is already present. "
+                " configure() can not affect sessions that have "
+                "already been created."
+            )
+
+        self.session_factory.configure(**kwargs)
+
+    async def remove(self) -> None:
         """Dispose of the current :class:`.AsyncSession`, if present.
 
         Different from scoped_session's remove method, this method would use
@@ -152,7 +245,9 @@ class async_scoped_session(ScopedSessionMixin):
             Proxied for the :class:`_orm.Session` class on
             behalf of the :class:`_asyncio.AsyncSession` class.
 
-        """
+
+
+        """  # noqa: E501
 
         return self._proxied.__iter__()
 
@@ -199,7 +294,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.add_all(instances)
 
-    def begin(self):
+    def begin(self) -> AsyncSessionTransaction:
         r"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
 
         .. container:: class_bases
@@ -228,7 +323,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.begin()
 
-    def begin_nested(self):
+    def begin_nested(self) -> AsyncSessionTransaction:
         r"""Return an :class:`_asyncio.AsyncSessionTransaction` object
         which will begin a "nested" transaction, e.g. SAVEPOINT.
 
@@ -247,7 +342,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.begin_nested()
 
-    async def close(self):
+    async def close(self) -> None:
         r"""Close out the transactional resources and ORM objects used by this
         :class:`_asyncio.AsyncSession`.
 
@@ -284,7 +379,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return await self._proxied.close()
 
-    async def commit(self):
+    async def commit(self) -> None:
         r"""Commit the current transaction in progress.
 
         .. container:: class_bases
@@ -296,7 +391,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return await self._proxied.commit()
 
-    async def connection(self, **kw):
+    async def connection(self, **kw: Any) -> AsyncConnection:
         r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
         this :class:`.Session` object's transactional state.
 
@@ -321,7 +416,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return await self._proxied.connection(**kw)
 
-    async def delete(self, instance):
+    async def delete(self, instance: object) -> None:
         r"""Mark an instance as deleted.
 
         .. container:: class_bases
@@ -345,12 +440,12 @@ class async_scoped_session(ScopedSessionMixin):
 
     async def execute(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Result:
         r"""Execute a statement and return a buffered
         :class:`_engine.Result` object.
 
@@ -519,7 +614,7 @@ class async_scoped_session(ScopedSessionMixin):
 
         return self._proxied.expunge_all()
 
-    async def flush(self, objects=None):
+    async def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
         r"""Flush all the object changes to the database.
 
         .. container:: class_bases
@@ -538,13 +633,15 @@ class async_scoped_session(ScopedSessionMixin):
 
     async def get(
         self,
-        entity,
-        ident,
-        options=None,
-        populate_existing=False,
-        with_for_update=None,
-        identity_token=None,
-    ):
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: Optional[ForUpdateArg] = None,
+        identity_token: Optional[Any] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> Optional[_O]:
         r"""Return an instance based on the given primary key identifier,
         or ``None`` if not found.
 
@@ -568,9 +665,16 @@ class async_scoped_session(ScopedSessionMixin):
             populate_existing=populate_existing,
             with_for_update=with_for_update,
             identity_token=identity_token,
+            execution_options=execution_options,
         )
 
-    def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+    def get_bind(
+        self,
+        mapper: Optional[_EntityBindKey[_O]] = None,
+        clause: Optional[ClauseElement] = None,
+        bind: Optional[_SessionBind] = None,
+        **kw: Any,
+    ) -> Union[Engine, Connection]:
         r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
         is bound.
 
@@ -724,7 +828,7 @@ class async_scoped_session(ScopedSessionMixin):
             instance, include_collections=include_collections
         )
 
-    async def invalidate(self):
+    async def invalidate(self) -> None:
         r"""Close this Session, using connection invalidation.
 
         .. container:: class_bases
@@ -738,7 +842,13 @@ class async_scoped_session(ScopedSessionMixin):
 
         return await self._proxied.invalidate()
 
-    async def merge(self, instance, load=True, options=None):
+    async def merge(
+        self,
+        instance: _O,
+        *,
+        load: bool = True,
+        options: Optional[Sequence[ORMOption]] = None,
+    ) -> _O:
         r"""Copy the state of a given instance into a corresponding instance
         within this :class:`_asyncio.AsyncSession`.
 
@@ -757,8 +867,11 @@ class async_scoped_session(ScopedSessionMixin):
         return await self._proxied.merge(instance, load=load, options=options)
 
     async def refresh(
-        self, instance, attribute_names=None, with_for_update=None
-    ):
+        self,
+        instance: object,
+        attribute_names: Optional[Iterable[str]] = None,
+        with_for_update: Optional[ForUpdateArg] = None,
+    ) -> None:
         r"""Expire and refresh the attributes on the given instance.
 
         .. container:: class_bases
@@ -785,7 +898,7 @@ class async_scoped_session(ScopedSessionMixin):
             with_for_update=with_for_update,
         )
 
-    async def rollback(self):
+    async def rollback(self) -> None:
         r"""Rollback the current transaction in progress.
 
         .. container:: class_bases
@@ -799,12 +912,12 @@ class async_scoped_session(ScopedSessionMixin):
 
     async def scalar(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Any:
         r"""Execute a statement and return a scalar result.
 
         .. container:: class_bases
@@ -829,12 +942,12 @@ class async_scoped_session(ScopedSessionMixin):
 
     async def scalars(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
         r"""Execute a statement and return scalar results.
 
         .. container:: class_bases
@@ -865,12 +978,12 @@ class async_scoped_session(ScopedSessionMixin):
 
     async def stream(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncResult:
         r"""Execute a statement and return a streaming
         :class:`_asyncio.AsyncResult` object.
 
@@ -892,12 +1005,12 @@ class async_scoped_session(ScopedSessionMixin):
 
     async def stream_scalars(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncScalarResult[Any]:
         r"""Execute a statement and return a stream of scalar results.
 
         .. container:: class_bases
@@ -1159,7 +1272,7 @@ class async_scoped_session(ScopedSessionMixin):
         return self._proxied.info
 
     @classmethod
-    async def close_all(self):
+    async def close_all(self) -> None:
         r"""Close all :class:`_asyncio.AsyncSession` sessions.
 
         .. container:: class_bases
index 769fe05bdb8d76b69d96c24d73a2f15b45ee337a..7d63b084c200cada046cce89dac8d8f85e4d1ba4 100644 (file)
@@ -7,17 +7,65 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
 
 from . import engine
-from . import result as _result
 from .base import ReversibleProxy
 from .base import StartableContext
 from .result import _ensure_sync_result
+from .result import AsyncResult
+from .result import AsyncScalarResult
 from ... import util
 from ...orm import object_session
 from ...orm import Session
+from ...orm import SessionTransaction
 from ...orm import state as _instance_state
 from ...util.concurrency import greenlet_spawn
+from ...util.typing import Protocol
+
+if TYPE_CHECKING:
+    from .engine import AsyncConnection
+    from .engine import AsyncEngine
+    from ...engine import Connection
+    from ...engine import Engine
+    from ...engine import Result
+    from ...engine import Row
+    from ...engine import ScalarResult
+    from ...engine import Transaction
+    from ...engine.interfaces import _CoreAnyExecuteParams
+    from ...engine.interfaces import _CoreSingleExecuteParams
+    from ...engine.interfaces import _ExecuteOptions
+    from ...engine.interfaces import _ExecuteOptionsParameter
+    from ...event import dispatcher
+    from ...orm._typing import _IdentityKeyType
+    from ...orm._typing import _O
+    from ...orm.identity import IdentityMap
+    from ...orm.interfaces import ORMOption
+    from ...orm.session import _BindArguments
+    from ...orm.session import _EntityBindKey
+    from ...orm.session import _PKIdentityArgument
+    from ...orm.session import _SessionBind
+    from ...orm.session import _SessionBindKey
+    from ...sql.base import Executable
+    from ...sql.elements import ClauseElement
+    from ...sql.selectable import ForUpdateArg
+
+_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
+
+
+class _SyncSessionCallable(Protocol):
+    def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
+        ...
+
 
 _EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
 _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
@@ -52,7 +100,7 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
         "info",
     ],
 )
-class AsyncSession(ReversibleProxy):
+class AsyncSession(ReversibleProxy[Session]):
     """Asyncio version of :class:`_orm.Session`.
 
     The :class:`_asyncio.AsyncSession` is a proxy for a traditional
@@ -69,9 +117,15 @@ class AsyncSession(ReversibleProxy):
 
     _is_asyncio = True
 
-    dispatch = None
+    dispatch: dispatcher[Session]
 
-    def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+    def __init__(
+        self,
+        bind: Optional[_AsyncSessionBind] = None,
+        binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None,
+        sync_session_class: Optional[Type[Session]] = None,
+        **kw: Any,
+    ):
         r"""Construct a new :class:`_asyncio.AsyncSession`.
 
         All parameters other than ``sync_session_class`` are passed to the
@@ -90,14 +144,15 @@ class AsyncSession(ReversibleProxy):
           .. versionadded:: 1.4.24
 
         """
-        kw["future"] = True
+        sync_bind = sync_binds = None
+
         if bind:
             self.bind = bind
-            bind = engine._get_sync_engine_or_connection(bind)
+            sync_bind = engine._get_sync_engine_or_connection(bind)
 
         if binds:
             self.binds = binds
-            binds = {
+            sync_binds = {
                 key: engine._get_sync_engine_or_connection(b)
                 for key, b in binds.items()
             }
@@ -106,10 +161,10 @@ class AsyncSession(ReversibleProxy):
             self.sync_session_class = sync_session_class
 
         self.sync_session = self._proxied = self._assign_proxied(
-            self.sync_session_class(bind=bind, binds=binds, **kw)
+            self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw)
         )
 
-    sync_session_class = Session
+    sync_session_class: Type[Session] = Session
     """The class or callable that provides the
     underlying :class:`_orm.Session` instance for a particular
     :class:`_asyncio.AsyncSession`.
@@ -138,9 +193,19 @@ class AsyncSession(ReversibleProxy):
 
     """
 
+    @classmethod
+    def _no_async_engine_events(cls) -> NoReturn:
+        raise NotImplementedError(
+            "asynchronous events are not implemented at this time.  Apply "
+            "synchronous listeners to the AsyncSession.sync_session."
+        )
+
     async def refresh(
-        self, instance, attribute_names=None, with_for_update=None
-    ):
+        self,
+        instance: object,
+        attribute_names: Optional[Iterable[str]] = None,
+        with_for_update: Optional[ForUpdateArg] = None,
+    ) -> None:
         """Expire and refresh the attributes on the given instance.
 
         A query will be issued to the database and all attributes will be
@@ -155,14 +220,16 @@ class AsyncSession(ReversibleProxy):
 
         """
 
-        return await greenlet_spawn(
+        await greenlet_spawn(
             self.sync_session.refresh,
             instance,
             attribute_names=attribute_names,
             with_for_update=with_for_update,
         )
 
-    async def run_sync(self, fn, *arg, **kw):
+    async def run_sync(
+        self, fn: _SyncSessionCallable, *arg: Any, **kw: Any
+    ) -> Any:
         """Invoke the given sync callable passing sync self as the first
         argument.
 
@@ -191,12 +258,12 @@ class AsyncSession(ReversibleProxy):
 
     async def execute(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Result:
         """Execute a statement and return a buffered
         :class:`_engine.Result` object.
 
@@ -225,12 +292,12 @@ class AsyncSession(ReversibleProxy):
 
     async def scalar(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> Any:
         """Execute a statement and return a scalar result.
 
         .. seealso::
@@ -250,12 +317,12 @@ class AsyncSession(ReversibleProxy):
 
     async def scalars(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> ScalarResult[Any]:
         """Execute a statement and return scalar results.
 
         :return: a :class:`_result.ScalarResult` object
@@ -281,13 +348,16 @@ class AsyncSession(ReversibleProxy):
 
     async def get(
         self,
-        entity,
-        ident,
-        options=None,
-        populate_existing=False,
-        with_for_update=None,
-        identity_token=None,
-    ):
+        entity: _EntityBindKey[_O],
+        ident: _PKIdentityArgument,
+        *,
+        options: Optional[Sequence[ORMOption]] = None,
+        populate_existing: bool = False,
+        with_for_update: Optional[ForUpdateArg] = None,
+        identity_token: Optional[Any] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+    ) -> Optional[_O]:
+
         """Return an instance based on the given primary key identifier,
         or ``None`` if not found.
 
@@ -297,7 +367,8 @@ class AsyncSession(ReversibleProxy):
 
 
         """
-        return await greenlet_spawn(
+
+        result_obj = await greenlet_spawn(
             self.sync_session.get,
             entity,
             ident,
@@ -306,15 +377,17 @@ class AsyncSession(ReversibleProxy):
             with_for_update=with_for_update,
             identity_token=identity_token,
         )
+        return result_obj
 
     async def stream(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreAnyExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncResult:
+
         """Execute a statement and return a streaming
         :class:`_asyncio.AsyncResult` object.
 
@@ -335,16 +408,16 @@ class AsyncSession(ReversibleProxy):
             bind_arguments=bind_arguments,
             **kw,
         )
-        return _result.AsyncResult(result)
+        return AsyncResult(result)
 
     async def stream_scalars(
         self,
-        statement,
-        params=None,
-        execution_options=util.EMPTY_DICT,
-        bind_arguments=None,
-        **kw,
-    ):
+        statement: Executable,
+        params: Optional[_CoreSingleExecuteParams] = None,
+        execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
+        **kw: Any,
+    ) -> AsyncScalarResult[Any]:
         """Execute a statement and return a stream of scalar results.
 
         :return: an :class:`_asyncio.AsyncScalarResult` object
@@ -368,7 +441,7 @@ class AsyncSession(ReversibleProxy):
         )
         return result.scalars()
 
-    async def delete(self, instance):
+    async def delete(self, instance: object) -> None:
         """Mark an instance as deleted.
 
         The database delete operation occurs upon ``flush()``.
@@ -381,9 +454,15 @@ class AsyncSession(ReversibleProxy):
             :meth:`_orm.Session.delete` - main documentation for delete
 
         """
-        return await greenlet_spawn(self.sync_session.delete, instance)
+        await greenlet_spawn(self.sync_session.delete, instance)
 
-    async def merge(self, instance, load=True, options=None):
+    async def merge(
+        self,
+        instance: _O,
+        *,
+        load: bool = True,
+        options: Optional[Sequence[ORMOption]] = None,
+    ) -> _O:
         """Copy the state of a given instance into a corresponding instance
         within this :class:`_asyncio.AsyncSession`.
 
@@ -396,7 +475,7 @@ class AsyncSession(ReversibleProxy):
             self.sync_session.merge, instance, load=load, options=options
         )
 
-    async def flush(self, objects=None):
+    async def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
         """Flush all the object changes to the database.
 
         .. seealso::
@@ -406,7 +485,7 @@ class AsyncSession(ReversibleProxy):
         """
         await greenlet_spawn(self.sync_session.flush, objects=objects)
 
-    def get_transaction(self):
+    def get_transaction(self) -> Optional[AsyncSessionTransaction]:
         """Return the current root transaction in progress, if any.
 
         :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -421,7 +500,7 @@ class AsyncSession(ReversibleProxy):
         else:
             return None
 
-    def get_nested_transaction(self):
+    def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]:
         """Return the current nested transaction in progress, if any.
 
         :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -437,7 +516,13 @@ class AsyncSession(ReversibleProxy):
         else:
             return None
 
-    def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+    def get_bind(
+        self,
+        mapper: Optional[_EntityBindKey[_O]] = None,
+        clause: Optional[ClauseElement] = None,
+        bind: Optional[_SessionBind] = None,
+        **kw: Any,
+    ) -> Union[Engine, Connection]:
         """Return a "bind" to which the synchronous proxied :class:`_orm.Session`
         is bound.
 
@@ -515,7 +600,7 @@ class AsyncSession(ReversibleProxy):
             mapper=mapper, clause=clause, bind=bind, **kw
         )
 
-    async def connection(self, **kw):
+    async def connection(self, **kw: Any) -> AsyncConnection:
         r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
         this :class:`.Session` object's transactional state.
 
@@ -539,7 +624,7 @@ class AsyncSession(ReversibleProxy):
             sync_connection
         )
 
-    def begin(self):
+    def begin(self) -> AsyncSessionTransaction:
         """Return an :class:`_asyncio.AsyncSessionTransaction` object.
 
         The underlying :class:`_orm.Session` will perform the
@@ -562,7 +647,7 @@ class AsyncSession(ReversibleProxy):
 
         return AsyncSessionTransaction(self)
 
-    def begin_nested(self):
+    def begin_nested(self) -> AsyncSessionTransaction:
         """Return an :class:`_asyncio.AsyncSessionTransaction` object
         which will begin a "nested" transaction, e.g. SAVEPOINT.
 
@@ -575,15 +660,15 @@ class AsyncSession(ReversibleProxy):
 
         return AsyncSessionTransaction(self, nested=True)
 
-    async def rollback(self):
+    async def rollback(self) -> None:
         """Rollback the current transaction in progress."""
-        return await greenlet_spawn(self.sync_session.rollback)
+        await greenlet_spawn(self.sync_session.rollback)
 
-    async def commit(self):
+    async def commit(self) -> None:
         """Commit the current transaction in progress."""
-        return await greenlet_spawn(self.sync_session.commit)
+        await greenlet_spawn(self.sync_session.commit)
 
-    async def close(self):
+    async def close(self) -> None:
         """Close out the transactional resources and ORM objects used by this
         :class:`_asyncio.AsyncSession`.
 
@@ -613,25 +698,25 @@ class AsyncSession(ReversibleProxy):
         """
         return await greenlet_spawn(self.sync_session.close)
 
-    async def invalidate(self):
+    async def invalidate(self) -> None:
         """Close this Session, using connection invalidation.
 
         For a complete description, see :meth:`_orm.Session.invalidate`.
         """
-        return await greenlet_spawn(self.sync_session.invalidate)
+        await greenlet_spawn(self.sync_session.invalidate)
 
     @classmethod
-    async def close_all(self):
+    async def close_all(self) -> None:
         """Close all :class:`_asyncio.AsyncSession` sessions."""
-        return await greenlet_spawn(self.sync_session.close_all)
+        await greenlet_spawn(self.sync_session.close_all)
 
-    async def __aenter__(self):
+    async def __aenter__(self) -> AsyncSession:
         return self
 
-    async def __aexit__(self, type_, value, traceback):
+    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
         await self.close()
 
-    def _maker_context_manager(self):
+    def _maker_context_manager(self) -> _AsyncSessionContextManager:
         # TODO: can this use asynccontextmanager ??
         return _AsyncSessionContextManager(self)
 
@@ -1142,21 +1227,159 @@ class AsyncSession(ReversibleProxy):
     # END PROXY METHODS AsyncSession
 
 
+class async_sessionmaker:
+    """A configurable :class:`.AsyncSession` factory.
+
+    The :class:`.async_sessionmaker` factory works in the same way as the
+    :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession`
+    objects when called, creating them given
+    the configurational arguments established here.
+
+    e.g.::
+
+        from sqlalchemy.ext.asyncio import create_async_engine
+        from sqlalchemy.ext.asyncio import async_sessionmaker
+
+        async def main():
+            # an AsyncEngine, which the AsyncSession will use for connection
+            # resources
+            engine = create_async_engine('postgresql+asycncpg://scott:tiger@localhost/')
+
+            AsyncSession = async_sessionmaker(engine)
+
+            async with async_session() as session:
+                session.add(some_object)
+                session.add(some_other_object)
+                await session.commit()
+
+    .. versionadded:: 2.0  :class:`.asyncio_sessionmaker` provides a
+       :class:`.sessionmaker` class that's dedicated to the
+       :class:`.AsyncSession` object, including pep-484 typing support.
+
+    .. seealso::
+
+        :ref:`asyncio_orm` - shows example use
+
+        :class:`.sessionmaker`  - general overview of the
+         :class:`.sessionmaker` architecture
+
+
+        :ref:`session_getting` - introductory text on creating
+        sessions using :class:`.sessionmaker`.
+
+    """  # noqa E501
+
+    class_: Type[AsyncSession]
+
+    def __init__(
+        self,
+        bind: Optional[_AsyncSessionBind] = None,
+        class_: Type[AsyncSession] = AsyncSession,
+        autoflush: bool = True,
+        expire_on_commit: bool = True,
+        info: Optional[Dict[Any, Any]] = None,
+        **kw: Any,
+    ):
+        r"""Construct a new :class:`.async_sessionmaker`.
+
+        All arguments here except for ``class_`` correspond to arguments
+        accepted by :class:`.Session` directly. See the
+        :meth:`.AsyncSession.__init__` docstring for more details on
+        parameters.
+
+
+        """
+        kw["bind"] = bind
+        kw["autoflush"] = autoflush
+        kw["expire_on_commit"] = expire_on_commit
+        if info is not None:
+            kw["info"] = info
+        self.kw = kw
+        self.class_ = class_
+
+    def begin(self) -> _AsyncSessionContextManager:
+        """Produce a context manager that both provides a new
+        :class:`_orm.AsyncSession` as well as a transaction that commits.
+
+
+        e.g.::
+
+            async def main():
+                Session = async_sessionmaker(some_engine)
+
+                async with Session.begin() as session:
+                    session.add(some_object)
+
+                # commits transaction, closes session
+
+
+        """
+
+        session = self()
+        return session._maker_context_manager()
+
+    def __call__(self, **local_kw: Any) -> AsyncSession:
+        """Produce a new :class:`.AsyncSession` object using the configuration
+        established in this :class:`.async_sessionmaker`.
+
+        In Python, the ``__call__`` method is invoked on an object when
+        it is "called" in the same way as a function::
+
+            AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False)
+            session = AsyncSession()  # invokes sessionmaker.__call__()
+
+        """  # noqa E501
+        for k, v in self.kw.items():
+            if k == "info" and "info" in local_kw:
+                d = v.copy()
+                d.update(local_kw["info"])
+                local_kw["info"] = d
+            else:
+                local_kw.setdefault(k, v)
+        return self.class_(**local_kw)
+
+    def configure(self, **new_kw: Any) -> None:
+        """(Re)configure the arguments for this async_sessionmaker.
+
+        e.g.::
+
+            AsyncSession = async_sessionmaker(some_engine)
+
+            AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://'))
+        """  # noqa E501
+
+        self.kw.update(new_kw)
+
+    def __repr__(self) -> str:
+        return "%s(class_=%r, %s)" % (
+            self.__class__.__name__,
+            self.class_.__name__,
+            ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()),
+        )
+
+
 class _AsyncSessionContextManager:
-    def __init__(self, async_session):
+    __slots__ = ("async_session", "trans")
+
+    async_session: AsyncSession
+    trans: AsyncSessionTransaction
+
+    def __init__(self, async_session: AsyncSession):
         self.async_session = async_session
 
-    async def __aenter__(self):
+    async def __aenter__(self) -> AsyncSession:
         self.trans = self.async_session.begin()
         await self.trans.__aenter__()
         return self.async_session
 
-    async def __aexit__(self, type_, value, traceback):
+    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
         await self.trans.__aexit__(type_, value, traceback)
         await self.async_session.__aexit__(type_, value, traceback)
 
 
-class AsyncSessionTransaction(ReversibleProxy, StartableContext):
+class AsyncSessionTransaction(
+    ReversibleProxy[SessionTransaction], StartableContext
+):
     """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
 
     This object is provided so that a transaction-holding object
@@ -1174,36 +1397,41 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
 
     __slots__ = ("session", "sync_transaction", "nested")
 
-    def __init__(self, session, nested=False):
+    session: AsyncSession
+    sync_transaction: Optional[SessionTransaction]
+
+    def __init__(self, session: AsyncSession, nested: bool = False):
         self.session = session
         self.nested = nested
         self.sync_transaction = None
 
     @property
-    def is_active(self):
+    def is_active(self) -> bool:
         return (
             self._sync_transaction() is not None
             and self._sync_transaction().is_active
         )
 
-    def _sync_transaction(self):
+    def _sync_transaction(self) -> SessionTransaction:
         if not self.sync_transaction:
             self._raise_for_not_started()
         return self.sync_transaction
 
-    async def rollback(self):
+    async def rollback(self) -> None:
         """Roll back this :class:`_asyncio.AsyncTransaction`."""
         await greenlet_spawn(self._sync_transaction().rollback)
 
-    async def commit(self):
+    async def commit(self) -> None:
         """Commit this :class:`_asyncio.AsyncTransaction`."""
 
         await greenlet_spawn(self._sync_transaction().commit)
 
-    async def start(self, is_ctxmanager=False):
+    async def start(
+        self, is_ctxmanager: bool = False
+    ) -> AsyncSessionTransaction:
         self.sync_transaction = self._assign_proxied(
             await greenlet_spawn(
-                self.session.sync_session.begin_nested
+                self.session.sync_session.begin_nested  # type: ignore
                 if self.nested
                 else self.session.sync_session.begin
             )
@@ -1212,13 +1440,13 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
             self.sync_transaction.__enter__()
         return self
 
-    async def __aexit__(self, type_, value, traceback):
+    async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
         await greenlet_spawn(
             self._sync_transaction().__exit__, type_, value, traceback
         )
 
 
-def async_object_session(instance):
+def async_object_session(instance: object) -> Optional[AsyncSession]:
     """Return the :class:`_asyncio.AsyncSession` to which the given instance
     belongs.
 
@@ -1247,7 +1475,7 @@ def async_object_session(instance):
         return None
 
 
-def async_session(session: Session) -> AsyncSession:
+def async_session(session: Session) -> Optional[AsyncSession]:
     """Return the :class:`_asyncio.AsyncSession` which is proxying the given
     :class:`_orm.Session` object, if any.
 
@@ -1260,4 +1488,4 @@ def async_session(session: Session) -> AsyncSession:
     return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
 
 
-_instance_state._async_provider = async_session
+_instance_state._async_provider = async_session  # type: ignore
index d8f57e14981faa9d55a5c9adefe13384f0bb9830..c5348c2373a245a2f2621f75150c635200022e68 100644 (file)
@@ -448,7 +448,13 @@ def _entity_descriptor(entity, key):
         ) from err
 
 
-_state_mapper = util.dottedgetter("manager.mapper")
+if TYPE_CHECKING:
+
+    def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]:
+        ...
+
+else:
+    _state_mapper = util.dottedgetter("manager.mapper")
 
 
 @inspection._inspects(type)
index e62a833975fda877c77a26a2a9d518cbd5ff0a83..c531e7cf19a6081ada3a17e1aa2f973a3fdd364e 100644 (file)
@@ -10,6 +10,7 @@
 """
 from __future__ import annotations
 
+from typing import Any
 import weakref
 
 from . import instrumentation
@@ -1324,7 +1325,7 @@ class _MapperEventsHold(_EventsHold):
 _sessionevents_lifecycle_event_names = set()
 
 
-class SessionEvents(event.Events):
+class SessionEvents(event.Events[Session]):
     """Define events specific to :class:`.Session` lifecycle.
 
     e.g.::
@@ -1396,12 +1397,21 @@ class SessionEvents(event.Events):
                 return target
         elif isinstance(target, Session):
             return target
+        elif hasattr(target, "_no_async_engine_events"):
+            target._no_async_engine_events()
         else:
             # allows alternate SessionEvents-like-classes to be consulted
             return event.Events._accept_with(target)
 
     @classmethod
-    def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
+    def _listen(
+        cls,
+        event_key: Any,
+        *,
+        raw: bool = False,
+        restore_load_context: bool = False,
+        **kw: Any,
+    ) -> None:
         is_instance_event = (
             event_key.identifier in _sessionevents_lifecycle_event_names
         )
index 030d1595b2f5208b3c4d9ea6bdabb9f81d688e43..a5dc305d22472120d490d8cf6db538b6b1a3f591 100644 (file)
@@ -35,6 +35,7 @@ from __future__ import annotations
 from typing import Any
 from typing import Dict
 from typing import Generic
+from typing import Optional
 from typing import Set
 from typing import TYPE_CHECKING
 from typing import TypeVar
@@ -44,6 +45,7 @@ from . import collections
 from . import exc
 from . import interfaces
 from . import state
+from ._typing import _O
 from .. import util
 from ..event import EventTarget
 from ..util import HasMemoized
@@ -52,6 +54,7 @@ from ..util.typing import Protocol
 if TYPE_CHECKING:
     from .attributes import InstrumentedAttribute
     from .mapper import Mapper
+    from .state import InstanceState
     from ..event import dispatcher
 
 _T = TypeVar("_T", bound=Any)
@@ -71,7 +74,7 @@ class _ExpiredAttributeLoaderProto(Protocol):
 class ClassManager(
     HasMemoized,
     Dict[str, "InstrumentedAttribute[Any]"],
-    Generic[_T],
+    Generic[_O],
     EventTarget,
 ):
     """Tracks state information at the class level."""
@@ -230,7 +233,7 @@ class ClassManager(
         return frozenset([attr.impl for attr in self.values()])
 
     @util.memoized_property
-    def mapper(self) -> Mapper[_T]:
+    def mapper(self) -> Mapper[_O]:
         # raises unless self.mapper has been assigned
         raise exc.UnmappedClassError(self.class_)
 
@@ -442,7 +445,7 @@ class ClassManager(
 
     # InstanceState management
 
-    def new_instance(self, state=None):
+    def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O:
         instance = self.class_.__new__(self.class_)
         if state is None:
             state = self._state_constructor(instance, self)
index c85861a594346665fe53bca25bf288aa23fc3e70..abe11cc68cdc9c94c74cd9ee3ce74759b41507b7 100644 (file)
@@ -132,6 +132,8 @@ class Mapper(
     _identity_class: Type[_O]
 
     always_refresh: bool
+    allow_partial_pks: bool
+    version_id_col: Optional[ColumnElement[Any]]
 
     @util.deprecated_params(
         non_primary=(
@@ -2931,7 +2933,7 @@ class Mapper(
         self,
         state: InstanceState[_O],
         dict_: _InstanceDict,
-        column: Column[Any],
+        column: ColumnElement[Any],
         passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE,
     ) -> Any:
         prop = self._columntoproperty[column]
index e498b17b4d31c60471c528b639eeeccf736bad2f..1dd7a69523d91a9714052922ebc345cc2eb2f200 100644 (file)
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
     from .interfaces import ORMOption
     from .mapper import Mapper
     from .query import Query
+    from .session import _BindArguments
     from .session import _EntityBindKey
     from .session import _PKIdentityArgument
     from .session import _SessionBind
@@ -65,65 +66,7 @@ class _QueryDescriptorType(Protocol):
 
 _O = TypeVar("_O", bound=object)
 
-__all__ = ["scoped_session", "ScopedSessionMixin"]
-
-
-class ScopedSessionMixin:
-    session_factory: sessionmaker
-    _support_async: bool
-    registry: ScopedRegistry[Session]
-
-    @property
-    def _proxied(self) -> Session:
-        return self.registry()  # type: ignore
-
-    def __call__(self, **kw: Any) -> Session:
-        r"""Return the current :class:`.Session`, creating it
-        using the :attr:`.scoped_session.session_factory` if not present.
-
-        :param \**kw: Keyword arguments will be passed to the
-         :attr:`.scoped_session.session_factory` callable, if an existing
-         :class:`.Session` is not present.  If the :class:`.Session` is present
-         and keyword arguments have been passed,
-         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
-
-        """
-        if kw:
-            if self.registry.has():
-                raise sa_exc.InvalidRequestError(
-                    "Scoped session is already present; "
-                    "no new arguments may be specified."
-                )
-            else:
-                sess = self.session_factory(**kw)
-                self.registry.set(sess)
-        else:
-            sess = self.registry()
-        if not self._support_async and sess._is_asyncio:
-            warn_deprecated(
-                "Using `scoped_session` with asyncio is deprecated and "
-                "will raise an error in a future version. "
-                "Please use `async_scoped_session` instead.",
-                "1.4.23",
-            )
-        return sess
-
-    def configure(self, **kwargs: Any) -> None:
-        """reconfigure the :class:`.sessionmaker` used by this
-        :class:`.scoped_session`.
-
-        See :meth:`.sessionmaker.configure`.
-
-        """
-
-        if self.registry.has():
-            warn(
-                "At least one scoped session is already present. "
-                " configure() can not affect sessions that have "
-                "already been created."
-            )
-
-        self.session_factory.configure(**kwargs)
+__all__ = ["scoped_session"]
 
 
 @create_proxy_methods(
@@ -173,7 +116,7 @@ class ScopedSessionMixin:
         "info",
     ],
 )
-class scoped_session(ScopedSessionMixin):
+class scoped_session:
     """Provides scoped management of :class:`.Session` objects.
 
     See :ref:`unitofwork_contextual` for a tutorial.
@@ -191,8 +134,9 @@ class scoped_session(ScopedSessionMixin):
     session_factory: sessionmaker
     """The `session_factory` provided to `__init__` is stored in this
     attribute and may be accessed at a later time.  This can be useful when
-    a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the
-    database is needed."""
+    a new non-scoped :class:`.Session` is needed."""
+
+    registry: ScopedRegistry[Session]
 
     def __init__(
         self,
@@ -222,6 +166,58 @@ class scoped_session(ScopedSessionMixin):
         else:
             self.registry = ThreadLocalRegistry(session_factory)
 
+    @property
+    def _proxied(self) -> Session:
+        return self.registry()
+
+    def __call__(self, **kw: Any) -> Session:
+        r"""Return the current :class:`.Session`, creating it
+        using the :attr:`.scoped_session.session_factory` if not present.
+
+        :param \**kw: Keyword arguments will be passed to the
+         :attr:`.scoped_session.session_factory` callable, if an existing
+         :class:`.Session` is not present.  If the :class:`.Session` is present
+         and keyword arguments have been passed,
+         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+        """
+        if kw:
+            if self.registry.has():
+                raise sa_exc.InvalidRequestError(
+                    "Scoped session is already present; "
+                    "no new arguments may be specified."
+                )
+            else:
+                sess = self.session_factory(**kw)
+                self.registry.set(sess)
+        else:
+            sess = self.registry()
+        if not self._support_async and sess._is_asyncio:
+            warn_deprecated(
+                "Using `scoped_session` with asyncio is deprecated and "
+                "will raise an error in a future version. "
+                "Please use `async_scoped_session` instead.",
+                "1.4.23",
+            )
+        return sess
+
+    def configure(self, **kwargs: Any) -> None:
+        """reconfigure the :class:`.sessionmaker` used by this
+        :class:`.scoped_session`.
+
+        See :meth:`.sessionmaker.configure`.
+
+        """
+
+        if self.registry.has():
+            warn(
+                "At least one scoped session is already present. "
+                " configure() can not affect sessions that have "
+                "already been created."
+            )
+
+        self.session_factory.configure(**kwargs)
+
     def remove(self) -> None:
         """Dispose of the current :class:`.Session`, if present.
 
@@ -494,9 +490,9 @@ class scoped_session(ScopedSessionMixin):
 
     def connection(
         self,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         execution_options: Optional[_ExecuteOptions] = None,
-    ) -> "Connection":
+    ) -> Connection:
         r"""Return a :class:`_engine.Connection` object corresponding to this
         :class:`.Session` object's transactional state.
 
@@ -557,7 +553,7 @@ class scoped_session(ScopedSessionMixin):
         statement: Executable,
         params: Optional[_CoreAnyExecuteParams] = None,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
     ) -> Result:
@@ -1567,7 +1563,7 @@ class scoped_session(ScopedSessionMixin):
         statement: Executable,
         params: Optional[_CoreSingleExecuteParams] = None,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
     ) -> Any:
         r"""Execute a statement and return a scalar result.
@@ -1597,7 +1593,7 @@ class scoped_session(ScopedSessionMixin):
         statement: Executable,
         params: Optional[_CoreSingleExecuteParams] = None,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
     ) -> ScalarResult[Any]:
         r"""Execute a statement and return the results as scalars.
index 55ce73cf54ca190c58c75e2eece549fa9ba2fd43..a26c55a2481ba683c0553a3bca8201ce01ec2b14 100644 (file)
@@ -26,7 +26,6 @@ from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
-from typing import TypeVar
 from typing import Union
 import weakref
 
@@ -38,6 +37,7 @@ from . import loading
 from . import persistence
 from . import query
 from . import state as statelib
+from ._typing import _O
 from ._typing import is_composite_class
 from ._typing import is_user_defined_option
 from .base import _class_to_mapper
@@ -119,11 +119,12 @@ _sessions: weakref.WeakValueDictionary[
 """Weak-referencing dictionary of :class:`.Session` objects.
 """
 
-_O = TypeVar("_O", bound=object)
 statelib._sessions = _sessions
 
 _PKIdentityArgument = Union[Any, Tuple[Any, ...]]
 
+_BindArguments = Dict[str, Any]
+
 _EntityBindKey = Union[Type[_O], "Mapper[_O]"]
 _SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"]
 _SessionBind = Union["Engine", "Connection"]
@@ -251,7 +252,7 @@ class ORMExecuteState(util.MemoizedSlots):
     parameters: Optional[_CoreAnyExecuteParams]
     execution_options: _ExecuteOptions
     local_execution_options: _ExecuteOptions
-    bind_arguments: Dict[str, Any]
+    bind_arguments: _BindArguments
     _compile_state_cls: Optional[Type[ORMCompileState]]
     _starting_event_idx: int
     _events_todo: List[Any]
@@ -263,7 +264,7 @@ class ORMExecuteState(util.MemoizedSlots):
         statement: Executable,
         parameters: Optional[_CoreAnyExecuteParams],
         execution_options: _ExecuteOptions,
-        bind_arguments: Dict[str, Any],
+        bind_arguments: _BindArguments,
         compile_state_cls: Optional[Type[ORMCompileState]],
         events_todo: List[_InstanceLevelDispatch[Session]],
     ):
@@ -286,7 +287,7 @@ class ORMExecuteState(util.MemoizedSlots):
         statement: Optional[Executable] = None,
         params: Optional[_CoreAnyExecuteParams] = None,
         execution_options: Optional[_ExecuteOptionsParameter] = None,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
     ) -> Result:
         """Execute the statement represented by this
         :class:`.ORMExecuteState`, without re-invoking events that have
@@ -1626,9 +1627,9 @@ class Session(_SessionClassMethods, EventTarget):
 
     def connection(
         self,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         execution_options: Optional[_ExecuteOptions] = None,
-    ) -> "Connection":
+    ) -> Connection:
         r"""Return a :class:`_engine.Connection` object corresponding to this
         :class:`.Session` object's transactional state.
 
@@ -1690,7 +1691,7 @@ class Session(_SessionClassMethods, EventTarget):
         statement: Executable,
         params: Optional[_CoreAnyExecuteParams] = None,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         _parent_execute_state: Optional[Any] = None,
         _add_event: Optional[Any] = None,
     ) -> Result:
@@ -1833,7 +1834,7 @@ class Session(_SessionClassMethods, EventTarget):
         statement: Executable,
         params: Optional[_CoreSingleExecuteParams] = None,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
     ) -> Any:
         """Execute a statement and return a scalar result.
@@ -1857,7 +1858,7 @@ class Session(_SessionClassMethods, EventTarget):
         statement: Executable,
         params: Optional[_CoreSingleExecuteParams] = None,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
-        bind_arguments: Optional[Dict[str, Any]] = None,
+        bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
     ) -> ScalarResult[Any]:
         """Execute a statement and return the results as scalars.
@@ -3099,7 +3100,7 @@ class Session(_SessionClassMethods, EventTarget):
         _recursive: Dict[InstanceState[Any], object],
         _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
     ) -> _O:
-        mapper = _state_mapper(state)
+        mapper: Mapper[_O] = _state_mapper(state)
         if state in _recursive:
             return cast(_O, _recursive[state])
 
@@ -3249,6 +3250,7 @@ class Session(_SessionClassMethods, EventTarget):
 
         if new_instance:
             merged_state.manager.dispatch.load(merged_state, None)
+
         return merged
 
     def _validate_persistent(self, state: InstanceState[Any]) -> None:
@@ -4291,7 +4293,7 @@ class sessionmaker(_SessionClassMethods):
         In Python, the ``__call__`` method is invoked on an object when
         it is "called" in the same way as a function::
 
-            Session = sessionmaker()
+            Session = sessionmaker(some_engine)
             session = Session()  # invokes sessionmaker.__call__()
 
         """
index 7ccda956598f5645b6ba27d5d6883f56eceb24aa..58f141997e0a2a9e1c2f2f703db2cfb5ee4fc31c 100644 (file)
@@ -23,12 +23,12 @@ from typing import Optional
 from typing import Set
 from typing import Tuple
 from typing import TYPE_CHECKING
-from typing import TypeVar
 import weakref
 
 from . import base
 from . import exc as orm_exc
 from . import interfaces
+from ._typing import _O
 from ._typing import is_collection_impl
 from .base import ATTR_WAS_SET
 from .base import INIT_OK
@@ -62,8 +62,6 @@ if TYPE_CHECKING:
     from ..ext.asyncio.session import async_session as _async_provider
     from ..ext.asyncio.session import AsyncSession
 
-_T = TypeVar("_T", bound=Any)
-
 if TYPE_CHECKING:
     _sessions: weakref.WeakValueDictionary[int, Session]
 else:
@@ -83,7 +81,7 @@ class _InstanceDictProto(Protocol):
 
 
 @inspection._self_inspects
-class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
+class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
     """tracks state information at the instance level.
 
     The :class:`.InstanceState` is a key object used by the
@@ -119,15 +117,15 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
         "expired_attributes",
     )
 
-    manager: ClassManager[_T]
+    manager: ClassManager[_O]
     session_id: Optional[int] = None
-    key: Optional[_IdentityKeyType[_T]] = None
+    key: Optional[_IdentityKeyType[_O]] = None
     runid: Optional[int] = None
     load_options: Tuple[ORMOption, ...] = ()
     load_path: PathRegistry = PathRegistry.root
     insert_order: Optional[int] = None
     _strong_obj: Optional[object] = None
-    obj: weakref.ref[_T]
+    obj: weakref.ref[_O]
 
     committed_state: Dict[str, Any]
 
@@ -159,7 +157,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
        see also the ``unmodified`` collection which is intersected
        against this set when a refresh operation occurs."""
 
-    callables: Dict[str, Callable[[InstanceState[_T], PassiveFlag], Any]]
+    callables: Dict[str, Callable[[InstanceState[_O], PassiveFlag], Any]]
     """A namespace where a per-state loader callable can be associated.
 
     In SQLAlchemy 1.0, this is only used for lazy loaders / deferred
@@ -174,7 +172,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
     if not TYPE_CHECKING:
         callables = util.EMPTY_DICT
 
-    def __init__(self, obj: _T, manager: ClassManager[_T]):
+    def __init__(self, obj: _O, manager: ClassManager[_O]):
         self.class_ = obj.__class__
         self.manager = manager
         self.obj = weakref.ref(obj, self._cleanup)
@@ -381,7 +379,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
             return None
 
     @property
-    def object(self) -> Optional[_T]:
+    def object(self) -> Optional[_O]:
         """Return the mapped object represented by this
         :class:`.InstanceState`.
 
@@ -411,7 +409,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
             return self.key[1]
 
     @property
-    def identity_key(self) -> Optional[_IdentityKeyType[_T]]:
+    def identity_key(self) -> Optional[_IdentityKeyType[_O]]:
         """Return the identity key for the mapped object.
 
         This is the key used to locate the object within
@@ -435,7 +433,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
         return {}
 
     @util.memoized_property
-    def mapper(self) -> Mapper[_T]:
+    def mapper(self) -> Mapper[_O]:
         """Return the :class:`_orm.Mapper` used for this mapped object."""
         return self.manager.mapper
 
@@ -452,7 +450,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
     @classmethod
     def _detach_states(
         self,
-        states: Iterable[InstanceState[_T]],
+        states: Iterable[InstanceState[_O]],
         session: Session,
         to_transient: bool = False,
     ) -> None:
@@ -497,7 +495,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
         # used by the test suite, apparently
         self._detach()
 
-    def _cleanup(self, ref: weakref.ref[_T]) -> None:
+    def _cleanup(self, ref: weakref.ref[_O]) -> None:
         """Weakref callback cleanup.
 
         This callable cleans out the state when it is being garbage
@@ -657,14 +655,14 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
 
     @classmethod
     def _instance_level_callable_processor(
-        cls, manager: ClassManager[_T], fn: _LoaderCallable, key: Any
-    ) -> Callable[[InstanceState[_T], _InstanceDict, Row], None]:
+        cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
+    ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]:
         impl = manager[key].impl
         if is_collection_impl(impl):
             fixed_impl = impl
 
             def _set_callable(
-                state: InstanceState[_T], dict_: _InstanceDict, row: Row
+                state: InstanceState[_O], dict_: _InstanceDict, row: Row
             ) -> None:
                 if "callables" not in state.__dict__:
                     state.callables = {}
@@ -676,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
         else:
 
             def _set_callable(
-                state: InstanceState[_T], dict_: _InstanceDict, row: Row
+                state: InstanceState[_O], dict_: _InstanceDict, row: Row
             ) -> None:
                 if "callables" not in state.__dict__:
                     state.callables = {}
@@ -768,7 +766,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
         self.manager.dispatch.expire(self, attribute_names)
 
     def _load_expired(
-        self, state: InstanceState[_T], passive: PassiveFlag
+        self, state: InstanceState[_O], passive: PassiveFlag
     ) -> LoaderCallableStatus:
         """__call__ allows the InstanceState to act as a deferred
         callable for loading expired attributes, which is also
index e961df1a3b69e83892fb11f1a7d9c3886922451d..1107c92b5cc8bb51c43b2a03621a355e5bd1c583 100644 (file)
@@ -73,10 +73,8 @@ class PoolEvents(event.Events[Pool]):
             return target.pool
         elif isinstance(target, Pool):
             return target
-        elif hasattr(target, "dispatch") and hasattr(
-            target.dispatch._events, "_no_async_engine_events"
-        ):
-            target.dispatch._events._no_async_engine_events()
+        elif hasattr(target, "_no_async_engine_events"):
+            target._no_async_engine_events()
         else:
             return None
 
index ccd5e8c40e224735a128eae3fc82bf62d1c25d84..d5874334087b8e433051083805af47acb674eb49 100644 (file)
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
     from .selectable import _SelectIterable
     from .selectable import FromClause
     from ..engine import Connection
+    from ..engine import CursorResult
     from ..engine import Result
     from ..engine.base import _CompiledCacheType
     from ..engine.interfaces import _CoreMultiExecuteParams
@@ -983,7 +984,7 @@ class Executable(roles.StatementRole, Generative):
             distilled_params: _CoreMultiExecuteParams,
             execution_options: _ExecuteOptionsParameter,
             _force: bool = False,
-        ) -> Result:
+        ) -> CursorResult:
             ...
 
     @util.ro_non_memoized_property
index 28b062d3d9f99d2aa7674405182449075232140f..6ad099eefcb6d19b971663ebcfddd4282cebc162 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
 
 import asyncio
 from contextvars import copy_context as _copy_context
@@ -19,6 +20,8 @@ from .langhelpers import memoized_property
 from .. import exc
 from ..util.typing import Protocol
 
+_T = TypeVar("_T", bound=Any)
+
 if typing.TYPE_CHECKING:
 
     class greenlet(Protocol):
@@ -52,8 +55,6 @@ if not typing.TYPE_CHECKING:
     except (ImportError, AttributeError):
         _copy_context = None  # noqa
 
-_T = TypeVar("_T", bound=Any)
-
 
 def is_exit_exception(e: BaseException) -> bool:
     # note asyncio.CancelledError is already BaseException
@@ -128,11 +129,11 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
 
 
 async def greenlet_spawn(
-    fn: Callable[..., Any],
+    fn: Callable[..., _T],
     *args: Any,
     _require_await: bool = False,
     **kwargs: Any,
-) -> Any:
+) -> _T:
     """Runs a sync function ``fn`` in a new greenlet.
 
     The sync function can then use :func:`await_` to wait for async
@@ -143,6 +144,7 @@ async def greenlet_spawn(
     :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
     """
 
+    result: _T
     context = _AsyncIoGreenlet(fn, getcurrent())
     # runs the function synchronously in gl greenlet. If the execution
     # is interrupted by await_, context is not dead and result is a
index 8f7f50715a040926de88c977ef415a85dda431a5..46a4eb0d85287aaefee1655447c91497eff8cee8 100644 (file)
@@ -61,7 +61,6 @@ module = [
     "sqlalchemy.engine.reflection",  # interim, should be strict
 
     # TODO for strict:
-    "sqlalchemy.ext.asyncio.*",
     "sqlalchemy.ext.automap",
     "sqlalchemy.ext.compiler",
     "sqlalchemy.ext.declarative.*",
index 0fe14dc9216929b7dec0f02e11f8a0a81b6c4d3f..462d0900f9308693629f6ddbe168849a834e4aa5 100644 (file)
@@ -480,6 +480,28 @@ class AsyncEngineTest(EngineFixture):
             eq_(async_engine.pool.checkedin(), 0)
         is_not(p1, async_engine.pool)
 
+    @testing.requires.queue_pool
+    @async_test
+    async def test_dispose_no_close(self, async_engine):
+        c1 = await async_engine.connect()
+        c2 = await async_engine.connect()
+
+        await c1.close()
+        await c2.close()
+
+        p1 = async_engine.pool
+
+        if isinstance(p1, AsyncAdaptedQueuePool):
+            eq_(async_engine.pool.checkedin(), 2)
+
+        await async_engine.dispose(close=False)
+
+        # TODO: test that DBAPI connection was not closed
+
+        if isinstance(p1, AsyncAdaptedQueuePool):
+            eq_(async_engine.pool.checkedin(), 0)
+        is_not(p1, async_engine.pool)
+
     @testing.requires.independent_connections
     @async_test
     async def test_init_once_concurrency(self, async_engine):
index f04b87f3718f1301fd3d6c794c0ef169c68560f5..ce38de51147af8bca6f524653576e174a89c0201 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import update
 from sqlalchemy.ext.asyncio import async_object_session
+from sqlalchemy.ext.asyncio import async_sessionmaker
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.ext.asyncio import exc as async_exc
 from sqlalchemy.ext.asyncio.base import ReversibleProxy
@@ -202,7 +203,7 @@ class AsyncSessionTransactionTest(AsyncFixture):
         await fn(async_session, trans_on_subject=True, execute_on_subject=True)
 
     @async_test
-    async def test_sessionmaker_block_one(self, async_engine):
+    async def test_orm_sessionmaker_block_one(self, async_engine):
 
         User = self.classes.User
         maker = sessionmaker(async_engine, class_=AsyncSession)
@@ -226,7 +227,7 @@ class AsyncSessionTransactionTest(AsyncFixture):
             eq_(u1.name, "u1")
 
     @async_test
-    async def test_sessionmaker_block_two(self, async_engine):
+    async def test_orm_sessionmaker_block_two(self, async_engine):
 
         User = self.classes.User
         maker = sessionmaker(async_engine, class_=AsyncSession)
@@ -247,6 +248,52 @@ class AsyncSessionTransactionTest(AsyncFixture):
 
             eq_(u1.name, "u1")
 
+    @async_test
+    async def test_async_sessionmaker_block_one(self, async_engine):
+
+        User = self.classes.User
+        maker = async_sessionmaker(async_engine)
+
+        session = maker()
+
+        async with session.begin():
+            u1 = User(name="u1")
+            assert session.in_transaction()
+            session.add(u1)
+
+        assert not session.in_transaction()
+
+        async with maker() as session:
+            result = await session.execute(
+                select(User).where(User.name == "u1")
+            )
+
+            u1 = result.scalar_one()
+
+            eq_(u1.name, "u1")
+
+    @async_test
+    async def test_async_sessionmaker_block_two(self, async_engine):
+
+        User = self.classes.User
+        maker = async_sessionmaker(async_engine)
+
+        async with maker.begin() as session:
+            u1 = User(name="u1")
+            assert session.in_transaction()
+            session.add(u1)
+
+        assert not session.in_transaction()
+
+        async with maker() as session:
+            result = await session.execute(
+                select(User).where(User.name == "u1")
+            )
+
+            u1 = result.scalar_one()
+
+            eq_(u1.name, "u1")
+
     @async_test
     async def test_trans(self, async_session, async_engine):
         async with async_engine.connect() as outer_conn:
@@ -882,7 +929,7 @@ class OverrideSyncSession(AsyncFixture):
         is_true(isinstance(ass.sync_session, _MySession))
         is_(ass.sync_session_class, _MySession)
 
-    def test_init_sessionmaker(self, async_engine):
+    def test_init_orm_sessionmaker(self, async_engine):
         sm = sessionmaker(
             async_engine, class_=AsyncSession, sync_session_class=_MySession
         )
@@ -891,6 +938,13 @@ class OverrideSyncSession(AsyncFixture):
         is_true(isinstance(ass.sync_session, _MySession))
         is_(ass.sync_session_class, _MySession)
 
+    def test_init_asyncio_sessionmaker(self, async_engine):
+        sm = async_sessionmaker(async_engine, sync_session_class=_MySession)
+        ass = sm()
+
+        is_true(isinstance(ass.sync_session, _MySession))
+        is_(ass.sync_session_class, _MySession)
+
     def test_subclass(self, async_engine):
         ass = _MyAS(async_engine)
 
diff --git a/test/ext/mypy/plain_files/async_sessionmaker.py b/test/ext/mypy/plain_files/async_sessionmaker.py
new file mode 100644 (file)
index 0000000..01a26d0
--- /dev/null
@@ -0,0 +1,79 @@
+"""Illustrates use of the sqlalchemy.ext.asyncio.AsyncSession object
+for asynchronous ORM use.
+
+"""
+from __future__ import annotations
+
+import asyncio
+from typing import List
+from typing import TYPE_CHECKING
+
+from sqlalchemy import ForeignKey
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.future import select
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+
+if TYPE_CHECKING:
+    from sqlalchemy import ScalarResult
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class A(Base):
+    __tablename__ = "a"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    data: Mapped[str]
+    bs: Mapped[List[B]] = relationship()
+
+
+class B(Base):
+    __tablename__ = "b"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    a_id = mapped_column(ForeignKey("a.id"))
+    data: Mapped[str]
+
+
+async def async_main() -> None:
+    """Main program function."""
+
+    engine = create_async_engine(
+        "postgresql+asyncpg://scott:tiger@localhost/test",
+        echo=True,
+    )
+
+    async with engine.begin() as conn:
+        await conn.run_sync(Base.metadata.drop_all)
+    async with engine.begin() as conn:
+        await conn.run_sync(Base.metadata.create_all)
+
+    async_session = async_sessionmaker(engine, expire_on_commit=False)
+
+    async with async_session.begin() as session:
+        session.add_all(
+            [
+                A(bs=[B(), B()], data="a1"),
+                A(bs=[B()], data="a2"),
+                A(bs=[B(), B()], data="a3"),
+            ]
+        )
+
+    async with async_session() as session:
+
+        result = await session.execute(select(A).order_by(A.id))
+
+        r: ScalarResult[A] = result.scalars()
+        a1 = r.one()
+
+        a1.data = "new data"
+
+        await session.commit()
+
+
+asyncio.run(async_main())
index 24c685e84b3d0bae8fc400fba1534038cc4cbcca..199d3a804fc9cd7a3f283af7d0a16d7f2af2bec3 100644 (file)
@@ -1,7 +1,6 @@
 from __future__ import annotations
 
 from typing import List
-from typing import Sequence
 
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
@@ -45,6 +44,6 @@ with Session(e) as sess:
     sess.commit()
 
 with Session(e) as sess:
-    users: Sequence[User] = sess.scalars(
+    users: List[User] = sess.scalars(
         select(User), execution_options={"stream_results": False}
     ).all()