]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement async_scoped_session
authorjason3gb <jason3gb@gmail.com>
Wed, 16 Jun 2021 14:18:08 +0000 (10:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Jun 2021 15:19:50 +0000 (11:19 -0400)
Implemented :class:`_asyncio.async_scoped_session` to address some
asyncio-related incompatibilities between :class:`_orm.scoped_session` and
:class:`_asyncio.AsyncSession`, in which some methods (notably the
:meth:`_asyncio.async_scoped_session.remove` method) should be used with
the ``await`` keyword.

Fixes: #6583
Closes: #6603
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/6603
Pull-request-sha: 0e8ef87dc824dcd83dca01641441afc453c8e07a

Change-Id: I9bfe56f8670302ff0015d9dc56c1e3ac5b92b118

doc/build/changelog/unreleased_14/6583.rst [new file with mode: 0644]
doc/build/orm/contextual.rst
doc/build/orm/extensions/asyncio.rst
lib/sqlalchemy/ext/asyncio/__init__.py
lib/sqlalchemy/ext/asyncio/scoping.py [new file with mode: 0644]
lib/sqlalchemy/orm/scoping.py
test/ext/asyncio/test_scoping_py3k.py [new file with mode: 0644]

diff --git a/doc/build/changelog/unreleased_14/6583.rst b/doc/build/changelog/unreleased_14/6583.rst
new file mode 100644 (file)
index 0000000..2d235c6
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: usecase, asyncio
+    :tickets: 6583
+
+    Implemented :class:`_asyncio.async_scoped_session` to address some
+    asyncio-related incompatibilities between :class:`_orm.scoped_session` and
+    :class:`_asyncio.AsyncSession`, in which some methods (notably the
+    :meth:`_asyncio.async_scoped_session.remove` method) should be used with
+    the ``await`` keyword.
+
+    .. seealso::
+
+        :ref:`asyncio_scoped_session`
\ No newline at end of file
index fd55846220a32a130d7914327730aeb757e501ad..aebcfc3ff307b507d5f7dff1a287decb1206fe13 100644 (file)
@@ -252,7 +252,8 @@ Contextual Session API
 ----------------------
 
 .. autoclass:: sqlalchemy.orm.scoping.scoped_session
-   :members:
+    :members:
+    :inherited-members:
 
 .. autoclass:: sqlalchemy.util.ScopedRegistry
     :members:
index 92471d459ea374985b324c5c3914b159448fcfc4..6aca1762df9af5ce30b9cccfe8ff9521c7ab86dc 100644 (file)
@@ -415,8 +415,48 @@ from using any connection more than once::
     )
 
 
+.. _asyncio_scoped_session:
+
+Using asyncio scoped session
+----------------------------
+
+The usage of :class:`_asyncio.async_scoped_session` is mostly similar to
+:class:`.scoped_session`. However, since there's no "thread-local" concept in
+the asyncio context, the "scopefunc" paramater must be provided to the
+constructor::
+
+    from asyncio import current_task
+
+    from sqlalchemy.orm import sessionmaker
+    from sqlalchemy.ext.asyncio import async_scoped_session
+    from sqlalchemy.ext.asyncio import AsyncSession
+
+    async_session_factory = sessionmaker(some_async_engine, class_=_AsyncSession)
+    AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+
+    some_async_session = AsyncSession()
+
+:class:`_asyncio.async_scoped_session` also includes **proxy
+behavior** similar to that of :class:`.scoped_session`, which means it can be
+treated as a :class:`_asyncio.AsyncSession` directly, keeping in mind that
+the usual ``await`` keywords are necessary, including for the
+:meth:`_asyncio.async_scoped_session.remove` method::
+
+    async def some_function(some_async_session, some_object):
+       # use the AsyncSession directly
+       some_async_session.add(some_object)
+
+       # use the AsyncSession via the context-local proxy
+       await AsyncSession.commit()
+
+       # "remove" the current proxied AsyncSession for the local context
+       await AsyncSession.remove()
+
+.. versionadded:: 1.4.19
+
 .. currentmodule:: sqlalchemy.ext.asyncio
 
+
 Engine API Documentation
 -------------------------
 
@@ -456,6 +496,10 @@ ORM Session API Documentation
 
 .. autofunction:: async_session
 
+.. autoclass:: async_scoped_session
+   :members:
+   :inherited-members:
+
 .. autoclass:: AsyncSession
    :members:
 
index 349bc1b753816218ec6a011fe90bc69a1fe55a47..19e6079dc583e5c14724a40a140e229ece7afc41 100644 (file)
@@ -14,6 +14,7 @@ from .events import AsyncSessionEvents
 from .result import AsyncMappingResult
 from .result import AsyncResult
 from .result import AsyncScalarResult
+from .scoping import async_scoped_session
 from .session import async_object_session
 from .session import async_session
 from .session import AsyncSession
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
new file mode 100644 (file)
index 0000000..4d1cea9
--- /dev/null
@@ -0,0 +1,101 @@
+# ext/asyncio/scoping.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from .session import AsyncSession
+from ...orm.scoping import ScopedSessionMixin
+from ...util import create_proxy_methods
+from ...util import ScopedRegistry
+
+
+@create_proxy_methods(
+    AsyncSession,
+    ":class:`_asyncio.AsyncSession`",
+    ":class:`_asyncio.scoping.async_scoped_session`",
+    classmethods=["close_all", "object_session", "identity_key"],
+    methods=[
+        "__contains__",
+        "__iter__",
+        "add",
+        "add_all",
+        "begin",
+        "begin_nested",
+        "close",
+        "commit",
+        "connection",
+        "delete",
+        "execute",
+        "expire",
+        "expire_all",
+        "expunge",
+        "expunge_all",
+        "flush",
+        "get",
+        "get_bind",
+        "is_modified",
+        "merge",
+        "refresh",
+        "rollback",
+        "scalar",
+    ],
+    attributes=[
+        "bind",
+        "dirty",
+        "deleted",
+        "new",
+        "identity_map",
+        "is_active",
+        "autoflush",
+        "no_autoflush",
+        "info",
+    ],
+)
+class async_scoped_session(ScopedSessionMixin):
+    """Provides scoped management of :class:`.AsyncSession` objects.
+
+    See the section :ref:`asyncio_scoped_session` for usage details.
+
+    .. versionadded:: 1.4.19
+
+
+    """
+
+    def __init__(self, session_factory, scopefunc):
+        """Construct a new :class:`_asyncio.async_scoped_session`.
+
+        :param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
+         instances. This is usually, but not necessarily, an instance
+         of :class:`_orm.sessionmaker` which itself was passed the
+         :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
+         parameter::
+
+            async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
+            AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+
+        :param scopefunc: function which defines
+         the current scope.   A function such as ``asyncio.current_task``
+         may be useful here.
+
+        """  # noqa E501
+
+        self.session_factory = session_factory
+        self.registry = ScopedRegistry(session_factory, scopefunc)
+
+    @property
+    def _proxied(self):
+        return self.registry()
+
+    async def remove(self):
+        """Dispose of the current :class:`.AsyncSession`, if present.
+
+        Different from scoped_session's remove method, this method would use
+        await to wait for the close method of AsyncSession.
+
+        """
+
+        if self.registry.has():
+            await self.registry().close()
+        self.registry.clear()
index 0ba7b12ff24691d4e1cc99f3594887ace6466df5..6bd052dde0a4c5f9375c05746fa86730ff2f4899 100644 (file)
@@ -14,7 +14,54 @@ from ..util import ScopedRegistry
 from ..util import ThreadLocalRegistry
 from ..util import warn
 
-__all__ = ["scoped_session"]
+__all__ = ["scoped_session", "ScopedSessionMixin"]
+
+
+class ScopedSessionMixin(object):
+    @property
+    def _proxied(self):
+        return self.registry()
+
+    def __call__(self, **kw):
+        r"""Return the current :class:`.Session`, creating it
+        using the :attr:`.scoped_session.session_factory` if not present.
+
+        :param \**kw: Keyword arguments will be passed to the
+         :attr:`.scoped_session.session_factory` callable, if an existing
+         :class:`.Session` is not present.  If the :class:`.Session` is present
+         and keyword arguments have been passed,
+         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+        """
+        if kw:
+            if self.registry.has():
+                raise sa_exc.InvalidRequestError(
+                    "Scoped session is already present; "
+                    "no new arguments may be specified."
+                )
+            else:
+                sess = self.session_factory(**kw)
+                self.registry.set(sess)
+                return sess
+        else:
+            return self.registry()
+
+    def configure(self, **kwargs):
+        """reconfigure the :class:`.sessionmaker` used by this
+        :class:`.scoped_session`.
+
+        See :meth:`.sessionmaker.configure`.
+
+        """
+
+        if self.registry.has():
+            warn(
+                "At least one scoped session is already present. "
+                " configure() can not affect sessions that have "
+                "already been created."
+            )
+
+        self.session_factory.configure(**kwargs)
 
 
 @create_proxy_methods(
@@ -64,7 +111,7 @@ __all__ = ["scoped_session"]
         "autocommit",
     ],
 )
-class scoped_session(object):
+class scoped_session(ScopedSessionMixin):
     """Provides scoped management of :class:`.Session` objects.
 
     See :ref:`unitofwork_contextual` for a tutorial.
@@ -100,34 +147,6 @@ class scoped_session(object):
         else:
             self.registry = ThreadLocalRegistry(session_factory)
 
-    @property
-    def _proxied(self):
-        return self.registry()
-
-    def __call__(self, **kw):
-        r"""Return the current :class:`.Session`, creating it
-        using the :attr:`.scoped_session.session_factory` if not present.
-
-        :param \**kw: Keyword arguments will be passed to the
-         :attr:`.scoped_session.session_factory` callable, if an existing
-         :class:`.Session` is not present.  If the :class:`.Session` is present
-         and keyword arguments have been passed,
-         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
-
-        """
-        if kw:
-            if self.registry.has():
-                raise sa_exc.InvalidRequestError(
-                    "Scoped session is already present; "
-                    "no new arguments may be specified."
-                )
-            else:
-                sess = self.session_factory(**kw)
-                self.registry.set(sess)
-                return sess
-        else:
-            return self.registry()
-
     def remove(self):
         """Dispose of the current :class:`.Session`, if present.
 
@@ -145,23 +164,6 @@ class scoped_session(object):
             self.registry().close()
         self.registry.clear()
 
-    def configure(self, **kwargs):
-        """reconfigure the :class:`.sessionmaker` used by this
-        :class:`.scoped_session`.
-
-        See :meth:`.sessionmaker.configure`.
-
-        """
-
-        if self.registry.has():
-            warn(
-                "At least one scoped session is already present. "
-                " configure() can not affect sessions that have "
-                "already been created."
-            )
-
-        self.session_factory.configure(**kwargs)
-
     def query_property(self, query_cls=None):
         """return a class property which produces a :class:`_query.Query`
         object
diff --git a/test/ext/asyncio/test_scoping_py3k.py b/test/ext/asyncio/test_scoping_py3k.py
new file mode 100644 (file)
index 0000000..223c7d9
--- /dev/null
@@ -0,0 +1,44 @@
+from asyncio import current_task
+
+import sqlalchemy as sa
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import is_
+from .test_session_py3k import AsyncFixture
+
+
+class AsyncScopedSessionTest(AsyncFixture):
+    @async_test
+    async def test_basic(self, async_engine):
+        AsyncSession = async_scoped_session(
+            sa.orm.sessionmaker(async_engine, class_=_AsyncSession),
+            scopefunc=current_task,
+        )
+
+        some_async_session = AsyncSession()
+        some_other_async_session = AsyncSession()
+
+        is_(some_async_session, some_other_async_session)
+        is_(some_async_session.bind, async_engine)
+
+        User = self.classes.User
+
+        async with AsyncSession.begin():
+            user_name = "scoped_async_session_u1"
+            u1 = User(name=user_name)
+
+            AsyncSession.add(u1)
+
+            await AsyncSession.flush()
+
+            conn = await AsyncSession.connection()
+            stmt = select(func.count(User.id)).where(User.name == user_name)
+            eq_(await conn.scalar(stmt), 1)
+
+            await AsyncSession.delete(u1)
+            await AsyncSession.flush()
+            eq_(await conn.scalar(stmt), 0)