From d06133ba376ba4ab0b7117b2eb72d5fd29a43bb2 Mon Sep 17 00:00:00 2001 From: jason3gb Date: Wed, 16 Jun 2021 10:18:08 -0400 Subject: [PATCH] Implement async_scoped_session Implemented :class:`_asyncio.async_scoped_session` to address some asyncio-related incompatibilities between :class:`_orm.scoped_session` and :class:`_asyncio.AsyncSession`, in which some methods (notably the :meth:`_asyncio.async_scoped_session.remove` method) should be used with the ``await`` keyword. Fixes: #6583 Closes: #6603 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/6603 Pull-request-sha: 0e8ef87dc824dcd83dca01641441afc453c8e07a Change-Id: I9bfe56f8670302ff0015d9dc56c1e3ac5b92b118 --- doc/build/changelog/unreleased_14/6583.rst | 13 +++ doc/build/orm/contextual.rst | 3 +- doc/build/orm/extensions/asyncio.rst | 44 +++++++++ lib/sqlalchemy/ext/asyncio/__init__.py | 1 + lib/sqlalchemy/ext/asyncio/scoping.py | 101 +++++++++++++++++++++ lib/sqlalchemy/orm/scoping.py | 96 ++++++++++---------- test/ext/asyncio/test_scoping_py3k.py | 44 +++++++++ 7 files changed, 254 insertions(+), 48 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6583.rst create mode 100644 lib/sqlalchemy/ext/asyncio/scoping.py create mode 100644 test/ext/asyncio/test_scoping_py3k.py diff --git a/doc/build/changelog/unreleased_14/6583.rst b/doc/build/changelog/unreleased_14/6583.rst new file mode 100644 index 0000000000..2d235c6d48 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6583.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: usecase, asyncio + :tickets: 6583 + + Implemented :class:`_asyncio.async_scoped_session` to address some + asyncio-related incompatibilities between :class:`_orm.scoped_session` and + :class:`_asyncio.AsyncSession`, in which some methods (notably the + :meth:`_asyncio.async_scoped_session.remove` method) should be used with + the ``await`` keyword. + + .. seealso:: + + :ref:`asyncio_scoped_session` \ No newline at end of file diff --git a/doc/build/orm/contextual.rst b/doc/build/orm/contextual.rst index fd55846220..aebcfc3ff3 100644 --- a/doc/build/orm/contextual.rst +++ b/doc/build/orm/contextual.rst @@ -252,7 +252,8 @@ Contextual Session API ---------------------- .. autoclass:: sqlalchemy.orm.scoping.scoped_session - :members: + :members: + :inherited-members: .. autoclass:: sqlalchemy.util.ScopedRegistry :members: diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index 92471d459e..6aca1762df 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -415,8 +415,48 @@ from using any connection more than once:: ) +.. _asyncio_scoped_session: + +Using asyncio scoped session +---------------------------- + +The usage of :class:`_asyncio.async_scoped_session` is mostly similar to +:class:`.scoped_session`. However, since there's no "thread-local" concept in +the asyncio context, the "scopefunc" paramater must be provided to the +constructor:: + + from asyncio import current_task + + from sqlalchemy.orm import sessionmaker + from sqlalchemy.ext.asyncio import async_scoped_session + from sqlalchemy.ext.asyncio import AsyncSession + + async_session_factory = sessionmaker(some_async_engine, class_=_AsyncSession) + AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task) + + some_async_session = AsyncSession() + +:class:`_asyncio.async_scoped_session` also includes **proxy +behavior** similar to that of :class:`.scoped_session`, which means it can be +treated as a :class:`_asyncio.AsyncSession` directly, keeping in mind that +the usual ``await`` keywords are necessary, including for the +:meth:`_asyncio.async_scoped_session.remove` method:: + + async def some_function(some_async_session, some_object): + # use the AsyncSession directly + some_async_session.add(some_object) + + # use the AsyncSession via the context-local proxy + await AsyncSession.commit() + + # "remove" the current proxied AsyncSession for the local context + await AsyncSession.remove() + +.. versionadded:: 1.4.19 + .. currentmodule:: sqlalchemy.ext.asyncio + Engine API Documentation ------------------------- @@ -456,6 +496,10 @@ ORM Session API Documentation .. autofunction:: async_session +.. autoclass:: async_scoped_session + :members: + :inherited-members: + .. autoclass:: AsyncSession :members: diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index 349bc1b753..19e6079dc5 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -14,6 +14,7 @@ from .events import AsyncSessionEvents from .result import AsyncMappingResult from .result import AsyncResult from .result import AsyncScalarResult +from .scoping import async_scoped_session from .session import async_object_session from .session import async_session from .session import AsyncSession diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py new file mode 100644 index 0000000000..4d1cea9f9f --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -0,0 +1,101 @@ +# ext/asyncio/scoping.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .session import AsyncSession +from ...orm.scoping import ScopedSessionMixin +from ...util import create_proxy_methods +from ...util import ScopedRegistry + + +@create_proxy_methods( + AsyncSession, + ":class:`_asyncio.AsyncSession`", + ":class:`_asyncio.scoping.async_scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get", + "get_bind", + "is_modified", + "merge", + "refresh", + "rollback", + "scalar", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class async_scoped_session(ScopedSessionMixin): + """Provides scoped management of :class:`.AsyncSession` objects. + + See the section :ref:`asyncio_scoped_session` for usage details. + + .. versionadded:: 1.4.19 + + + """ + + def __init__(self, session_factory, scopefunc): + """Construct a new :class:`_asyncio.async_scoped_session`. + + :param session_factory: a factory to create new :class:`_asyncio.AsyncSession` + instances. This is usually, but not necessarily, an instance + of :class:`_orm.sessionmaker` which itself was passed the + :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_` + parameter:: + + async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession) + AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task) + + :param scopefunc: function which defines + the current scope. A function such as ``asyncio.current_task`` + may be useful here. + + """ # noqa E501 + + self.session_factory = session_factory + self.registry = ScopedRegistry(session_factory, scopefunc) + + @property + def _proxied(self): + return self.registry() + + async def remove(self): + """Dispose of the current :class:`.AsyncSession`, if present. + + Different from scoped_session's remove method, this method would use + await to wait for the close method of AsyncSession. + + """ + + if self.registry.has(): + await self.registry().close() + self.registry.clear() diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 0ba7b12ff2..6bd052dde0 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -14,7 +14,54 @@ from ..util import ScopedRegistry from ..util import ThreadLocalRegistry from ..util import warn -__all__ = ["scoped_session"] +__all__ = ["scoped_session", "ScopedSessionMixin"] + + +class ScopedSessionMixin(object): + @property + def _proxied(self): + return self.registry() + + def __call__(self, **kw): + r"""Return the current :class:`.Session`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.Session` is not present. If the :class:`.Session` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + return sess + else: + return self.registry() + + def configure(self, **kwargs): + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) @create_proxy_methods( @@ -64,7 +111,7 @@ __all__ = ["scoped_session"] "autocommit", ], ) -class scoped_session(object): +class scoped_session(ScopedSessionMixin): """Provides scoped management of :class:`.Session` objects. See :ref:`unitofwork_contextual` for a tutorial. @@ -100,34 +147,6 @@ class scoped_session(object): else: self.registry = ThreadLocalRegistry(session_factory) - @property - def _proxied(self): - return self.registry() - - def __call__(self, **kw): - r"""Return the current :class:`.Session`, creating it - using the :attr:`.scoped_session.session_factory` if not present. - - :param \**kw: Keyword arguments will be passed to the - :attr:`.scoped_session.session_factory` callable, if an existing - :class:`.Session` is not present. If the :class:`.Session` is present - and keyword arguments have been passed, - :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. - - """ - if kw: - if self.registry.has(): - raise sa_exc.InvalidRequestError( - "Scoped session is already present; " - "no new arguments may be specified." - ) - else: - sess = self.session_factory(**kw) - self.registry.set(sess) - return sess - else: - return self.registry() - def remove(self): """Dispose of the current :class:`.Session`, if present. @@ -145,23 +164,6 @@ class scoped_session(object): self.registry().close() self.registry.clear() - def configure(self, **kwargs): - """reconfigure the :class:`.sessionmaker` used by this - :class:`.scoped_session`. - - See :meth:`.sessionmaker.configure`. - - """ - - if self.registry.has(): - warn( - "At least one scoped session is already present. " - " configure() can not affect sessions that have " - "already been created." - ) - - self.session_factory.configure(**kwargs) - def query_property(self, query_cls=None): """return a class property which produces a :class:`_query.Query` object diff --git a/test/ext/asyncio/test_scoping_py3k.py b/test/ext/asyncio/test_scoping_py3k.py new file mode 100644 index 0000000000..223c7d9031 --- /dev/null +++ b/test/ext/asyncio/test_scoping_py3k.py @@ -0,0 +1,44 @@ +from asyncio import current_task + +import sqlalchemy as sa +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession +from sqlalchemy.testing import async_test +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import is_ +from .test_session_py3k import AsyncFixture + + +class AsyncScopedSessionTest(AsyncFixture): + @async_test + async def test_basic(self, async_engine): + AsyncSession = async_scoped_session( + sa.orm.sessionmaker(async_engine, class_=_AsyncSession), + scopefunc=current_task, + ) + + some_async_session = AsyncSession() + some_other_async_session = AsyncSession() + + is_(some_async_session, some_other_async_session) + is_(some_async_session.bind, async_engine) + + User = self.classes.User + + async with AsyncSession.begin(): + user_name = "scoped_async_session_u1" + u1 = User(name=user_name) + + AsyncSession.add(u1) + + await AsyncSession.flush() + + conn = await AsyncSession.connection() + stmt = select(func.count(User.id)).where(User.name == user_name) + eq_(await conn.scalar(stmt), 1) + + await AsyncSession.delete(u1) + await AsyncSession.flush() + eq_(await conn.scalar(stmt), 0) -- 2.47.2