]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
feat: `close_all_sessions`
authorBryan不可思议 <programripper@foxmail.com>
Wed, 18 Oct 2023 20:15:30 +0000 (20:15 +0000)
committerGitHub <noreply@github.com>
Wed, 18 Oct 2023 20:15:30 +0000 (20:15 +0000)
lib/sqlalchemy/ext/asyncio/__init__.py
lib/sqlalchemy/ext/asyncio/session.py
test/ext/asyncio/test_session_py3k.py

index ad6cd15268fba69103df889990b8d8e16c26bdf7..8564db6f22ee3f981053546b0b47d10d3c4c6e1c 100644 (file)
@@ -22,3 +22,4 @@ from .session import async_sessionmaker as async_sessionmaker
 from .session import AsyncAttrs as AsyncAttrs
 from .session import AsyncSession as AsyncSession
 from .session import AsyncSessionTransaction as AsyncSessionTransaction
+from .session import close_all_sessions as close_all_sessions
index 75dd43281d8d780321073512234bc11ddca3b9e7..b4b5ab3fbecb459c615777713678d1d68f438815 100644 (file)
@@ -32,7 +32,7 @@ from .result import _ensure_sync_result
 from .result import AsyncResult
 from .result import AsyncScalarResult
 from ... import util
-from ...orm import close_all_sessions
+from ...orm import close_all_sessions as _sync_close_all_sessions
 from ...orm import object_session
 from ...orm import Session
 from ...orm import SessionTransaction
@@ -1062,11 +1062,11 @@ class AsyncSession(ReversibleProxy[Session]):
         "2.0",
         "The :meth:`.AsyncSession.close_all` method is deprecated and will be "
         "removed in a future release.  Please refer to "
-        ":func:`.session.close_all_sessions`.",
+        ":func:`_asyncio.close_all_sessions`.",
     )
     async def close_all(cls) -> None:
         """Close all :class:`_asyncio.AsyncSession` sessions."""
-        await greenlet_spawn(close_all_sessions)
+        await close_all_sessions()
 
     async def __aenter__(self: _AS) -> _AS:
         return self
@@ -1918,4 +1918,17 @@ def async_session(session: Session) -> Optional[AsyncSession]:
     return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
 
 
+async def close_all_sessions() -> None:
+    """Close all :class:`_asyncio.AsyncSession` sessions.
+
+    .. versionadded:: 2.0.23
+
+    .. seealso::
+
+        :func:`.session.close_all_sessions`
+
+    """
+    await greenlet_spawn(_sync_close_all_sessions)
+
+
 _instance_state._async_provider = async_session  # type: ignore
index 61759f8a93147ad06c276558f207e46f2c9c648a..394e7df45e3beabc50f20fc9443ea37a3c44190e 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import async_object_session
 from sqlalchemy.ext.asyncio import async_sessionmaker
 from sqlalchemy.ext.asyncio import AsyncAttrs
 from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import close_all_sessions
 from sqlalchemy.ext.asyncio import exc as async_exc
 from sqlalchemy.ext.asyncio.base import ReversibleProxy
 from sqlalchemy.orm import DeclarativeBase
@@ -123,17 +124,52 @@ class AsyncSessionTest(AsyncFixture):
                 )
 
     @async_test
-    async def test_close_all(self, async_session):
-        User = self.classes.User
-        u = User(name="u")
-        async_session.add(u)
-        await async_session.commit()
+    async def test_close_all(self, async_engine):
+        users, User = self.tables.users, self.classes.User
+
+        self.mapper_registry.map_imperatively(User, users)
+
+        s1 = AsyncSession(async_engine)
+        u1 = User()
+        s1.add(u1)
+
+        s2 = AsyncSession(async_engine)
+        u2 = User()
+        s2.add(u2)
+
+        assert u1 in s1
+        assert u2 in s2
+
+        await close_all_sessions()
+
+        assert u1 not in s1
+        assert u2 not in s2
+
+    @async_test
+    async def test_session_close_all_deprecated(self, async_engine):
+        users, User = self.tables.users, self.classes.User
+
+        self.mapper_registry.map_imperatively(User, users)
+
+        s1 = AsyncSession(async_engine)
+        u1 = User()
+        s1.add(u1)
+
+        s2 = AsyncSession(async_engine)
+        u2 = User()
+        s2.add(u2)
+
+        assert u1 in s1
+        assert u2 in s2
+
         with expect_deprecated(
             r"The AsyncSession.close_all\(\) method is deprecated and will "
             "be removed in a future release. "
         ):
             await AsyncSession.close_all()
-        assert async_session.sync_session.identity_map.values() == []
+
+        assert u1 not in s1
+        assert u2 not in s2
 
 
 class AsyncSessionQueryTest(AsyncFixture):