From: Mike Bayer Date: Sun, 10 Apr 2022 19:42:35 +0000 (-0400) Subject: pep-484: asyncio X-Git-Tag: rel_2_0_0b1~355 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a45e2284dad17fbbba3bea9d5e5304aab21c8c94;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep-484: asyncio in this patch the asyncio/events.py module, which existed only to raise errors when trying to attach event listeners, is removed, as we were already coding an asyncio-specific workaround in upstream Pool / Session to raise this error, just moved the error out to the target and did the same thing for Engine. We also add an async_sessionmaker class. The initial rationale here is because sessionmaker() is hardcoded to Session subclasses, and there's not a way to get the use case of sessionmaker(class_=AsyncSession) to type correctly without changing the sessionmaker() symbol itself to be a function and not a class, which gets too complicated for what this is. Additionally, _SessionClassMethods has only three methods on it, one of which is not usable with asyncio (close_all()), the others not generally used from the session class. Change-Id: I064a5fa5d91cc8d5bbe9597437536e37b4e801fe --- diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index 9badcb4184..82ba7cabb2 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -147,12 +147,12 @@ illustrates a complete example including mapper and session configuration:: from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.future import select from sqlalchemy.orm import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload - from sqlalchemy.orm import sessionmaker Base = declarative_base() @@ -190,9 +190,7 @@ illustrates a complete example including mapper and session configuration:: # expire_on_commit=False will prevent attributes from being expired # after commit. - async_session = sessionmaker( - engine, expire_on_commit=False, class_=AsyncSession - ) + async_session = async_sessionmaker(engine, expire_on_commit=False) async with async_session() as session: async with session.begin(): @@ -234,7 +232,7 @@ illustrates a complete example including mapper and session configuration:: asyncio.run(async_main()) In the example above, the :class:`_asyncio.AsyncSession` is instantiated using -the optional :class:`_orm.sessionmaker` helper, and associated with an +the optional :class:`_asyncio.async_sessionmaker` helper, and associated with an :class:`_asyncio.AsyncEngine` against particular database URL. It is then used in a Python asynchronous context manager (i.e. ``async with:`` statement) so that it is automatically closed at the end of the block; this is @@ -284,8 +282,8 @@ prevent this: async_session = AsyncSession(engine, expire_on_commit=False) # sessionmaker version - async_session = sessionmaker( - engine, expire_on_commit=False, class_=AsyncSession + async_session = async_sessionmaker( + engine, expire_on_commit=False ) async with async_session() as session: @@ -722,11 +720,11 @@ constructor:: from asyncio import current_task - from sqlalchemy.orm import sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import async_scoped_session from sqlalchemy.ext.asyncio import AsyncSession - async_session_factory = sessionmaker(some_async_engine, class_=AsyncSession) + async_session_factory = async_sessionmaker(some_async_engine, expire_on_commit=False) AsyncScopedSession = async_scoped_session(async_session_factory, scopefunc=current_task) some_async_session = AsyncScopedSession() @@ -833,6 +831,10 @@ ORM Session API Documentation .. autofunction:: async_session +.. autoclass:: async_sessionmaker + :members: + :inherited-members: + .. autoclass:: async_scoped_session :members: :inherited-members: diff --git a/examples/asyncio/async_orm.py b/examples/asyncio/async_orm.py index 174ebf30b5..4688911588 100644 --- a/examples/asyncio/async_orm.py +++ b/examples/asyncio/async_orm.py @@ -11,13 +11,12 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import String -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.future import select from sqlalchemy.orm import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload -from sqlalchemy.orm import sessionmaker Base = declarative_base() @@ -58,9 +57,7 @@ async def async_main(): # expire_on_commit=False will prevent attributes from being expired # after commit. - async_session = sessionmaker( - engine, expire_on_commit=False, class_=AsyncSession - ) + async_session = async_sessionmaker(engine, expire_on_commit=False) async with async_session() as session: async with session.begin(): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 8bcc7e2587..594a193446 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -42,7 +42,7 @@ from ..sql import util as sql_util _CompiledCacheType = MutableMapping[Any, "Compiled"] if typing.TYPE_CHECKING: - from . import Result + from . import CursorResult from . import ScalarResult from .interfaces import _AnyExecuteParams from .interfaces import _AnyMultiExecuteParams @@ -472,7 +472,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): else: return self._dbapi_connection - def get_isolation_level(self) -> str: + def get_isolation_level(self) -> _IsolationLevel: """Return the current isolation level assigned to this :class:`_engine.Connection`. @@ -1186,9 +1186,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): statement: Executable, parameters: Optional[_CoreAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> Result: + ) -> CursorResult: r"""Executes a SQL statement construct and returns a - :class:`_engine.Result`. + :class:`_engine.CursorResult`. :param statement: The statement to be executed. This is always an object that is in both the :class:`_expression.ClauseElement` and @@ -1235,7 +1235,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): func: FunctionElement[Any], distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> Result: + ) -> CursorResult: """Execute a sql.FunctionElement object.""" return self._execute_clauseelement( @@ -1306,7 +1306,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ddl: DDLElement, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> Result: + ) -> CursorResult: """Execute a schema.DDL object.""" execution_options = ddl._execution_options.merge_with( @@ -1403,7 +1403,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): elem: Executable, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - ) -> Result: + ) -> CursorResult: """Execute a sql.ClauseElement object.""" execution_options = elem._execution_options.merge_with( @@ -1476,7 +1476,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): compiled: Compiled, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, - ) -> Result: + ) -> CursorResult: """Execute a sql.Compiled object. TODO: why do we have this? likely deprecate or remove @@ -1526,7 +1526,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): statement: str, parameters: Optional[_DBAPIAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, - ) -> Result: + ) -> CursorResult: r"""Executes a SQL statement construct and returns a :class:`_engine.CursorResult`. @@ -1603,7 +1603,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): execution_options: _ExecuteOptions, *args: Any, **kw: Any, - ) -> Result: + ) -> CursorResult: """Create an :class:`.ExecutionContext` and execute, returning a :class:`_engine.CursorResult`.""" diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index 699faf4897..ef10946a86 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -16,6 +16,7 @@ from typing import Tuple from typing import Type from typing import Union +from .base import Connection from .base import Engine from .interfaces import ConnectionEventsTarget from .interfaces import DBAPIConnection @@ -123,9 +124,23 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]): _dispatch_target = ConnectionEventsTarget @classmethod - def _listen( # type: ignore[override] + def _accept_with( + cls, + target: Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]], + ) -> Optional[Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]]]: + default_dispatch = super()._accept_with(target) + if default_dispatch is None and hasattr( + target, "_no_async_engine_events" + ): + target._no_async_engine_events() # type: ignore + + return default_dispatch + + @classmethod + def _listen( cls, event_key: event._EventKey[ConnectionEventsTarget], + *, retval: bool = False, **kw: Any, ) -> None: @@ -769,7 +784,9 @@ class DialectEvents(event.Events[Dialect]): def _listen( # type: ignore cls, event_key: event._EventKey[Dialect], + *, retval: bool = False, + **kw: Any, ) -> None: target = event_key.dispatch_target @@ -789,10 +806,8 @@ class DialectEvents(event.Events[Dialect]): return target.dialect elif isinstance(target, Dialect): return target - elif hasattr(target, "dispatch") and hasattr( - target.dispatch._events, "_no_async_engine_events" - ): - target.dispatch._events._no_async_engine_events() + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() else: return None diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index aa75da6141..54fe21d747 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -46,7 +46,7 @@ from ..util.typing import TypedDict if TYPE_CHECKING: from .base import Connection from .base import Engine - from .result import Result + from .cursor import CursorResult from .url import URL from ..event import _ListenerFnType from ..event import dispatcher @@ -2422,7 +2422,7 @@ class ExecutionContext: def _get_cache_stats(self) -> str: raise NotImplementedError() - def _setup_result_proxy(self) -> Result: + def _setup_result_proxy(self) -> CursorResult: raise NotImplementedError() def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int: diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 880bd8d4c2..11998e7188 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -536,7 +536,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): return interim_rows @HasMemoized_ro_memoized_attribute - def _onerow_getter(self) -> Callable[..., Union[_NoRow, _R]]: + def _onerow_getter( + self, + ) -> Callable[..., Union[Literal[_NoRow._NO_ROW], _R]]: make_row = self._row_getter post_creational_filter = self._post_creational_filter diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 8ed4c64bac..c16f6870be 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -256,6 +256,7 @@ class _HasEventsDispatch(Generic[_ET]): def _listen( cls, event_key: _EventKey[_ET], + *, propagate: bool = False, insert: bool = False, named: bool = False, @@ -361,6 +362,7 @@ class Events(_HasEventsDispatch[_ET]): def _listen( cls, event_key: _EventKey[_ET], + *, propagate: bool = False, insert: bool = False, named: bool = False, diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index 15b2cb015b..dfe89a154e 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -5,18 +5,17 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -from .engine import async_engine_from_config -from .engine import AsyncConnection -from .engine import AsyncEngine -from .engine import AsyncTransaction -from .engine import create_async_engine -from .events import AsyncConnectionEvents -from .events import AsyncSessionEvents -from .result import AsyncMappingResult -from .result import AsyncResult -from .result import AsyncScalarResult -from .scoping import async_scoped_session -from .session import async_object_session -from .session import async_session -from .session import AsyncSession -from .session import AsyncSessionTransaction +from .engine import async_engine_from_config as async_engine_from_config +from .engine import AsyncConnection as AsyncConnection +from .engine import AsyncEngine as AsyncEngine +from .engine import AsyncTransaction as AsyncTransaction +from .engine import create_async_engine as create_async_engine +from .result import AsyncMappingResult as AsyncMappingResult +from .result import AsyncResult as AsyncResult +from .result import AsyncScalarResult as AsyncScalarResult +from .scoping import async_scoped_session as async_scoped_session +from .session import async_object_session as async_object_session +from .session import async_session as async_session +from .session import async_sessionmaker as async_sessionmaker +from .session import AsyncSession as AsyncSession +from .session import AsyncSessionTransaction as AsyncSessionTransaction diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 3f77f55007..7fdd2d7e06 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -1,36 +1,103 @@ +# ext/asyncio/base.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + import abc import functools +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Type +from typing import TypeVar import weakref from . import exc as async_exc +from ... import util +from ...util.typing import Literal + +_T = TypeVar("_T", bound=Any) + + +_PT = TypeVar("_PT", bound=Any) -class ReversibleProxy: - # weakref.ref(async proxy object) -> weakref.ref(sync proxied object) - _proxy_objects = {} +SelfReversibleProxy = TypeVar( + "SelfReversibleProxy", bound="ReversibleProxy[Any]" +) + + +class ReversibleProxy(Generic[_PT]): + _proxy_objects: ClassVar[ + Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]] + ] = {} __slots__ = ("__weakref__",) - def _assign_proxied(self, target): + @overload + def _assign_proxied(self, target: _PT) -> _PT: + ... + + @overload + def _assign_proxied(self, target: None) -> None: + ... + + def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: if target is not None: - target_ref = weakref.ref(target, ReversibleProxy._target_gced) + target_ref: weakref.ref[_PT] = weakref.ref( + target, ReversibleProxy._target_gced + ) proxy_ref = weakref.ref( self, - functools.partial(ReversibleProxy._target_gced, target_ref), + functools.partial( # type: ignore + ReversibleProxy._target_gced, target_ref + ), ) ReversibleProxy._proxy_objects[target_ref] = proxy_ref return target @classmethod - def _target_gced(cls, ref, proxy_ref=None): + def _target_gced( + cls: Type[SelfReversibleProxy], + ref: weakref.ref[_PT], + proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None, + ) -> None: cls._proxy_objects.pop(ref, None) @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT + ) -> SelfReversibleProxy: raise NotImplementedError() + @overload @classmethod - def _retrieve_proxy_for_target(cls, target, regenerate=True): + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], + target: _PT, + regenerate: Literal[True] = ..., + ) -> SelfReversibleProxy: + ... + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True + ) -> Optional[SelfReversibleProxy]: + ... + + @classmethod + def _retrieve_proxy_for_target( + cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True + ) -> Optional[SelfReversibleProxy]: try: proxy_ref = cls._proxy_objects[weakref.ref(target)] except KeyError: @@ -38,7 +105,7 @@ class ReversibleProxy: else: proxy = proxy_ref() if proxy is not None: - return proxy + return proxy # type: ignore if regenerate: return cls._regenerate_proxy_for_target(target) @@ -46,43 +113,54 @@ class ReversibleProxy: return None +SelfStartableContext = TypeVar( + "SelfStartableContext", bound="StartableContext" +) + + class StartableContext(abc.ABC): __slots__ = () @abc.abstractmethod - async def start(self, is_ctxmanager=False): - pass + async def start( + self: SelfStartableContext, is_ctxmanager: bool = False + ) -> Any: + raise NotImplementedError() - def __await__(self): + def __await__(self) -> Any: return self.start().__await__() - async def __aenter__(self): + async def __aenter__(self: SelfStartableContext) -> Any: return await self.start(is_ctxmanager=True) @abc.abstractmethod - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: pass - def _raise_for_not_started(self): + def _raise_for_not_started(self) -> NoReturn: raise async_exc.AsyncContextNotStarted( "%s context has not been started and object has not been awaited." % (self.__class__.__name__) ) -class ProxyComparable(ReversibleProxy): +class ProxyComparable(ReversibleProxy[_PT]): __slots__ = () - def __hash__(self): + @util.ro_non_memoized_property + def _proxied(self) -> _PT: + raise NotImplementedError() + + def __hash__(self) -> int: return id(self) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, self.__class__) and self._proxied == other._proxied ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return ( not isinstance(other, self.__class__) or self._proxied != other._proxied diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 3b54405c15..bb51a4d225 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -7,23 +7,56 @@ from __future__ import annotations from typing import Any +from typing import Dict +from typing import Generator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from . import exc as async_exc from .base import ProxyComparable from .base import StartableContext from .result import _ensure_sync_result from .result import AsyncResult +from .result import AsyncScalarResult from ... import exc from ... import inspection from ... import util +from ...engine import Connection from ...engine import create_engine as _create_engine +from ...engine import Engine from ...engine.base import NestedTransaction -from ...future import Connection -from ...future import Engine +from ...engine.base import Transaction from ...util.concurrency import greenlet_spawn - - -def create_async_engine(*arg, **kw): +from ...util.typing import Protocol + +if TYPE_CHECKING: + from ...engine import Connection + from ...engine import Engine + from ...engine.cursor import CursorResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _DBAPIAnyExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import _ExecuteOptionsParameter + from ...engine.interfaces import _IsolationLevel + from ...engine.interfaces import Dialect + from ...engine.result import ScalarResult + from ...engine.url import URL + from ...pool import Pool + from ...pool import PoolProxiedConnection + from ...sql.base import Executable + + +class _SyncConnectionCallable(Protocol): + def __call__(self, connection: Connection, *arg: Any, **kw: Any) -> Any: + ... + + +def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: """Create a new async engine instance. Arguments passed to :func:`_asyncio.create_async_engine` are mostly @@ -43,11 +76,13 @@ def create_async_engine(*arg, **kw): ) kw["future"] = True kw["_is_async"] = True - sync_engine = _create_engine(*arg, **kw) + sync_engine = _create_engine(url, **kw) return AsyncEngine(sync_engine) -def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): +def async_engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> AsyncEngine: """Create a new AsyncEngine instance using a configuration dictionary. This function is analogous to the :func:`_sa.engine_from_config` function @@ -73,6 +108,14 @@ def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): class AsyncConnectable: __slots__ = "_slots_dispatch", "__weakref__" + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + @util.create_proxy_methods( Connection, @@ -87,7 +130,9 @@ class AsyncConnectable: "default_isolation_level", ], ) -class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): +class AsyncConnection( + ProxyComparable[Connection], StartableContext, AsyncConnectable +): """An asyncio proxy for a :class:`_engine.Connection`. :class:`_asyncio.AsyncConnection` is acquired using the @@ -115,12 +160,16 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): "sync_connection", ) - def __init__(self, async_engine, sync_connection=None): + def __init__( + self, + async_engine: AsyncEngine, + sync_connection: Optional[Connection] = None, + ): self.engine = async_engine self.sync_engine = async_engine.sync_engine self.sync_connection = self._assign_proxied(sync_connection) - sync_connection: Connection + sync_connection: Optional[Connection] """Reference to the sync-style :class:`_engine.Connection` this :class:`_asyncio.AsyncConnection` proxies requests towards. @@ -146,12 +195,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target( + cls, target: Connection + ) -> AsyncConnection: return AsyncConnection( AsyncEngine._retrieve_proxy_for_target(target.engine), target ) - async def start(self, is_ctxmanager=False): + async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: """Start this :class:`_asyncio.AsyncConnection` object's context outside of using a Python ``with:`` block. @@ -164,7 +215,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): return self @property - def connection(self): + def connection(self) -> NoReturn: """Not implemented for async; call :meth:`_asyncio.AsyncConnection.get_raw_connection`. """ @@ -174,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): "Use the get_raw_connection() method." ) - async def get_raw_connection(self): + async def get_raw_connection(self) -> PoolProxiedConnection: """Return the pooled DBAPI-level connection in use by this :class:`_asyncio.AsyncConnection`. @@ -187,16 +238,11 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): adapts the driver connection to the DBAPI protocol. """ - conn = self._sync_connection() - - return await greenlet_spawn(getattr, conn, "connection") - @property - def _proxied(self): - return self.sync_connection + return await greenlet_spawn(getattr, self._proxied, "connection") @property - def info(self): + def info(self) -> Dict[str, Any]: """Return the :attr:`_engine.Connection.info` dictionary of the underlying :class:`_engine.Connection`. @@ -211,24 +257,28 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): .. versionadded:: 1.4.0b2 """ - return self.sync_connection.info + return self._proxied.info - def _sync_connection(self): + @util.ro_non_memoized_property + def _proxied(self) -> Connection: if not self.sync_connection: self._raise_for_not_started() return self.sync_connection - def begin(self): + def begin(self) -> AsyncTransaction: """Begin a transaction prior to autobegin occurring.""" - self._sync_connection() + assert self._proxied return AsyncTransaction(self) - def begin_nested(self): + def begin_nested(self) -> AsyncTransaction: """Begin a nested transaction and return a transaction handle.""" - self._sync_connection() + assert self._proxied return AsyncTransaction(self, nested=True) - async def invalidate(self, exception=None): + async def invalidate( + self, exception: Optional[BaseException] = None + ) -> None: + """Invalidate the underlying DBAPI connection associated with this :class:`_engine.Connection`. @@ -237,39 +287,27 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ - conn = self._sync_connection() - return await greenlet_spawn(conn.invalidate, exception=exception) - - async def get_isolation_level(self): - conn = self._sync_connection() - return await greenlet_spawn(conn.get_isolation_level) - - async def set_isolation_level(self): - conn = self._sync_connection() - return await greenlet_spawn(conn.get_isolation_level) - - def in_transaction(self): - """Return True if a transaction is in progress. - - .. versionadded:: 1.4.0b2 + return await greenlet_spawn( + self._proxied.invalidate, exception=exception + ) - """ + async def get_isolation_level(self) -> _IsolationLevel: + return await greenlet_spawn(self._proxied.get_isolation_level) - conn = self._sync_connection() + def in_transaction(self) -> bool: + """Return True if a transaction is in progress.""" - return conn.in_transaction() + return self._proxied.in_transaction() - def in_nested_transaction(self): + def in_nested_transaction(self) -> bool: """Return True if a transaction is in progress. .. versionadded:: 1.4.0b2 """ - conn = self._sync_connection() - - return conn.in_nested_transaction() + return self._proxied.in_nested_transaction() - def get_transaction(self): + def get_transaction(self) -> Optional[AsyncTransaction]: """Return an :class:`.AsyncTransaction` representing the current transaction, if any. @@ -281,15 +319,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): .. versionadded:: 1.4.0b2 """ - conn = self._sync_connection() - trans = conn.get_transaction() + trans = self._proxied.get_transaction() if trans is not None: return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[AsyncTransaction]: """Return an :class:`.AsyncTransaction` representing the current nested (savepoint) transaction, if any. @@ -301,15 +338,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): .. versionadded:: 1.4.0b2 """ - conn = self._sync_connection() - trans = conn.get_nested_transaction() + trans = self._proxied.get_nested_transaction() if trans is not None: return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None - async def execution_options(self, **opt): + async def execution_options(self, **opt: Any) -> AsyncConnection: r"""Set non-SQL options for the connection which take effect during execution. @@ -321,12 +357,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ - conn = self._sync_connection() + conn = self._proxied c2 = await greenlet_spawn(conn.execution_options, **opt) assert c2 is conn return self - async def commit(self): + async def commit(self) -> None: """Commit the transaction that is currently in progress. This method commits the current transaction if one has been started. @@ -338,10 +374,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): :meth:`_future.Connection.begin` method is called. """ - conn = self._sync_connection() - await greenlet_spawn(conn.commit) + await greenlet_spawn(self._proxied.commit) - async def rollback(self): + async def rollback(self) -> None: """Roll back the transaction that is currently in progress. This method rolls back the current transaction if one has been started. @@ -355,34 +390,30 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ - conn = self._sync_connection() - await greenlet_spawn(conn.rollback) + await greenlet_spawn(self._proxied.rollback) - async def close(self): + async def close(self) -> None: """Close this :class:`_asyncio.AsyncConnection`. This has the effect of also rolling back the transaction if one is in place. """ - conn = self._sync_connection() - await greenlet_spawn(conn.close) + await greenlet_spawn(self._proxied.close) async def exec_driver_sql( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult: r"""Executes a driver-level SQL string and return buffered :class:`_engine.Result`. """ - conn = self._sync_connection() - result = await greenlet_spawn( - conn.exec_driver_sql, + self._proxied.exec_driver_sql, statement, parameters, execution_options, @@ -393,17 +424,15 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def stream( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncResult: """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object.""" - conn = self._sync_connection() - result = await greenlet_spawn( - conn.execute, + self._proxied.execute, statement, parameters, util.EMPTY_DICT.merge_with( @@ -418,10 +447,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def execute( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> CursorResult: r"""Executes a SQL statement construct and return a buffered :class:`_engine.Result`. @@ -453,10 +482,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): :return: a :class:`_engine.Result` object. """ - conn = self._sync_connection() - result = await greenlet_spawn( - conn.execute, + self._proxied.execute, statement, parameters, execution_options, @@ -466,10 +493,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def scalar( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Any: r"""Executes a SQL statement construct and returns a scalar object. This method is shorthand for invoking the @@ -485,10 +512,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def scalars( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: r"""Executes a SQL statement construct and returns a scalar objects. This method is shorthand for invoking the @@ -505,10 +532,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): async def stream_scalars( self, - statement, - parameters=None, - execution_options=util.EMPTY_DICT, - ): + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> AsyncScalarResult[Any]: r"""Executes a SQL statement and returns a streaming scalar result object. @@ -524,7 +551,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): result = await self.stream(statement, parameters, execution_options) return result.scalars() - async def run_sync(self, fn, *arg, **kw): + async def run_sync( + self, fn: _SyncConnectionCallable, *arg: Any, **kw: Any + ) -> Any: """Invoke the given sync callable passing self as the first argument. This method maintains the asyncio event loop all the way through @@ -548,14 +577,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): :ref:`session_run_sync` """ - conn = self._sync_connection() - - return await greenlet_spawn(fn, conn, *arg, **kw) + return await greenlet_spawn(fn, self._proxied, *arg, **kw) - def __await__(self): + def __await__(self) -> Generator[Any, None, AsyncConnection]: return self.start().__await__() - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() # START PROXY METHODS AsyncConnection @@ -661,7 +688,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): ], attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], ) -class AsyncEngine(ProxyComparable, AsyncConnectable): +class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): """An asyncio proxy for a :class:`_engine.Engine`. :class:`_asyncio.AsyncEngine` is acquired using the @@ -679,51 +706,60 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): # current transaction, info, etc. It should be possible to # create a new AsyncEngine that matches this one given only the # "sync" elements. - __slots__ = ("sync_engine", "_proxied") + __slots__ = "sync_engine" - _connection_cls = AsyncConnection + _connection_cls: Type[AsyncConnection] = AsyncConnection - _option_cls: type + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncEngine` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ class _trans_ctx(StartableContext): - def __init__(self, conn): + __slots__ = ("conn", "transaction") + + conn: AsyncConnection + transaction: AsyncTransaction + + def __init__(self, conn: AsyncConnection): self.conn = conn - async def start(self, is_ctxmanager=False): + async def start(self, is_ctxmanager: bool = False) -> AsyncConnection: await self.conn.start(is_ctxmanager=is_ctxmanager) self.transaction = self.conn.begin() await self.transaction.__aenter__() return self.conn - async def __aexit__(self, type_, value, traceback): + async def __aexit__( + self, type_: Any, value: Any, traceback: Any + ) -> None: await self.transaction.__aexit__(type_, value, traceback) await self.conn.close() - def __init__(self, sync_engine): + def __init__(self, sync_engine: Engine): if not sync_engine.dialect.is_async: raise exc.InvalidRequestError( "The asyncio extension requires an async driver to be used. " f"The loaded {sync_engine.dialect.driver!r} is not async." ) - self.sync_engine = self._proxied = self._assign_proxied(sync_engine) - - sync_engine: Engine - """Reference to the sync-style :class:`_engine.Engine` this - :class:`_asyncio.AsyncEngine` proxies requests towards. + self.sync_engine = self._assign_proxied(sync_engine) - This instance can be used as an event target. - - .. seealso:: - - :ref:`asyncio_events` - """ + @util.ro_non_memoized_property + def _proxied(self) -> Engine: + return self.sync_engine @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: return AsyncEngine(target) - def begin(self): + def begin(self) -> AsyncEngine._trans_ctx: """Return a context manager which when entered will deliver an :class:`_asyncio.AsyncConnection` with an :class:`_asyncio.AsyncTransaction` established. @@ -741,7 +777,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): conn = self.connect() return self._trans_ctx(conn) - def connect(self): + def connect(self) -> AsyncConnection: """Return an :class:`_asyncio.AsyncConnection` object. The :class:`_asyncio.AsyncConnection` will procure a database @@ -759,7 +795,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): return self._connection_cls(self) - async def raw_connection(self): + async def raw_connection(self) -> PoolProxiedConnection: """Return a "raw" DBAPI connection from the connection pool. .. seealso:: @@ -769,7 +805,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): """ return await greenlet_spawn(self.sync_engine.raw_connection) - def execution_options(self, **opt): + def execution_options(self, **opt: Any) -> AsyncEngine: """Return a new :class:`_asyncio.AsyncEngine` that will provide :class:`_asyncio.AsyncConnection` objects with the given execution options. @@ -781,21 +817,31 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): return AsyncEngine(self.sync_engine.execution_options(**opt)) - async def dispose(self): + async def dispose(self, close: bool = True) -> None: + """Dispose of the connection pool used by this :class:`_asyncio.AsyncEngine`. - This will close all connection pool connections that are - **currently checked in**. See the documentation for the underlying - :meth:`_future.Engine.dispose` method for further notes. + :param close: if left at its default of ``True``, has the + effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`_engine.Engine`, + so when they are closed individually, eventually the + :class:`_pool.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + If set to ``False``, the previous connection pool is de-referenced, + and otherwise not touched in any way. .. seealso:: - :meth:`_future.Engine.dispose` + :meth:`_engine.Engine.dispose` """ - return await greenlet_spawn(self.sync_engine.dispose) + return await greenlet_spawn(self.sync_engine.dispose, close=close) # START PROXY METHODS AsyncEngine @@ -973,18 +1019,24 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): # END PROXY METHODS AsyncEngine -class AsyncTransaction(ProxyComparable, StartableContext): +class AsyncTransaction(ProxyComparable[Transaction], StartableContext): """An asyncio proxy for a :class:`_engine.Transaction`.""" __slots__ = ("connection", "sync_transaction", "nested") - def __init__(self, connection, nested=False): - self.connection = connection # AsyncConnection - self.sync_transaction = None # sqlalchemy.engine.Transaction + sync_transaction: Optional[Transaction] + connection: AsyncConnection + nested: bool + + def __init__(self, connection: AsyncConnection, nested: bool = False): + self.connection = connection + self.sync_transaction = None self.nested = nested @classmethod - def _regenerate_proxy_for_target(cls, target): + def _regenerate_proxy_for_target( + cls, target: Transaction + ) -> AsyncTransaction: sync_connection = target.connection sync_transaction = target nested = isinstance(target, NestedTransaction) @@ -1000,25 +1052,22 @@ class AsyncTransaction(ProxyComparable, StartableContext): obj.nested = nested return obj - def _sync_transaction(self): + @util.ro_non_memoized_property + def _proxied(self) -> Transaction: if not self.sync_transaction: self._raise_for_not_started() return self.sync_transaction @property - def _proxied(self): - return self.sync_transaction + def is_valid(self) -> bool: + return self._proxied.is_valid @property - def is_valid(self): - return self._sync_transaction().is_valid + def is_active(self) -> bool: + return self._proxied.is_active - @property - def is_active(self): - return self._sync_transaction().is_active - - async def close(self): - """Close this :class:`.Transaction`. + async def close(self) -> None: + """Close this :class:`.AsyncTransaction`. If this transaction is the base transaction in a begin/commit nesting, the transaction will rollback(). Otherwise, the @@ -1028,18 +1077,18 @@ class AsyncTransaction(ProxyComparable, StartableContext): an enclosing transaction. """ - await greenlet_spawn(self._sync_transaction().close) + await greenlet_spawn(self._proxied.close) - async def rollback(self): - """Roll back this :class:`.Transaction`.""" - await greenlet_spawn(self._sync_transaction().rollback) + async def rollback(self) -> None: + """Roll back this :class:`.AsyncTransaction`.""" + await greenlet_spawn(self._proxied.rollback) - async def commit(self): - """Commit this :class:`.Transaction`.""" + async def commit(self) -> None: + """Commit this :class:`.AsyncTransaction`.""" - await greenlet_spawn(self._sync_transaction().commit) + await greenlet_spawn(self._proxied.commit) - async def start(self, is_ctxmanager=False): + async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction: """Start this :class:`_asyncio.AsyncTransaction` object's context outside of using a Python ``with:`` block. @@ -1047,24 +1096,36 @@ class AsyncTransaction(ProxyComparable, StartableContext): self.sync_transaction = self._assign_proxied( await greenlet_spawn( - self.connection._sync_connection().begin_nested + self.connection._proxied.begin_nested if self.nested - else self.connection._sync_connection().begin + else self.connection._proxied.begin ) ) if is_ctxmanager: self.sync_transaction.__enter__() return self - async def __aexit__(self, type_, value, traceback): - await greenlet_spawn( - self._sync_transaction().__exit__, type_, value, traceback - ) + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn(self._proxied.__exit__, type_, value, traceback) + + +@overload +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: + ... + + +@overload +def _get_sync_engine_or_connection( + async_engine: AsyncConnection, +) -> Connection: + ... -def _get_sync_engine_or_connection(async_engine): +def _get_sync_engine_or_connection( + async_engine: Union[AsyncEngine, AsyncConnection] +) -> Union[Engine, Connection]: if isinstance(async_engine, AsyncConnection): - return async_engine.sync_connection + return async_engine._proxied try: return async_engine.sync_engine @@ -1075,7 +1136,7 @@ def _get_sync_engine_or_connection(async_engine): @inspection._inspects(AsyncConnection) -def _no_insp_for_async_conn_yet(subject): +def _no_insp_for_async_conn_yet(subject: AsyncConnection) -> NoReturn: raise exc.NoInspectionAvailable( "Inspection on an AsyncConnection is currently not supported. " "Please use ``run_sync`` to pass a callable where it's possible " @@ -1085,7 +1146,7 @@ def _no_insp_for_async_conn_yet(subject): @inspection._inspects(AsyncEngine) -def _no_insp_for_async_engine_xyet(subject): +def _no_insp_for_async_engine_xyet(subject: AsyncEngine) -> NoReturn: raise exc.NoInspectionAvailable( "Inspection on an AsyncEngine is currently not supported. " "Please obtain a connection then use ``conn.run_sync`` to pass a " diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py deleted file mode 100644 index c5d5e0126e..0000000000 --- a/lib/sqlalchemy/ext/asyncio/events.py +++ /dev/null @@ -1,44 +0,0 @@ -# ext/asyncio/events.py -# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php - -from .engine import AsyncConnectable -from .session import AsyncSession -from ...engine import events as engine_event -from ...orm import events as orm_event - - -class AsyncConnectionEvents(engine_event.ConnectionEvents): - _target_class_doc = "SomeEngine" - _dispatch_target = AsyncConnectable - - @classmethod - def _no_async_engine_events(cls): - raise NotImplementedError( - "asynchronous events are not implemented at this time. Apply " - "synchronous listeners to the AsyncEngine.sync_engine or " - "AsyncConnection.sync_connection attributes." - ) - - @classmethod - def _listen(cls, event_key, retval=False): - cls._no_async_engine_events() - - -class AsyncSessionEvents(orm_event.SessionEvents): - _target_class_doc = "SomeSession" - _dispatch_target = AsyncSession - - @classmethod - def _no_async_engine_events(cls): - raise NotImplementedError( - "asynchronous events are not implemented at this time. Apply " - "synchronous listeners to the AsyncSession.sync_session." - ) - - @classmethod - def _listen(cls, event_key, retval=False): - cls._no_async_engine_events() diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index 39718735cc..a9db822a6a 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -4,25 +4,49 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations import operator +from typing import Any +from typing import AsyncIterator +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import TypeVar from . import exc as async_exc from ...engine.result import _NO_ROW +from ...engine.result import _R from ...engine.result import FilterResult from ...engine.result import FrozenResult from ...engine.result import MergedResult +from ...engine.result import ResultMetaData +from ...engine.row import Row +from ...engine.row import RowMapping from ...util.concurrency import greenlet_spawn +if TYPE_CHECKING: + from ...engine import CursorResult + from ...engine import Result + from ...engine.result import _KeyIndexType + from ...engine.result import _UniqueFilterType + from ...engine.result import RMKeyView -class AsyncCommon(FilterResult): - async def close(self): + +class AsyncCommon(FilterResult[_R]): + _real_result: Result + _metadata: ResultMetaData + + async def close(self) -> None: """Close this result.""" await greenlet_spawn(self._real_result.close) -class AsyncResult(AsyncCommon): +SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult") + + +class AsyncResult(AsyncCommon[Row]): """An asyncio wrapper around a :class:`_result.Result` object. The :class:`_asyncio.AsyncResult` only applies to statement executions that @@ -43,7 +67,7 @@ class AsyncResult(AsyncCommon): """ - def __init__(self, real_result): + def __init__(self, real_result: Result): self._real_result = real_result self._metadata = real_result._metadata @@ -56,14 +80,16 @@ class AsyncResult(AsyncCommon): "_row_getter", real_result.__dict__["_row_getter"] ) - def keys(self): + def keys(self) -> RMKeyView: """Return the :meth:`_engine.Result.keys` collection from the underlying :class:`_engine.Result`. """ return self._metadata.keys - def unique(self, strategy=None): + def unique( + self: SelfAsyncResult, strategy: Optional[_UniqueFilterType] = None + ) -> SelfAsyncResult: """Apply unique filtering to the objects returned by this :class:`_asyncio.AsyncResult`. @@ -75,7 +101,9 @@ class AsyncResult(AsyncCommon): self._unique_filter_state = (set(), strategy) return self - def columns(self, *col_expressions): + def columns( + self: SelfAsyncResult, *col_expressions: _KeyIndexType + ) -> SelfAsyncResult: r"""Establish the columns that should be returned in each row. Refer to :meth:`_engine.Result.columns` in the synchronous @@ -85,7 +113,9 @@ class AsyncResult(AsyncCommon): """ return self._column_slices(col_expressions) - async def partitions(self, size=None): + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[List[Row]]: """Iterate through sub-lists of rows of the size given. An async iterator is returned:: @@ -111,7 +141,7 @@ class AsyncResult(AsyncCommon): else: break - async def fetchone(self): + async def fetchone(self) -> Optional[Row]: """Fetch one row. When all rows are exhausted, returns None. @@ -131,9 +161,9 @@ class AsyncResult(AsyncCommon): if row is _NO_ROW: return None else: - return row # type: ignore[return-value] + return row - async def fetchmany(self, size=None): + async def fetchmany(self, size: Optional[int] = None) -> List[Row]: """Fetch many rows. When all rows are exhausted, returns an empty list. @@ -152,11 +182,9 @@ class AsyncResult(AsyncCommon): """ - return await greenlet_spawn( - self._manyrow_getter, self, size # type: ignore - ) + return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self): + async def all(self) -> List[Row]: """Return all rows in a list. Closes the result set after invocation. Subsequent invocations @@ -166,19 +194,19 @@ class AsyncResult(AsyncCommon): """ - return await greenlet_spawn(self._allrows) # type: ignore + return await greenlet_spawn(self._allrows) - def __aiter__(self): + def __aiter__(self) -> AsyncResult: return self - async def __anext__(self): + async def __anext__(self) -> Row: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self): + async def first(self) -> Optional[Row]: """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. @@ -201,7 +229,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self): + async def one_or_none(self) -> Optional[Row]: """Return at most one result or raise an exception. Returns ``None`` if the result has no rows. @@ -223,7 +251,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, False) - async def scalar_one(self): + async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and @@ -238,7 +266,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, True, True) - async def scalar_one_or_none(self): + async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and @@ -253,7 +281,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, True) - async def one(self): + async def one(self) -> Row: """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no @@ -284,7 +312,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, True, False) - async def scalar(self): + async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. Returns None if there are no rows to fetch. @@ -300,7 +328,7 @@ class AsyncResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, True) - async def freeze(self): + async def freeze(self) -> FrozenResult: """Return a callable object that will produce copies of this :class:`_asyncio.AsyncResult` when invoked. @@ -323,7 +351,7 @@ class AsyncResult(AsyncCommon): return await greenlet_spawn(FrozenResult, self) - def merge(self, *others): + def merge(self, *others: AsyncResult) -> MergedResult: """Merge this :class:`_asyncio.AsyncResult` with other compatible result objects. @@ -337,9 +365,12 @@ class AsyncResult(AsyncCommon): undefined. """ - return MergedResult(self._metadata, (self,) + others) + return MergedResult( + self._metadata, + (self._real_result,) + tuple(o._real_result for o in others), + ) - def scalars(self, index=0): + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: """Return an :class:`_asyncio.AsyncScalarResult` filtering object which will return single elements rather than :class:`_row.Row` objects. @@ -355,7 +386,7 @@ class AsyncResult(AsyncCommon): """ return AsyncScalarResult(self._real_result, index) - def mappings(self): + def mappings(self) -> AsyncMappingResult: """Apply a mappings filter to returned rows, returning an instance of :class:`_asyncio.AsyncMappingResult`. @@ -373,7 +404,12 @@ class AsyncResult(AsyncCommon): return AsyncMappingResult(self._real_result) -class AsyncScalarResult(AsyncCommon): +SelfAsyncScalarResult = TypeVar( + "SelfAsyncScalarResult", bound="AsyncScalarResult[Any]" +) + + +class AsyncScalarResult(AsyncCommon[_R]): """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values rather than :class:`_row.Row` values. @@ -389,7 +425,7 @@ class AsyncScalarResult(AsyncCommon): _generate_rows = False - def __init__(self, real_result, index): + def __init__(self, real_result: Result, index: _KeyIndexType): self._real_result = real_result if real_result._source_supports_scalars: @@ -401,7 +437,10 @@ class AsyncScalarResult(AsyncCommon): self._unique_filter_state = real_result._unique_filter_state - def unique(self, strategy=None): + def unique( + self: SelfAsyncScalarResult, + strategy: Optional[_UniqueFilterType] = None, + ) -> SelfAsyncScalarResult: """Apply unique filtering to the objects returned by this :class:`_asyncio.AsyncScalarResult`. @@ -411,7 +450,9 @@ class AsyncScalarResult(AsyncCommon): self._unique_filter_state = (set(), strategy) return self - async def partitions(self, size=None): + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[List[_R]]: """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that @@ -429,12 +470,12 @@ class AsyncScalarResult(AsyncCommon): else: break - async def fetchall(self): + async def fetchall(self) -> List[_R]: """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" return await greenlet_spawn(self._allrows) - async def fetchmany(self, size=None): + async def fetchmany(self, size: Optional[int] = None) -> List[_R]: """Fetch many objects. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that @@ -444,7 +485,7 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self): + async def all(self) -> List[_R]: """Return all scalar values in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that @@ -454,17 +495,17 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._allrows) - def __aiter__(self): + def __aiter__(self) -> AsyncScalarResult[_R]: return self - async def __anext__(self): + async def __anext__(self) -> _R: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self): + async def first(self) -> Optional[_R]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_asyncio.AsyncResult.first` except that @@ -474,7 +515,7 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self): + async def one_or_none(self) -> Optional[_R]: """Return at most one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that @@ -484,7 +525,7 @@ class AsyncScalarResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, False) - async def one(self): + async def one(self) -> _R: """Return exactly one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one` except that @@ -495,7 +536,12 @@ class AsyncScalarResult(AsyncCommon): return await greenlet_spawn(self._only_one_row, True, True, False) -class AsyncMappingResult(AsyncCommon): +SelfAsyncMappingResult = TypeVar( + "SelfAsyncMappingResult", bound="AsyncMappingResult" +) + + +class AsyncMappingResult(AsyncCommon[RowMapping]): """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values rather than :class:`_engine.Row` values. @@ -513,14 +559,14 @@ class AsyncMappingResult(AsyncCommon): _post_creational_filter = operator.attrgetter("_mapping") - def __init__(self, result): + def __init__(self, result: Result): self._real_result = result self._unique_filter_state = result._unique_filter_state self._metadata = result._metadata if result._source_supports_scalars: self._metadata = self._metadata._reduce([0]) - def keys(self): + def keys(self) -> RMKeyView: """Return an iterable view which yields the string keys that would be represented by each :class:`.Row`. @@ -535,7 +581,10 @@ class AsyncMappingResult(AsyncCommon): """ return self._metadata.keys - def unique(self, strategy=None): + def unique( + self: SelfAsyncMappingResult, + strategy: Optional[_UniqueFilterType] = None, + ) -> SelfAsyncMappingResult: """Apply unique filtering to the objects returned by this :class:`_asyncio.AsyncMappingResult`. @@ -545,11 +594,16 @@ class AsyncMappingResult(AsyncCommon): self._unique_filter_state = (set(), strategy) return self - def columns(self, *col_expressions): + def columns( + self: SelfAsyncMappingResult, *col_expressions: _KeyIndexType + ) -> SelfAsyncMappingResult: r"""Establish the columns that should be returned in each row.""" return self._column_slices(col_expressions) - async def partitions(self, size=None): + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[List[RowMapping]]: + """Iterate through sub-lists of elements of the size given. Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that @@ -567,12 +621,12 @@ class AsyncMappingResult(AsyncCommon): else: break - async def fetchall(self): + async def fetchall(self) -> List[RowMapping]: """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" return await greenlet_spawn(self._allrows) - async def fetchone(self): + async def fetchone(self) -> Optional[RowMapping]: """Fetch one object. Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that @@ -587,8 +641,8 @@ class AsyncMappingResult(AsyncCommon): else: return row - async def fetchmany(self, size=None): - """Fetch many objects. + async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]: + """Fetch many rows. Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that :class:`_result.RowMapping` values, rather than :class:`_result.Row` @@ -598,8 +652,8 @@ class AsyncMappingResult(AsyncCommon): return await greenlet_spawn(self._manyrow_getter, self, size) - async def all(self): - """Return all scalar values in a list. + async def all(self) -> List[RowMapping]: + """Return all rows in a list. Equivalent to :meth:`_asyncio.AsyncResult.all` except that :class:`_result.RowMapping` values, rather than :class:`_result.Row` @@ -609,17 +663,17 @@ class AsyncMappingResult(AsyncCommon): return await greenlet_spawn(self._allrows) - def __aiter__(self): + def __aiter__(self) -> AsyncMappingResult: return self - async def __anext__(self): + async def __anext__(self) -> RowMapping: row = await greenlet_spawn(self._onerow_getter, self) if row is _NO_ROW: raise StopAsyncIteration() else: return row - async def first(self): + async def first(self) -> Optional[RowMapping]: """Fetch the first object or None if no object is present. Equivalent to :meth:`_asyncio.AsyncResult.first` except that @@ -630,7 +684,7 @@ class AsyncMappingResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, False, False, False) - async def one_or_none(self): + async def one_or_none(self) -> Optional[RowMapping]: """Return at most one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that @@ -640,7 +694,7 @@ class AsyncMappingResult(AsyncCommon): """ return await greenlet_spawn(self._only_one_row, True, False, False) - async def one(self): + async def one(self) -> RowMapping: """Return exactly one object or raise an exception. Equivalent to :meth:`_asyncio.AsyncResult.one` except that @@ -651,11 +705,15 @@ class AsyncMappingResult(AsyncCommon): return await greenlet_spawn(self._only_one_row, True, True, False) -async def _ensure_sync_result(result, calling_method): +_RT = TypeVar("_RT", bound="Result") + + +async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: + cursor_result: CursorResult if not result._is_cursor: - cursor_result = getattr(result, "raw", None) + cursor_result = getattr(result, "raw", None) # type: ignore else: - cursor_result = result + cursor_result = result # type: ignore if cursor_result and cursor_result.context._is_server_side: await greenlet_spawn(cursor_result.close) raise async_exc.AsyncMethodRequired( diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 0503076aaf..0d6ae92b41 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -8,12 +8,50 @@ from __future__ import annotations from typing import Any - +from typing import Callable +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from .session import async_sessionmaker from .session import AsyncSession +from ... import exc as sa_exc from ... import util -from ...orm.scoping import ScopedSessionMixin +from ...orm.session import Session from ...util import create_proxy_methods from ...util import ScopedRegistry +from ...util import warn +from ...util import warn_deprecated + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .result import AsyncResult + from .result import AsyncScalarResult + from .session import AsyncSessionTransaction + from ...engine import Connection + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import _ExecuteOptionsParameter + from ...engine.result import ScalarResult + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...sql.base import Executable + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateArg @create_proxy_methods( @@ -62,7 +100,7 @@ from ...util import ScopedRegistry "info", ], ) -class async_scoped_session(ScopedSessionMixin): +class async_scoped_session: """Provides scoped management of :class:`.AsyncSession` objects. See the section :ref:`asyncio_scoped_session` for usage details. @@ -74,17 +112,23 @@ class async_scoped_session(ScopedSessionMixin): _support_async = True - def __init__(self, session_factory, scopefunc): + session_factory: async_sessionmaker + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.AsyncSession` is needed.""" + + registry: ScopedRegistry[AsyncSession] + + def __init__( + self, + session_factory: async_sessionmaker, + scopefunc: Callable[[], Any], + ): """Construct a new :class:`_asyncio.async_scoped_session`. :param session_factory: a factory to create new :class:`_asyncio.AsyncSession` instances. This is usually, but not necessarily, an instance - of :class:`_orm.sessionmaker` which itself was passed the - :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_` - parameter:: - - async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession) - AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task) + of :class:`_asyncio.async_sessionmaker`. :param scopefunc: function which defines the current scope. A function such as ``asyncio.current_task`` @@ -96,10 +140,59 @@ class async_scoped_session(ScopedSessionMixin): self.registry = ScopedRegistry(session_factory, scopefunc) @property - def _proxied(self): + def _proxied(self) -> AsyncSession: return self.registry() - async def remove(self): + def __call__(self, **kw: Any) -> AsyncSession: + r"""Return the current :class:`.AsyncSession`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.AsyncSession` is not present. If the + :class:`.AsyncSession` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + + async def remove(self) -> None: """Dispose of the current :class:`.AsyncSession`, if present. Different from scoped_session's remove method, this method would use @@ -152,7 +245,9 @@ class async_scoped_session(ScopedSessionMixin): Proxied for the :class:`_orm.Session` class on behalf of the :class:`_asyncio.AsyncSession` class. - """ + + + """ # noqa: E501 return self._proxied.__iter__() @@ -199,7 +294,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.add_all(instances) - def begin(self): + def begin(self) -> AsyncSessionTransaction: r"""Return an :class:`_asyncio.AsyncSessionTransaction` object. .. container:: class_bases @@ -228,7 +323,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.begin() - def begin_nested(self): + def begin_nested(self) -> AsyncSessionTransaction: r"""Return an :class:`_asyncio.AsyncSessionTransaction` object which will begin a "nested" transaction, e.g. SAVEPOINT. @@ -247,7 +342,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.begin_nested() - async def close(self): + async def close(self) -> None: r"""Close out the transactional resources and ORM objects used by this :class:`_asyncio.AsyncSession`. @@ -284,7 +379,7 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.close() - async def commit(self): + async def commit(self) -> None: r"""Commit the current transaction in progress. .. container:: class_bases @@ -296,7 +391,7 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.commit() - async def connection(self, **kw): + async def connection(self, **kw: Any) -> AsyncConnection: r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this :class:`.Session` object's transactional state. @@ -321,7 +416,7 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.connection(**kw) - async def delete(self, instance): + async def delete(self, instance: object) -> None: r"""Mark an instance as deleted. .. container:: class_bases @@ -345,12 +440,12 @@ class async_scoped_session(ScopedSessionMixin): async def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result: r"""Execute a statement and return a buffered :class:`_engine.Result` object. @@ -519,7 +614,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.expunge_all() - async def flush(self, objects=None): + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: r"""Flush all the object changes to the database. .. container:: class_bases @@ -538,13 +633,15 @@ class async_scoped_session(ScopedSessionMixin): async def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: r"""Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -568,9 +665,16 @@ class async_scoped_session(ScopedSessionMixin): populate_existing=populate_existing, with_for_update=with_for_update, identity_token=identity_token, + execution_options=execution_options, ) - def get_bind(self, mapper=None, clause=None, bind=None, **kw): + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session` is bound. @@ -724,7 +828,7 @@ class async_scoped_session(ScopedSessionMixin): instance, include_collections=include_collections ) - async def invalidate(self): + async def invalidate(self) -> None: r"""Close this Session, using connection invalidation. .. container:: class_bases @@ -738,7 +842,13 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.invalidate() - async def merge(self, instance, load=True, options=None): + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: r"""Copy the state of a given instance into a corresponding instance within this :class:`_asyncio.AsyncSession`. @@ -757,8 +867,11 @@ class async_scoped_session(ScopedSessionMixin): return await self._proxied.merge(instance, load=load, options=options) async def refresh( - self, instance, attribute_names=None, with_for_update=None - ): + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: r"""Expire and refresh the attributes on the given instance. .. container:: class_bases @@ -785,7 +898,7 @@ class async_scoped_session(ScopedSessionMixin): with_for_update=with_for_update, ) - async def rollback(self): + async def rollback(self) -> None: r"""Rollback the current transaction in progress. .. container:: class_bases @@ -799,12 +912,12 @@ class async_scoped_session(ScopedSessionMixin): async def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: r"""Execute a statement and return a scalar result. .. container:: class_bases @@ -829,12 +942,12 @@ class async_scoped_session(ScopedSessionMixin): async def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: r"""Execute a statement and return scalar results. .. container:: class_bases @@ -865,12 +978,12 @@ class async_scoped_session(ScopedSessionMixin): async def stream( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult: r"""Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -892,12 +1005,12 @@ class async_scoped_session(ScopedSessionMixin): async def stream_scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: r"""Execute a statement and return a stream of scalar results. .. container:: class_bases @@ -1159,7 +1272,7 @@ class async_scoped_session(ScopedSessionMixin): return self._proxied.info @classmethod - async def close_all(self): + async def close_all(self) -> None: r"""Close all :class:`_asyncio.AsyncSession` sessions. .. container:: class_bases diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 769fe05bdb..7d63b084c2 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -7,17 +7,65 @@ from __future__ import annotations from typing import Any +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union from . import engine -from . import result as _result from .base import ReversibleProxy from .base import StartableContext from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult from ... import util from ...orm import object_session from ...orm import Session +from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn +from ...util.typing import Protocol + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .engine import AsyncEngine + from ...engine import Connection + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import ScalarResult + from ...engine import Transaction + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import _ExecuteOptionsParameter + from ...event import dispatcher + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm.identity import IdentityMap + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...orm.session import _SessionBindKey + from ...sql.base import Executable + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateArg + +_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] + + +class _SyncSessionCallable(Protocol): + def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any: + ... + _EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) @@ -52,7 +100,7 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) "info", ], ) -class AsyncSession(ReversibleProxy): +class AsyncSession(ReversibleProxy[Session]): """Asyncio version of :class:`_orm.Session`. The :class:`_asyncio.AsyncSession` is a proxy for a traditional @@ -69,9 +117,15 @@ class AsyncSession(ReversibleProxy): _is_asyncio = True - dispatch = None + dispatch: dispatcher[Session] - def __init__(self, bind=None, binds=None, sync_session_class=None, **kw): + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None, + sync_session_class: Optional[Type[Session]] = None, + **kw: Any, + ): r"""Construct a new :class:`_asyncio.AsyncSession`. All parameters other than ``sync_session_class`` are passed to the @@ -90,14 +144,15 @@ class AsyncSession(ReversibleProxy): .. versionadded:: 1.4.24 """ - kw["future"] = True + sync_bind = sync_binds = None + if bind: self.bind = bind - bind = engine._get_sync_engine_or_connection(bind) + sync_bind = engine._get_sync_engine_or_connection(bind) if binds: self.binds = binds - binds = { + sync_binds = { key: engine._get_sync_engine_or_connection(b) for key, b in binds.items() } @@ -106,10 +161,10 @@ class AsyncSession(ReversibleProxy): self.sync_session_class = sync_session_class self.sync_session = self._proxied = self._assign_proxied( - self.sync_session_class(bind=bind, binds=binds, **kw) + self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw) ) - sync_session_class = Session + sync_session_class: Type[Session] = Session """The class or callable that provides the underlying :class:`_orm.Session` instance for a particular :class:`_asyncio.AsyncSession`. @@ -138,9 +193,19 @@ class AsyncSession(ReversibleProxy): """ + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) + async def refresh( - self, instance, attribute_names=None, with_for_update=None - ): + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: """Expire and refresh the attributes on the given instance. A query will be issued to the database and all attributes will be @@ -155,14 +220,16 @@ class AsyncSession(ReversibleProxy): """ - return await greenlet_spawn( + await greenlet_spawn( self.sync_session.refresh, instance, attribute_names=attribute_names, with_for_update=with_for_update, ) - async def run_sync(self, fn, *arg, **kw): + async def run_sync( + self, fn: _SyncSessionCallable, *arg: Any, **kw: Any + ) -> Any: """Invoke the given sync callable passing sync self as the first argument. @@ -191,12 +258,12 @@ class AsyncSession(ReversibleProxy): async def execute( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result: """Execute a statement and return a buffered :class:`_engine.Result` object. @@ -225,12 +292,12 @@ class AsyncSession(ReversibleProxy): async def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: """Execute a statement and return a scalar result. .. seealso:: @@ -250,12 +317,12 @@ class AsyncSession(ReversibleProxy): async def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: """Execute a statement and return scalar results. :return: a :class:`_result.ScalarResult` object @@ -281,13 +348,16 @@ class AsyncSession(ReversibleProxy): async def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: + """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -297,7 +367,8 @@ class AsyncSession(ReversibleProxy): """ - return await greenlet_spawn( + + result_obj = await greenlet_spawn( self.sync_session.get, entity, ident, @@ -306,15 +377,17 @@ class AsyncSession(ReversibleProxy): with_for_update=with_for_update, identity_token=identity_token, ) + return result_obj async def stream( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult: + """Execute a statement and return a streaming :class:`_asyncio.AsyncResult` object. @@ -335,16 +408,16 @@ class AsyncSession(ReversibleProxy): bind_arguments=bind_arguments, **kw, ) - return _result.AsyncResult(result) + return AsyncResult(result) async def stream_scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: """Execute a statement and return a stream of scalar results. :return: an :class:`_asyncio.AsyncScalarResult` object @@ -368,7 +441,7 @@ class AsyncSession(ReversibleProxy): ) return result.scalars() - async def delete(self, instance): + async def delete(self, instance: object) -> None: """Mark an instance as deleted. The database delete operation occurs upon ``flush()``. @@ -381,9 +454,15 @@ class AsyncSession(ReversibleProxy): :meth:`_orm.Session.delete` - main documentation for delete """ - return await greenlet_spawn(self.sync_session.delete, instance) + await greenlet_spawn(self.sync_session.delete, instance) - async def merge(self, instance, load=True, options=None): + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: """Copy the state of a given instance into a corresponding instance within this :class:`_asyncio.AsyncSession`. @@ -396,7 +475,7 @@ class AsyncSession(ReversibleProxy): self.sync_session.merge, instance, load=load, options=options ) - async def flush(self, objects=None): + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: """Flush all the object changes to the database. .. seealso:: @@ -406,7 +485,7 @@ class AsyncSession(ReversibleProxy): """ await greenlet_spawn(self.sync_session.flush, objects=objects) - def get_transaction(self): + def get_transaction(self) -> Optional[AsyncSessionTransaction]: """Return the current root transaction in progress, if any. :return: an :class:`_asyncio.AsyncSessionTransaction` object, or @@ -421,7 +500,7 @@ class AsyncSession(ReversibleProxy): else: return None - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]: """Return the current nested transaction in progress, if any. :return: an :class:`_asyncio.AsyncSessionTransaction` object, or @@ -437,7 +516,13 @@ class AsyncSession(ReversibleProxy): else: return None - def get_bind(self, mapper=None, clause=None, bind=None, **kw): + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: """Return a "bind" to which the synchronous proxied :class:`_orm.Session` is bound. @@ -515,7 +600,7 @@ class AsyncSession(ReversibleProxy): mapper=mapper, clause=clause, bind=bind, **kw ) - async def connection(self, **kw): + async def connection(self, **kw: Any) -> AsyncConnection: r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this :class:`.Session` object's transactional state. @@ -539,7 +624,7 @@ class AsyncSession(ReversibleProxy): sync_connection ) - def begin(self): + def begin(self) -> AsyncSessionTransaction: """Return an :class:`_asyncio.AsyncSessionTransaction` object. The underlying :class:`_orm.Session` will perform the @@ -562,7 +647,7 @@ class AsyncSession(ReversibleProxy): return AsyncSessionTransaction(self) - def begin_nested(self): + def begin_nested(self) -> AsyncSessionTransaction: """Return an :class:`_asyncio.AsyncSessionTransaction` object which will begin a "nested" transaction, e.g. SAVEPOINT. @@ -575,15 +660,15 @@ class AsyncSession(ReversibleProxy): return AsyncSessionTransaction(self, nested=True) - async def rollback(self): + async def rollback(self) -> None: """Rollback the current transaction in progress.""" - return await greenlet_spawn(self.sync_session.rollback) + await greenlet_spawn(self.sync_session.rollback) - async def commit(self): + async def commit(self) -> None: """Commit the current transaction in progress.""" - return await greenlet_spawn(self.sync_session.commit) + await greenlet_spawn(self.sync_session.commit) - async def close(self): + async def close(self) -> None: """Close out the transactional resources and ORM objects used by this :class:`_asyncio.AsyncSession`. @@ -613,25 +698,25 @@ class AsyncSession(ReversibleProxy): """ return await greenlet_spawn(self.sync_session.close) - async def invalidate(self): + async def invalidate(self) -> None: """Close this Session, using connection invalidation. For a complete description, see :meth:`_orm.Session.invalidate`. """ - return await greenlet_spawn(self.sync_session.invalidate) + await greenlet_spawn(self.sync_session.invalidate) @classmethod - async def close_all(self): + async def close_all(self) -> None: """Close all :class:`_asyncio.AsyncSession` sessions.""" - return await greenlet_spawn(self.sync_session.close_all) + await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self): + async def __aenter__(self) -> AsyncSession: return self - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self): + def _maker_context_manager(self) -> _AsyncSessionContextManager: # TODO: can this use asynccontextmanager ?? return _AsyncSessionContextManager(self) @@ -1142,21 +1227,159 @@ class AsyncSession(ReversibleProxy): # END PROXY METHODS AsyncSession +class async_sessionmaker: + """A configurable :class:`.AsyncSession` factory. + + The :class:`.async_sessionmaker` factory works in the same way as the + :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession` + objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + + async def main(): + # an AsyncEngine, which the AsyncSession will use for connection + # resources + engine = create_async_engine('postgresql+asycncpg://scott:tiger@localhost/') + + AsyncSession = async_sessionmaker(engine) + + async with async_session() as session: + session.add(some_object) + session.add(some_other_object) + await session.commit() + + .. versionadded:: 2.0 :class:`.asyncio_sessionmaker` provides a + :class:`.sessionmaker` class that's dedicated to the + :class:`.AsyncSession` object, including pep-484 typing support. + + .. seealso:: + + :ref:`asyncio_orm` - shows example use + + :class:`.sessionmaker` - general overview of the + :class:`.sessionmaker` architecture + + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ # noqa E501 + + class_: Type[AsyncSession] + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + class_: Type[AsyncSession] = AsyncSession, + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[Dict[Any, Any]] = None, + **kw: Any, + ): + r"""Construct a new :class:`.async_sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.AsyncSession.__init__` docstring for more details on + parameters. + + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + self.class_ = class_ + + def begin(self) -> _AsyncSessionContextManager: + """Produce a context manager that both provides a new + :class:`_orm.AsyncSession` as well as a transaction that commits. + + + e.g.:: + + async def main(): + Session = async_sessionmaker(some_engine) + + async with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> AsyncSession: + """Produce a new :class:`.AsyncSession` object using the configuration + established in this :class:`.async_sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False) + session = AsyncSession() # invokes sessionmaker.__call__() + + """ # noqa E501 + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this async_sessionmaker. + + e.g.:: + + AsyncSession = async_sessionmaker(some_engine) + + AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) + """ # noqa E501 + + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + class _AsyncSessionContextManager: - def __init__(self, async_session): + __slots__ = ("async_session", "trans") + + async_session: AsyncSession + trans: AsyncSessionTransaction + + def __init__(self, async_session: AsyncSession): self.async_session = async_session - async def __aenter__(self): + async def __aenter__(self) -> AsyncSession: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.trans.__aexit__(type_, value, traceback) await self.async_session.__aexit__(type_, value, traceback) -class AsyncSessionTransaction(ReversibleProxy, StartableContext): +class AsyncSessionTransaction( + ReversibleProxy[SessionTransaction], StartableContext +): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. This object is provided so that a transaction-holding object @@ -1174,36 +1397,41 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext): __slots__ = ("session", "sync_transaction", "nested") - def __init__(self, session, nested=False): + session: AsyncSession + sync_transaction: Optional[SessionTransaction] + + def __init__(self, session: AsyncSession, nested: bool = False): self.session = session self.nested = nested self.sync_transaction = None @property - def is_active(self): + def is_active(self) -> bool: return ( self._sync_transaction() is not None and self._sync_transaction().is_active ) - def _sync_transaction(self): + def _sync_transaction(self) -> SessionTransaction: if not self.sync_transaction: self._raise_for_not_started() return self.sync_transaction - async def rollback(self): + async def rollback(self) -> None: """Roll back this :class:`_asyncio.AsyncTransaction`.""" await greenlet_spawn(self._sync_transaction().rollback) - async def commit(self): + async def commit(self) -> None: """Commit this :class:`_asyncio.AsyncTransaction`.""" await greenlet_spawn(self._sync_transaction().commit) - async def start(self, is_ctxmanager=False): + async def start( + self, is_ctxmanager: bool = False + ) -> AsyncSessionTransaction: self.sync_transaction = self._assign_proxied( await greenlet_spawn( - self.session.sync_session.begin_nested + self.session.sync_session.begin_nested # type: ignore if self.nested else self.session.sync_session.begin ) @@ -1212,13 +1440,13 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext): self.sync_transaction.__enter__() return self - async def __aexit__(self, type_, value, traceback): + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await greenlet_spawn( self._sync_transaction().__exit__, type_, value, traceback ) -def async_object_session(instance): +def async_object_session(instance: object) -> Optional[AsyncSession]: """Return the :class:`_asyncio.AsyncSession` to which the given instance belongs. @@ -1247,7 +1475,7 @@ def async_object_session(instance): return None -def async_session(session: Session) -> AsyncSession: +def async_session(session: Session) -> Optional[AsyncSession]: """Return the :class:`_asyncio.AsyncSession` which is proxying the given :class:`_orm.Session` object, if any. @@ -1260,4 +1488,4 @@ def async_session(session: Session) -> AsyncSession: return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) -_instance_state._async_provider = async_session +_instance_state._async_provider = async_session # type: ignore diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index d8f57e1498..c5348c2373 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -448,7 +448,13 @@ def _entity_descriptor(entity, key): ) from err -_state_mapper = util.dottedgetter("manager.mapper") +if TYPE_CHECKING: + + def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: + ... + +else: + _state_mapper = util.dottedgetter("manager.mapper") @inspection._inspects(type) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index e62a833975..c531e7cf19 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -10,6 +10,7 @@ """ from __future__ import annotations +from typing import Any import weakref from . import instrumentation @@ -1324,7 +1325,7 @@ class _MapperEventsHold(_EventsHold): _sessionevents_lifecycle_event_names = set() -class SessionEvents(event.Events): +class SessionEvents(event.Events[Session]): """Define events specific to :class:`.Session` lifecycle. e.g.:: @@ -1396,12 +1397,21 @@ class SessionEvents(event.Events): return target elif isinstance(target, Session): return target + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() else: # allows alternate SessionEvents-like-classes to be consulted return event.Events._accept_with(target) @classmethod - def _listen(cls, event_key, raw=False, restore_load_context=False, **kw): + def _listen( + cls, + event_key: Any, + *, + raw: bool = False, + restore_load_context: bool = False, + **kw: Any, + ) -> None: is_instance_event = ( event_key.identifier in _sessionevents_lifecycle_event_names ) diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 030d1595b2..a5dc305d22 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -35,6 +35,7 @@ from __future__ import annotations from typing import Any from typing import Dict from typing import Generic +from typing import Optional from typing import Set from typing import TYPE_CHECKING from typing import TypeVar @@ -44,6 +45,7 @@ from . import collections from . import exc from . import interfaces from . import state +from ._typing import _O from .. import util from ..event import EventTarget from ..util import HasMemoized @@ -52,6 +54,7 @@ from ..util.typing import Protocol if TYPE_CHECKING: from .attributes import InstrumentedAttribute from .mapper import Mapper + from .state import InstanceState from ..event import dispatcher _T = TypeVar("_T", bound=Any) @@ -71,7 +74,7 @@ class _ExpiredAttributeLoaderProto(Protocol): class ClassManager( HasMemoized, Dict[str, "InstrumentedAttribute[Any]"], - Generic[_T], + Generic[_O], EventTarget, ): """Tracks state information at the class level.""" @@ -230,7 +233,7 @@ class ClassManager( return frozenset([attr.impl for attr in self.values()]) @util.memoized_property - def mapper(self) -> Mapper[_T]: + def mapper(self) -> Mapper[_O]: # raises unless self.mapper has been assigned raise exc.UnmappedClassError(self.class_) @@ -442,7 +445,7 @@ class ClassManager( # InstanceState management - def new_instance(self, state=None): + def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O: instance = self.class_.__new__(self.class_) if state is None: state = self._state_constructor(instance, self) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c85861a594..abe11cc68c 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -132,6 +132,8 @@ class Mapper( _identity_class: Type[_O] always_refresh: bool + allow_partial_pks: bool + version_id_col: Optional[ColumnElement[Any]] @util.deprecated_params( non_primary=( @@ -2931,7 +2933,7 @@ class Mapper( self, state: InstanceState[_O], dict_: _InstanceDict, - column: Column[Any], + column: ColumnElement[Any], passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, ) -> Any: prop = self._columntoproperty[column] diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index e498b17b4d..1dd7a69523 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: from .interfaces import ORMOption from .mapper import Mapper from .query import Query + from .session import _BindArguments from .session import _EntityBindKey from .session import _PKIdentityArgument from .session import _SessionBind @@ -65,65 +66,7 @@ class _QueryDescriptorType(Protocol): _O = TypeVar("_O", bound=object) -__all__ = ["scoped_session", "ScopedSessionMixin"] - - -class ScopedSessionMixin: - session_factory: sessionmaker - _support_async: bool - registry: ScopedRegistry[Session] - - @property - def _proxied(self) -> Session: - return self.registry() # type: ignore - - def __call__(self, **kw: Any) -> Session: - r"""Return the current :class:`.Session`, creating it - using the :attr:`.scoped_session.session_factory` if not present. - - :param \**kw: Keyword arguments will be passed to the - :attr:`.scoped_session.session_factory` callable, if an existing - :class:`.Session` is not present. If the :class:`.Session` is present - and keyword arguments have been passed, - :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. - - """ - if kw: - if self.registry.has(): - raise sa_exc.InvalidRequestError( - "Scoped session is already present; " - "no new arguments may be specified." - ) - else: - sess = self.session_factory(**kw) - self.registry.set(sess) - else: - sess = self.registry() - if not self._support_async and sess._is_asyncio: - warn_deprecated( - "Using `scoped_session` with asyncio is deprecated and " - "will raise an error in a future version. " - "Please use `async_scoped_session` instead.", - "1.4.23", - ) - return sess - - def configure(self, **kwargs: Any) -> None: - """reconfigure the :class:`.sessionmaker` used by this - :class:`.scoped_session`. - - See :meth:`.sessionmaker.configure`. - - """ - - if self.registry.has(): - warn( - "At least one scoped session is already present. " - " configure() can not affect sessions that have " - "already been created." - ) - - self.session_factory.configure(**kwargs) +__all__ = ["scoped_session"] @create_proxy_methods( @@ -173,7 +116,7 @@ class ScopedSessionMixin: "info", ], ) -class scoped_session(ScopedSessionMixin): +class scoped_session: """Provides scoped management of :class:`.Session` objects. See :ref:`unitofwork_contextual` for a tutorial. @@ -191,8 +134,9 @@ class scoped_session(ScopedSessionMixin): session_factory: sessionmaker """The `session_factory` provided to `__init__` is stored in this attribute and may be accessed at a later time. This can be useful when - a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the - database is needed.""" + a new non-scoped :class:`.Session` is needed.""" + + registry: ScopedRegistry[Session] def __init__( self, @@ -222,6 +166,58 @@ class scoped_session(ScopedSessionMixin): else: self.registry = ThreadLocalRegistry(session_factory) + @property + def _proxied(self) -> Session: + return self.registry() + + def __call__(self, **kw: Any) -> Session: + r"""Return the current :class:`.Session`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.Session` is not present. If the :class:`.Session` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + def remove(self) -> None: """Dispose of the current :class:`.Session`, if present. @@ -494,9 +490,9 @@ class scoped_session(ScopedSessionMixin): def connection( self, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, execution_options: Optional[_ExecuteOptions] = None, - ) -> "Connection": + ) -> Connection: r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -557,7 +553,7 @@ class scoped_session(ScopedSessionMixin): statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, ) -> Result: @@ -1567,7 +1563,7 @@ class scoped_session(ScopedSessionMixin): statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, **kw: Any, ) -> Any: r"""Execute a statement and return a scalar result. @@ -1597,7 +1593,7 @@ class scoped_session(ScopedSessionMixin): statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, **kw: Any, ) -> ScalarResult[Any]: r"""Execute a statement and return the results as scalars. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 55ce73cf54..a26c55a248 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -26,7 +26,6 @@ from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING -from typing import TypeVar from typing import Union import weakref @@ -38,6 +37,7 @@ from . import loading from . import persistence from . import query from . import state as statelib +from ._typing import _O from ._typing import is_composite_class from ._typing import is_user_defined_option from .base import _class_to_mapper @@ -119,11 +119,12 @@ _sessions: weakref.WeakValueDictionary[ """Weak-referencing dictionary of :class:`.Session` objects. """ -_O = TypeVar("_O", bound=object) statelib._sessions = _sessions _PKIdentityArgument = Union[Any, Tuple[Any, ...]] +_BindArguments = Dict[str, Any] + _EntityBindKey = Union[Type[_O], "Mapper[_O]"] _SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"] _SessionBind = Union["Engine", "Connection"] @@ -251,7 +252,7 @@ class ORMExecuteState(util.MemoizedSlots): parameters: Optional[_CoreAnyExecuteParams] execution_options: _ExecuteOptions local_execution_options: _ExecuteOptions - bind_arguments: Dict[str, Any] + bind_arguments: _BindArguments _compile_state_cls: Optional[Type[ORMCompileState]] _starting_event_idx: int _events_todo: List[Any] @@ -263,7 +264,7 @@ class ORMExecuteState(util.MemoizedSlots): statement: Executable, parameters: Optional[_CoreAnyExecuteParams], execution_options: _ExecuteOptions, - bind_arguments: Dict[str, Any], + bind_arguments: _BindArguments, compile_state_cls: Optional[Type[ORMCompileState]], events_todo: List[_InstanceLevelDispatch[Session]], ): @@ -286,7 +287,7 @@ class ORMExecuteState(util.MemoizedSlots): statement: Optional[Executable] = None, params: Optional[_CoreAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, ) -> Result: """Execute the statement represented by this :class:`.ORMExecuteState`, without re-invoking events that have @@ -1626,9 +1627,9 @@ class Session(_SessionClassMethods, EventTarget): def connection( self, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, execution_options: Optional[_ExecuteOptions] = None, - ) -> "Connection": + ) -> Connection: r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -1690,7 +1691,7 @@ class Session(_SessionClassMethods, EventTarget): statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, ) -> Result: @@ -1833,7 +1834,7 @@ class Session(_SessionClassMethods, EventTarget): statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, **kw: Any, ) -> Any: """Execute a statement and return a scalar result. @@ -1857,7 +1858,7 @@ class Session(_SessionClassMethods, EventTarget): statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[_BindArguments] = None, **kw: Any, ) -> ScalarResult[Any]: """Execute a statement and return the results as scalars. @@ -3099,7 +3100,7 @@ class Session(_SessionClassMethods, EventTarget): _recursive: Dict[InstanceState[Any], object], _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], ) -> _O: - mapper = _state_mapper(state) + mapper: Mapper[_O] = _state_mapper(state) if state in _recursive: return cast(_O, _recursive[state]) @@ -3249,6 +3250,7 @@ class Session(_SessionClassMethods, EventTarget): if new_instance: merged_state.manager.dispatch.load(merged_state, None) + return merged def _validate_persistent(self, state: InstanceState[Any]) -> None: @@ -4291,7 +4293,7 @@ class sessionmaker(_SessionClassMethods): In Python, the ``__call__`` method is invoked on an object when it is "called" in the same way as a function:: - Session = sessionmaker() + Session = sessionmaker(some_engine) session = Session() # invokes sessionmaker.__call__() """ diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 7ccda95659..58f141997e 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -23,12 +23,12 @@ from typing import Optional from typing import Set from typing import Tuple from typing import TYPE_CHECKING -from typing import TypeVar import weakref from . import base from . import exc as orm_exc from . import interfaces +from ._typing import _O from ._typing import is_collection_impl from .base import ATTR_WAS_SET from .base import INIT_OK @@ -62,8 +62,6 @@ if TYPE_CHECKING: from ..ext.asyncio.session import async_session as _async_provider from ..ext.asyncio.session import AsyncSession -_T = TypeVar("_T", bound=Any) - if TYPE_CHECKING: _sessions: weakref.WeakValueDictionary[int, Session] else: @@ -83,7 +81,7 @@ class _InstanceDictProto(Protocol): @inspection._self_inspects -class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): +class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): """tracks state information at the instance level. The :class:`.InstanceState` is a key object used by the @@ -119,15 +117,15 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): "expired_attributes", ) - manager: ClassManager[_T] + manager: ClassManager[_O] session_id: Optional[int] = None - key: Optional[_IdentityKeyType[_T]] = None + key: Optional[_IdentityKeyType[_O]] = None runid: Optional[int] = None load_options: Tuple[ORMOption, ...] = () load_path: PathRegistry = PathRegistry.root insert_order: Optional[int] = None _strong_obj: Optional[object] = None - obj: weakref.ref[_T] + obj: weakref.ref[_O] committed_state: Dict[str, Any] @@ -159,7 +157,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): see also the ``unmodified`` collection which is intersected against this set when a refresh operation occurs.""" - callables: Dict[str, Callable[[InstanceState[_T], PassiveFlag], Any]] + callables: Dict[str, Callable[[InstanceState[_O], PassiveFlag], Any]] """A namespace where a per-state loader callable can be associated. In SQLAlchemy 1.0, this is only used for lazy loaders / deferred @@ -174,7 +172,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): if not TYPE_CHECKING: callables = util.EMPTY_DICT - def __init__(self, obj: _T, manager: ClassManager[_T]): + def __init__(self, obj: _O, manager: ClassManager[_O]): self.class_ = obj.__class__ self.manager = manager self.obj = weakref.ref(obj, self._cleanup) @@ -381,7 +379,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): return None @property - def object(self) -> Optional[_T]: + def object(self) -> Optional[_O]: """Return the mapped object represented by this :class:`.InstanceState`. @@ -411,7 +409,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): return self.key[1] @property - def identity_key(self) -> Optional[_IdentityKeyType[_T]]: + def identity_key(self) -> Optional[_IdentityKeyType[_O]]: """Return the identity key for the mapped object. This is the key used to locate the object within @@ -435,7 +433,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): return {} @util.memoized_property - def mapper(self) -> Mapper[_T]: + def mapper(self) -> Mapper[_O]: """Return the :class:`_orm.Mapper` used for this mapped object.""" return self.manager.mapper @@ -452,7 +450,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): @classmethod def _detach_states( self, - states: Iterable[InstanceState[_T]], + states: Iterable[InstanceState[_O]], session: Session, to_transient: bool = False, ) -> None: @@ -497,7 +495,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): # used by the test suite, apparently self._detach() - def _cleanup(self, ref: weakref.ref[_T]) -> None: + def _cleanup(self, ref: weakref.ref[_O]) -> None: """Weakref callback cleanup. This callable cleans out the state when it is being garbage @@ -657,14 +655,14 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): @classmethod def _instance_level_callable_processor( - cls, manager: ClassManager[_T], fn: _LoaderCallable, key: Any - ) -> Callable[[InstanceState[_T], _InstanceDict, Row], None]: + cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any + ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]: impl = manager[key].impl if is_collection_impl(impl): fixed_impl = impl def _set_callable( - state: InstanceState[_T], dict_: _InstanceDict, row: Row + state: InstanceState[_O], dict_: _InstanceDict, row: Row ) -> None: if "callables" not in state.__dict__: state.callables = {} @@ -676,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): else: def _set_callable( - state: InstanceState[_T], dict_: _InstanceDict, row: Row + state: InstanceState[_O], dict_: _InstanceDict, row: Row ) -> None: if "callables" not in state.__dict__: state.callables = {} @@ -768,7 +766,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]): self.manager.dispatch.expire(self, attribute_names) def _load_expired( - self, state: InstanceState[_T], passive: PassiveFlag + self, state: InstanceState[_O], passive: PassiveFlag ) -> LoaderCallableStatus: """__call__ allows the InstanceState to act as a deferred callable for loading expired attributes, which is also diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py index e961df1a3b..1107c92b5c 100644 --- a/lib/sqlalchemy/pool/events.py +++ b/lib/sqlalchemy/pool/events.py @@ -73,10 +73,8 @@ class PoolEvents(event.Events[Pool]): return target.pool elif isinstance(target, Pool): return target - elif hasattr(target, "dispatch") and hasattr( - target.dispatch._events, "_no_async_engine_events" - ): - target.dispatch._events._no_async_engine_events() + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() else: return None diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index ccd5e8c40e..d587433408 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from .selectable import _SelectIterable from .selectable import FromClause from ..engine import Connection + from ..engine import CursorResult from ..engine import Result from ..engine.base import _CompiledCacheType from ..engine.interfaces import _CoreMultiExecuteParams @@ -983,7 +984,7 @@ class Executable(roles.StatementRole, Generative): distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, _force: bool = False, - ) -> Result: + ) -> CursorResult: ... @util.ro_non_memoized_property diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 28b062d3d9..6ad099eefc 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations import asyncio from contextvars import copy_context as _copy_context @@ -19,6 +20,8 @@ from .langhelpers import memoized_property from .. import exc from ..util.typing import Protocol +_T = TypeVar("_T", bound=Any) + if typing.TYPE_CHECKING: class greenlet(Protocol): @@ -52,8 +55,6 @@ if not typing.TYPE_CHECKING: except (ImportError, AttributeError): _copy_context = None # noqa -_T = TypeVar("_T", bound=Any) - def is_exit_exception(e: BaseException) -> bool: # note asyncio.CancelledError is already BaseException @@ -128,11 +129,11 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T: async def greenlet_spawn( - fn: Callable[..., Any], + fn: Callable[..., _T], *args: Any, _require_await: bool = False, **kwargs: Any, -) -> Any: +) -> _T: """Runs a sync function ``fn`` in a new greenlet. The sync function can then use :func:`await_` to wait for async @@ -143,6 +144,7 @@ async def greenlet_spawn( :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable. """ + result: _T context = _AsyncIoGreenlet(fn, getcurrent()) # runs the function synchronously in gl greenlet. If the execution # is interrupted by await_, context is not dead and result is a diff --git a/pyproject.toml b/pyproject.toml index 8f7f50715a..46a4eb0d85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,6 @@ module = [ "sqlalchemy.engine.reflection", # interim, should be strict # TODO for strict: - "sqlalchemy.ext.asyncio.*", "sqlalchemy.ext.automap", "sqlalchemy.ext.compiler", "sqlalchemy.ext.declarative.*", diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 0fe14dc921..462d0900f9 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -480,6 +480,28 @@ class AsyncEngineTest(EngineFixture): eq_(async_engine.pool.checkedin(), 0) is_not(p1, async_engine.pool) + @testing.requires.queue_pool + @async_test + async def test_dispose_no_close(self, async_engine): + c1 = await async_engine.connect() + c2 = await async_engine.connect() + + await c1.close() + await c2.close() + + p1 = async_engine.pool + + if isinstance(p1, AsyncAdaptedQueuePool): + eq_(async_engine.pool.checkedin(), 2) + + await async_engine.dispose(close=False) + + # TODO: test that DBAPI connection was not closed + + if isinstance(p1, AsyncAdaptedQueuePool): + eq_(async_engine.pool.checkedin(), 0) + is_not(p1, async_engine.pool) + @testing.requires.independent_connections @async_test async def test_init_once_concurrency(self, async_engine): diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index f04b87f371..ce38de5114 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -10,6 +10,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.ext.asyncio import async_object_session +from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import exc as async_exc from sqlalchemy.ext.asyncio.base import ReversibleProxy @@ -202,7 +203,7 @@ class AsyncSessionTransactionTest(AsyncFixture): await fn(async_session, trans_on_subject=True, execute_on_subject=True) @async_test - async def test_sessionmaker_block_one(self, async_engine): + async def test_orm_sessionmaker_block_one(self, async_engine): User = self.classes.User maker = sessionmaker(async_engine, class_=AsyncSession) @@ -226,7 +227,7 @@ class AsyncSessionTransactionTest(AsyncFixture): eq_(u1.name, "u1") @async_test - async def test_sessionmaker_block_two(self, async_engine): + async def test_orm_sessionmaker_block_two(self, async_engine): User = self.classes.User maker = sessionmaker(async_engine, class_=AsyncSession) @@ -247,6 +248,52 @@ class AsyncSessionTransactionTest(AsyncFixture): eq_(u1.name, "u1") + @async_test + async def test_async_sessionmaker_block_one(self, async_engine): + + User = self.classes.User + maker = async_sessionmaker(async_engine) + + session = maker() + + async with session.begin(): + u1 = User(name="u1") + assert session.in_transaction() + session.add(u1) + + assert not session.in_transaction() + + async with maker() as session: + result = await session.execute( + select(User).where(User.name == "u1") + ) + + u1 = result.scalar_one() + + eq_(u1.name, "u1") + + @async_test + async def test_async_sessionmaker_block_two(self, async_engine): + + User = self.classes.User + maker = async_sessionmaker(async_engine) + + async with maker.begin() as session: + u1 = User(name="u1") + assert session.in_transaction() + session.add(u1) + + assert not session.in_transaction() + + async with maker() as session: + result = await session.execute( + select(User).where(User.name == "u1") + ) + + u1 = result.scalar_one() + + eq_(u1.name, "u1") + @async_test async def test_trans(self, async_session, async_engine): async with async_engine.connect() as outer_conn: @@ -882,7 +929,7 @@ class OverrideSyncSession(AsyncFixture): is_true(isinstance(ass.sync_session, _MySession)) is_(ass.sync_session_class, _MySession) - def test_init_sessionmaker(self, async_engine): + def test_init_orm_sessionmaker(self, async_engine): sm = sessionmaker( async_engine, class_=AsyncSession, sync_session_class=_MySession ) @@ -891,6 +938,13 @@ class OverrideSyncSession(AsyncFixture): is_true(isinstance(ass.sync_session, _MySession)) is_(ass.sync_session_class, _MySession) + def test_init_asyncio_sessionmaker(self, async_engine): + sm = async_sessionmaker(async_engine, sync_session_class=_MySession) + ass = sm() + + is_true(isinstance(ass.sync_session, _MySession)) + is_(ass.sync_session_class, _MySession) + def test_subclass(self, async_engine): ass = _MyAS(async_engine) diff --git a/test/ext/mypy/plain_files/async_sessionmaker.py b/test/ext/mypy/plain_files/async_sessionmaker.py new file mode 100644 index 0000000000..01a26d0354 --- /dev/null +++ b/test/ext/mypy/plain_files/async_sessionmaker.py @@ -0,0 +1,79 @@ +"""Illustrates use of the sqlalchemy.ext.asyncio.AsyncSession object +for asynchronous ORM use. + +""" +from __future__ import annotations + +import asyncio +from typing import List +from typing import TYPE_CHECKING + +from sqlalchemy import ForeignKey +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.future import select +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + +if TYPE_CHECKING: + from sqlalchemy import ScalarResult + + +class Base(DeclarativeBase): + pass + + +class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + bs: Mapped[List[B]] = relationship() + + +class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id = mapped_column(ForeignKey("a.id")) + data: Mapped[str] + + +async def async_main() -> None: + """Main program function.""" + + engine = create_async_engine( + "postgresql+asyncpg://scott:tiger@localhost/test", + echo=True, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async_session = async_sessionmaker(engine, expire_on_commit=False) + + async with async_session.begin() as session: + session.add_all( + [ + A(bs=[B(), B()], data="a1"), + A(bs=[B()], data="a2"), + A(bs=[B(), B()], data="a3"), + ] + ) + + async with async_session() as session: + + result = await session.execute(select(A).order_by(A.id)) + + r: ScalarResult[A] = result.scalars() + a1 = r.one() + + a1.data = "new data" + + await session.commit() + + +asyncio.run(async_main()) diff --git a/test/ext/mypy/plain_files/session.py b/test/ext/mypy/plain_files/session.py index 24c685e84b..199d3a804f 100644 --- a/test/ext/mypy/plain_files/session.py +++ b/test/ext/mypy/plain_files/session.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import List -from typing import Sequence from sqlalchemy import create_engine from sqlalchemy import ForeignKey @@ -45,6 +44,6 @@ with Session(e) as sess: sess.commit() with Session(e) as sess: - users: Sequence[User] = sess.scalars( + users: List[User] = sess.scalars( select(User), execution_options={"stream_results": False} ).all()