]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #6583
authorjason3gb <jason3gb@gmail.com>
Tue, 8 Jun 2021 16:40:03 +0000 (00:40 +0800)
committerjason3gb <jason3gb@gmail.com>
Tue, 15 Jun 2021 15:10:39 +0000 (23:10 +0800)
lib/sqlalchemy/ext/asyncio/__init__.py
lib/sqlalchemy/ext/asyncio/scoping.py [new file with mode: 0644]
lib/sqlalchemy/orm/scoping.py
test/ext/asyncio/test_scoping_py3k.py [new file with mode: 0644]

index 349bc1b753816218ec6a011fe90bc69a1fe55a47..19e6079dc583e5c14724a40a140e229ece7afc41 100644 (file)
@@ -14,6 +14,7 @@ from .events import AsyncSessionEvents
 from .result import AsyncMappingResult
 from .result import AsyncResult
 from .result import AsyncScalarResult
+from .scoping import async_scoped_session
 from .session import async_object_session
 from .session import async_session
 from .session import AsyncSession
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
new file mode 100644 (file)
index 0000000..efbc323
--- /dev/null
@@ -0,0 +1,89 @@
+# ext/asyncio/scoping.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from .session import AsyncSession
+from ...orm.scoping import ScopedSessionMixin
+from ...util import create_proxy_methods
+from ...util import ScopedRegistry
+
+
+@create_proxy_methods(
+    AsyncSession,
+    ":class:`_asyncio.AsyncSession`",
+    ":class:`_asyncio.scoping.async_scoped_session`",
+    classmethods=["close_all", "object_session", "identity_key"],
+    methods=[
+        "__contains__",
+        "__iter__",
+        "add",
+        "add_all",
+        "begin",
+        "begin_nested",
+        "close",
+        "commit",
+        "connection",
+        "delete",
+        "execute",
+        "expire",
+        "expire_all",
+        "expunge",
+        "expunge_all",
+        "flush",
+        "get",
+        "get_bind",
+        "is_modified",
+        "merge",
+        "refresh",
+        "rollback",
+        "scalar",
+    ],
+    attributes=[
+        "bind",
+        "dirty",
+        "deleted",
+        "new",
+        "identity_map",
+        "is_active",
+        "autoflush",
+        "no_autoflush",
+        "info",
+    ],
+)
+class async_scoped_session(ScopedSessionMixin):
+    """Provides scoped management of :class:`.AsyncSession` objects."""
+
+    def __init__(self, session_factory, scopefunc):
+        """Construct a new :class:`.scoped_session`.
+
+        :param session_factory: a factory to create new :class:`.Session`
+         instances. This is usually, but not necessarily, an instance
+         of :class:`.sessionmaker`.
+        :param scopefunc: function which defines
+         the current scope. Different from scoped_session's __init__ signature,
+         which has a default 'thread-local' scope, the scopefunc must be passed
+         in and it should return a hashable token with the same requirement as
+         in scoped_session.
+
+        """
+        self.session_factory = session_factory
+        self.registry = ScopedRegistry(session_factory, scopefunc)
+
+    @property
+    def _proxied(self):
+        return self.registry()
+
+    async def remove(self):
+        """Dispose of the current :class:`.AsyncSession`, if present.
+
+        Different from scoped_session's remove method, this method would use
+        await to wait for the close method of AsyncSession.
+
+        """
+
+        if self.registry.has():
+            await self.registry().close()
+        self.registry.clear()
index 0ba7b12ff24691d4e1cc99f3594887ace6466df5..6bd052dde0a4c5f9375c05746fa86730ff2f4899 100644 (file)
@@ -14,7 +14,54 @@ from ..util import ScopedRegistry
 from ..util import ThreadLocalRegistry
 from ..util import warn
 
-__all__ = ["scoped_session"]
+__all__ = ["scoped_session", "ScopedSessionMixin"]
+
+
+class ScopedSessionMixin(object):
+    @property
+    def _proxied(self):
+        return self.registry()
+
+    def __call__(self, **kw):
+        r"""Return the current :class:`.Session`, creating it
+        using the :attr:`.scoped_session.session_factory` if not present.
+
+        :param \**kw: Keyword arguments will be passed to the
+         :attr:`.scoped_session.session_factory` callable, if an existing
+         :class:`.Session` is not present.  If the :class:`.Session` is present
+         and keyword arguments have been passed,
+         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+        """
+        if kw:
+            if self.registry.has():
+                raise sa_exc.InvalidRequestError(
+                    "Scoped session is already present; "
+                    "no new arguments may be specified."
+                )
+            else:
+                sess = self.session_factory(**kw)
+                self.registry.set(sess)
+                return sess
+        else:
+            return self.registry()
+
+    def configure(self, **kwargs):
+        """reconfigure the :class:`.sessionmaker` used by this
+        :class:`.scoped_session`.
+
+        See :meth:`.sessionmaker.configure`.
+
+        """
+
+        if self.registry.has():
+            warn(
+                "At least one scoped session is already present. "
+                " configure() can not affect sessions that have "
+                "already been created."
+            )
+
+        self.session_factory.configure(**kwargs)
 
 
 @create_proxy_methods(
@@ -64,7 +111,7 @@ __all__ = ["scoped_session"]
         "autocommit",
     ],
 )
-class scoped_session(object):
+class scoped_session(ScopedSessionMixin):
     """Provides scoped management of :class:`.Session` objects.
 
     See :ref:`unitofwork_contextual` for a tutorial.
@@ -100,34 +147,6 @@ class scoped_session(object):
         else:
             self.registry = ThreadLocalRegistry(session_factory)
 
-    @property
-    def _proxied(self):
-        return self.registry()
-
-    def __call__(self, **kw):
-        r"""Return the current :class:`.Session`, creating it
-        using the :attr:`.scoped_session.session_factory` if not present.
-
-        :param \**kw: Keyword arguments will be passed to the
-         :attr:`.scoped_session.session_factory` callable, if an existing
-         :class:`.Session` is not present.  If the :class:`.Session` is present
-         and keyword arguments have been passed,
-         :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
-
-        """
-        if kw:
-            if self.registry.has():
-                raise sa_exc.InvalidRequestError(
-                    "Scoped session is already present; "
-                    "no new arguments may be specified."
-                )
-            else:
-                sess = self.session_factory(**kw)
-                self.registry.set(sess)
-                return sess
-        else:
-            return self.registry()
-
     def remove(self):
         """Dispose of the current :class:`.Session`, if present.
 
@@ -145,23 +164,6 @@ class scoped_session(object):
             self.registry().close()
         self.registry.clear()
 
-    def configure(self, **kwargs):
-        """reconfigure the :class:`.sessionmaker` used by this
-        :class:`.scoped_session`.
-
-        See :meth:`.sessionmaker.configure`.
-
-        """
-
-        if self.registry.has():
-            warn(
-                "At least one scoped session is already present. "
-                " configure() can not affect sessions that have "
-                "already been created."
-            )
-
-        self.session_factory.configure(**kwargs)
-
     def query_property(self, query_cls=None):
         """return a class property which produces a :class:`_query.Query`
         object
diff --git a/test/ext/asyncio/test_scoping_py3k.py b/test/ext/asyncio/test_scoping_py3k.py
new file mode 100644 (file)
index 0000000..223c7d9
--- /dev/null
@@ -0,0 +1,44 @@
+from asyncio import current_task
+
+import sqlalchemy as sa
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import is_
+from .test_session_py3k import AsyncFixture
+
+
+class AsyncScopedSessionTest(AsyncFixture):
+    @async_test
+    async def test_basic(self, async_engine):
+        AsyncSession = async_scoped_session(
+            sa.orm.sessionmaker(async_engine, class_=_AsyncSession),
+            scopefunc=current_task,
+        )
+
+        some_async_session = AsyncSession()
+        some_other_async_session = AsyncSession()
+
+        is_(some_async_session, some_other_async_session)
+        is_(some_async_session.bind, async_engine)
+
+        User = self.classes.User
+
+        async with AsyncSession.begin():
+            user_name = "scoped_async_session_u1"
+            u1 = User(name=user_name)
+
+            AsyncSession.add(u1)
+
+            await AsyncSession.flush()
+
+            conn = await AsyncSession.connection()
+            stmt = select(func.count(User.id)).where(User.name == user_name)
+            eq_(await conn.scalar(stmt), 1)
+
+            await AsyncSession.delete(u1)
+            await AsyncSession.flush()
+            eq_(await conn.scalar(stmt), 0)