]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix AsyncSession.close_all()
authorBryan不可思议 <programripper@foxmail.com>
Wed, 18 Oct 2023 20:31:45 +0000 (16:31 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 18 Oct 2023 21:04:57 +0000 (23:04 +0200)
Fixed bug with method :meth:`_asyncio.AsyncSession.close_all`
that was not working correctly.
Also added function :func:`_asyncio.close_all_sessions` that's
the equivalent of :func:`_orm.close_all_sessions`.

Fixes: #10421
Closes: #10429
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10429
Pull-request-sha: a35c29d41213bd79bfa73ac559a29035ca4e9eb9

Change-Id: If2c3f0130a71b239382c2ea11a3436788ee242be

doc/build/changelog/changelog_20.rst
doc/build/changelog/unreleased_20/10421.rst [new file with mode: 0644]
doc/build/orm/extensions/asyncio.rst
lib/sqlalchemy/ext/asyncio/__init__.py
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/util/langhelpers.py
test/ext/asyncio/test_session_py3k.py

index 2c518321acc0144936b5359f008798bb5c257420..39939ee0d8464d338c6abd8ee0c009a9f1567e17 100644 (file)
         This is a no-argument callable that provides a new asyncio connection,
         using the asyncio database driver directly. The
         :func:`.create_async_engine` function will wrap the driver-level connection
-        in the appropriate structures. Pull request curtesy of Jack Wotherspoon.
+        in the appropriate structures. Pull request courtesy of Jack Wotherspoon.
 
     .. change::
         :tags: bug, orm, regression
         * ``@?`` using :meth:`_postgresql.JSONB.Comparator.path_exists`
         * ``#-`` using :meth:`_postgresql.JSONB.Comparator.delete_path`
 
-        Pull request curtesy of Guilherme Martins Crocetti.
+        Pull request courtesy of Guilherme Martins Crocetti.
 
 .. changelog::
     :version: 2.0.0rc1
diff --git a/doc/build/changelog/unreleased_20/10421.rst b/doc/build/changelog/unreleased_20/10421.rst
new file mode 100644 (file)
index 0000000..c550647
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 10421
+
+    Fixed bug with method :meth:`_asyncio.AsyncSession.close_all`
+    that was not working correctly.
+    Also added function :func:`_asyncio.close_all_sessions` that's
+    the equivalent of :func:`_orm.close_all_sessions`.
+    Pull request courtesy of Bryan不可思议.
index e4782f2bbb24db88247c70e43c626e11c5bdc89d..0815da29affc3d18028cb716222432f103fcb3ca 100644 (file)
@@ -1097,6 +1097,8 @@ ORM Session API Documentation
 
 .. autofunction:: async_session
 
+.. autofunction:: close_all_sessions
+
 .. autoclass:: async_sessionmaker
    :members:
    :inherited-members:
index ad6cd15268fba69103df889990b8d8e16c26bdf7..8564db6f22ee3f981053546b0b47d10d3c4c6e1c 100644 (file)
@@ -22,3 +22,4 @@ from .session import async_sessionmaker as async_sessionmaker
 from .session import AsyncAttrs as AsyncAttrs
 from .session import AsyncSession as AsyncSession
 from .session import AsyncSessionTransaction as AsyncSessionTransaction
+from .session import close_all_sessions as close_all_sessions
index 3e6d4fc71c9f43bd7e27f4bf2bc7030d2c4a47e6..7ac11ffacee48756209bbc4aad0365c012f2400e 100644 (file)
@@ -1544,7 +1544,7 @@ class async_scoped_session(Generic[_AS]):
         return self._proxied.info
 
     @classmethod
-    async def close_all(self) -> None:
+    async def close_all(cls) -> None:
         r"""Close all :class:`_asyncio.AsyncSession` sessions.
 
         .. container:: class_bases
@@ -1552,6 +1552,8 @@ class async_scoped_session(Generic[_AS]):
             Proxied for the :class:`_asyncio.AsyncSession` class on
             behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
 
+        .. deprecated:: 2.0 The :meth:`.AsyncSession.close_all` method is deprecated and will be removed in a future release.  Please refer to :func:`_asyncio.close_all_sessions`.
+
         """  # noqa: E501
 
         return await AsyncSession.close_all()
index dcbf26c6c49364a000cd37b5020c1b9c610b6fb3..b4b5ab3fbecb459c615777713678d1d68f438815 100644 (file)
@@ -32,6 +32,7 @@ from .result import _ensure_sync_result
 from .result import AsyncResult
 from .result import AsyncScalarResult
 from ... import util
+from ...orm import close_all_sessions as _sync_close_all_sessions
 from ...orm import object_session
 from ...orm import Session
 from ...orm import SessionTransaction
@@ -1057,9 +1058,15 @@ class AsyncSession(ReversibleProxy[Session]):
         await greenlet_spawn(self.sync_session.invalidate)
 
     @classmethod
-    async def close_all(self) -> None:
+    @util.deprecated(
+        "2.0",
+        "The :meth:`.AsyncSession.close_all` method is deprecated and will be "
+        "removed in a future release.  Please refer to "
+        ":func:`_asyncio.close_all_sessions`.",
+    )
+    async def close_all(cls) -> None:
         """Close all :class:`_asyncio.AsyncSession` sessions."""
-        await greenlet_spawn(self.sync_session.close_all)
+        await close_all_sessions()
 
     async def __aenter__(self: _AS) -> _AS:
         return self
@@ -1911,4 +1918,17 @@ def async_session(session: Session) -> Optional[AsyncSession]:
     return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
 
 
+async def close_all_sessions() -> None:
+    """Close all :class:`_asyncio.AsyncSession` sessions.
+
+    .. versionadded:: 2.0.23
+
+    .. seealso::
+
+        :func:`.session.close_all_sessions`
+
+    """
+    await greenlet_spawn(_sync_close_all_sessions)
+
+
 _instance_state._async_provider = async_session  # type: ignore
index c4527e123f610d94d0dab8d441d09f510988c017..9c56487c400909e58eb8132d8a387203ea420b6f 100644 (file)
@@ -265,6 +265,13 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
         metadata.update(format_argspec_plus(spec, grouped=False))
         metadata["name"] = fn.__name__
 
+        if inspect.iscoroutinefunction(fn):
+            metadata["prefix"] = "async "
+            metadata["target_prefix"] = "await "
+        else:
+            metadata["prefix"] = ""
+            metadata["target_prefix"] = ""
+
         # look for __ positional arguments.  This is a convention in
         # SQLAlchemy that arguments should be passed positionally
         # rather than as keyword
@@ -276,16 +283,16 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
         if "__" in repr(spec[0]):
             code = (
                 """\
-def %(name)s%(grouped_args)s:
-    return %(target)s(%(fn)s, %(apply_pos)s)
+%(prefix)sdef %(name)s%(grouped_args)s:
+    return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s)
 """
                 % metadata
             )
         else:
             code = (
                 """\
-def %(name)s%(grouped_args)s:
-    return %(target)s(%(fn)s, %(apply_kw)s)
+%(prefix)sdef %(name)s%(grouped_args)s:
+    return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s)
 """
                 % metadata
             )
index 19ce55a2d7c250efbff1971595f9f525bd599372..e38a0cc52a90f7741d67c17999dc5781f7529226 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import async_object_session
 from sqlalchemy.ext.asyncio import async_sessionmaker
 from sqlalchemy.ext.asyncio import AsyncAttrs
 from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import close_all_sessions
 from sqlalchemy.ext.asyncio import exc as async_exc
 from sqlalchemy.ext.asyncio.base import ReversibleProxy
 from sqlalchemy.orm import DeclarativeBase
@@ -41,7 +42,9 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import expect_deprecated
+from sqlalchemy.testing.assertions import in_
 from sqlalchemy.testing.assertions import is_false
+from sqlalchemy.testing.assertions import not_in
 from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.provision import normalize_sequence
 from .test_engine_py3k import AsyncFixture as _AsyncFixture
@@ -122,6 +125,50 @@ class AsyncSessionTest(AsyncFixture):
                     sync_connection.dialect.default_sequence_base,
                 )
 
+    @async_test
+    async def test_close_all(self, async_engine):
+        User = self.classes.User
+
+        s1 = AsyncSession(async_engine)
+        u1 = User()
+        s1.add(u1)
+
+        s2 = AsyncSession(async_engine)
+        u2 = User()
+        s2.add(u2)
+
+        in_(u1, s1)
+        in_(u2, s2)
+
+        await close_all_sessions()
+
+        not_in(u1, s1)
+        not_in(u2, s2)
+
+    @async_test
+    async def test_session_close_all_deprecated(self, async_engine):
+        User = self.classes.User
+
+        s1 = AsyncSession(async_engine)
+        u1 = User()
+        s1.add(u1)
+
+        s2 = AsyncSession(async_engine)
+        u2 = User()
+        s2.add(u2)
+
+        in_(u1, s1)
+        in_(u2, s2)
+
+        with expect_deprecated(
+            r"The AsyncSession.close_all\(\) method is deprecated and will "
+            "be removed in a future release. "
+        ):
+            await AsyncSession.close_all()
+
+        not_in(u1, s1)
+        not_in(u2, s2)
+
 
 class AsyncSessionQueryTest(AsyncFixture):
     @async_test