From 8bb95627717e880c8ef02ac6d21061a9636c4e09 Mon Sep 17 00:00:00 2001 From: jason3gb Date: Wed, 9 Jun 2021 00:40:03 +0800 Subject: [PATCH] Fixes: #6583 --- lib/sqlalchemy/ext/asyncio/__init__.py | 1 + lib/sqlalchemy/ext/asyncio/scoping.py | 89 ++++++++++++++++++++++++ lib/sqlalchemy/orm/scoping.py | 96 +++++++++++++------------- test/ext/asyncio/test_scoping_py3k.py | 44 ++++++++++++ 4 files changed, 183 insertions(+), 47 deletions(-) create mode 100644 lib/sqlalchemy/ext/asyncio/scoping.py create mode 100644 test/ext/asyncio/test_scoping_py3k.py diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index 349bc1b753..19e6079dc5 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -14,6 +14,7 @@ from .events import AsyncSessionEvents from .result import AsyncMappingResult from .result import AsyncResult from .result import AsyncScalarResult +from .scoping import async_scoped_session from .session import async_object_session from .session import async_session from .session import AsyncSession diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py new file mode 100644 index 0000000000..efbc323fa9 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -0,0 +1,89 @@ +# ext/asyncio/scoping.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +from .session import AsyncSession +from ...orm.scoping import ScopedSessionMixin +from ...util import create_proxy_methods +from ...util import ScopedRegistry + + +@create_proxy_methods( + AsyncSession, + ":class:`_asyncio.AsyncSession`", + ":class:`_asyncio.scoping.async_scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get", + "get_bind", + "is_modified", + "merge", + "refresh", + "rollback", + "scalar", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class async_scoped_session(ScopedSessionMixin): + """Provides scoped management of :class:`.AsyncSession` objects.""" + + def __init__(self, session_factory, scopefunc): + """Construct a new :class:`.scoped_session`. + + :param session_factory: a factory to create new :class:`.Session` + instances. This is usually, but not necessarily, an instance + of :class:`.sessionmaker`. + :param scopefunc: function which defines + the current scope. Different from scoped_session's __init__ signature, + which has a default 'thread-local' scope, the scopefunc must be passed + in and it should return a hashable token with the same requirement as + in scoped_session. + + """ + self.session_factory = session_factory + self.registry = ScopedRegistry(session_factory, scopefunc) + + @property + def _proxied(self): + return self.registry() + + async def remove(self): + """Dispose of the current :class:`.AsyncSession`, if present. + + Different from scoped_session's remove method, this method would use + await to wait for the close method of AsyncSession. + + """ + + if self.registry.has(): + await self.registry().close() + self.registry.clear() diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 0ba7b12ff2..6bd052dde0 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -14,7 +14,54 @@ from ..util import ScopedRegistry from ..util import ThreadLocalRegistry from ..util import warn -__all__ = ["scoped_session"] +__all__ = ["scoped_session", "ScopedSessionMixin"] + + +class ScopedSessionMixin(object): + @property + def _proxied(self): + return self.registry() + + def __call__(self, **kw): + r"""Return the current :class:`.Session`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.Session` is not present. If the :class:`.Session` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + return sess + else: + return self.registry() + + def configure(self, **kwargs): + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) @create_proxy_methods( @@ -64,7 +111,7 @@ __all__ = ["scoped_session"] "autocommit", ], ) -class scoped_session(object): +class scoped_session(ScopedSessionMixin): """Provides scoped management of :class:`.Session` objects. See :ref:`unitofwork_contextual` for a tutorial. @@ -100,34 +147,6 @@ class scoped_session(object): else: self.registry = ThreadLocalRegistry(session_factory) - @property - def _proxied(self): - return self.registry() - - def __call__(self, **kw): - r"""Return the current :class:`.Session`, creating it - using the :attr:`.scoped_session.session_factory` if not present. - - :param \**kw: Keyword arguments will be passed to the - :attr:`.scoped_session.session_factory` callable, if an existing - :class:`.Session` is not present. If the :class:`.Session` is present - and keyword arguments have been passed, - :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. - - """ - if kw: - if self.registry.has(): - raise sa_exc.InvalidRequestError( - "Scoped session is already present; " - "no new arguments may be specified." - ) - else: - sess = self.session_factory(**kw) - self.registry.set(sess) - return sess - else: - return self.registry() - def remove(self): """Dispose of the current :class:`.Session`, if present. @@ -145,23 +164,6 @@ class scoped_session(object): self.registry().close() self.registry.clear() - def configure(self, **kwargs): - """reconfigure the :class:`.sessionmaker` used by this - :class:`.scoped_session`. - - See :meth:`.sessionmaker.configure`. - - """ - - if self.registry.has(): - warn( - "At least one scoped session is already present. " - " configure() can not affect sessions that have " - "already been created." - ) - - self.session_factory.configure(**kwargs) - def query_property(self, query_cls=None): """return a class property which produces a :class:`_query.Query` object diff --git a/test/ext/asyncio/test_scoping_py3k.py b/test/ext/asyncio/test_scoping_py3k.py new file mode 100644 index 0000000000..223c7d9031 --- /dev/null +++ b/test/ext/asyncio/test_scoping_py3k.py @@ -0,0 +1,44 @@ +from asyncio import current_task + +import sqlalchemy as sa +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession +from sqlalchemy.testing import async_test +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import is_ +from .test_session_py3k import AsyncFixture + + +class AsyncScopedSessionTest(AsyncFixture): + @async_test + async def test_basic(self, async_engine): + AsyncSession = async_scoped_session( + sa.orm.sessionmaker(async_engine, class_=_AsyncSession), + scopefunc=current_task, + ) + + some_async_session = AsyncSession() + some_other_async_session = AsyncSession() + + is_(some_async_session, some_other_async_session) + is_(some_async_session.bind, async_engine) + + User = self.classes.User + + async with AsyncSession.begin(): + user_name = "scoped_async_session_u1" + u1 = User(name=user_name) + + AsyncSession.add(u1) + + await AsyncSession.flush() + + conn = await AsyncSession.connection() + stmt = select(func.count(User.id)).where(User.name == user_name) + eq_(await conn.scalar(stmt), 1) + + await AsyncSession.delete(u1) + await AsyncSession.flush() + eq_(await conn.scalar(stmt), 0) -- 2.47.3