From a35c29d41213bd79bfa73ac559a29035ca4e9eb9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Bryan=E4=B8=8D=E5=8F=AF=E6=80=9D=E8=AE=AE?= Date: Wed, 18 Oct 2023 20:15:30 +0000 Subject: [PATCH] feat: `close_all_sessions` --- lib/sqlalchemy/ext/asyncio/__init__.py | 1 + lib/sqlalchemy/ext/asyncio/session.py | 19 ++++++++-- test/ext/asyncio/test_session_py3k.py | 48 ++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index ad6cd15268..8564db6f22 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -22,3 +22,4 @@ from .session import async_sessionmaker as async_sessionmaker from .session import AsyncAttrs as AsyncAttrs from .session import AsyncSession as AsyncSession from .session import AsyncSessionTransaction as AsyncSessionTransaction +from .session import close_all_sessions as close_all_sessions diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 75dd43281d..b4b5ab3fbe 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -32,7 +32,7 @@ from .result import _ensure_sync_result from .result import AsyncResult from .result import AsyncScalarResult from ... import util -from ...orm import close_all_sessions +from ...orm import close_all_sessions as _sync_close_all_sessions from ...orm import object_session from ...orm import Session from ...orm import SessionTransaction @@ -1062,11 +1062,11 @@ class AsyncSession(ReversibleProxy[Session]): "2.0", "The :meth:`.AsyncSession.close_all` method is deprecated and will be " "removed in a future release. Please refer to " - ":func:`.session.close_all_sessions`.", + ":func:`_asyncio.close_all_sessions`.", ) async def close_all(cls) -> None: """Close all :class:`_asyncio.AsyncSession` sessions.""" - await greenlet_spawn(close_all_sessions) + await close_all_sessions() async def __aenter__(self: _AS) -> _AS: return self @@ -1918,4 +1918,17 @@ def async_session(session: Session) -> Optional[AsyncSession]: return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) +async def close_all_sessions() -> None: + """Close all :class:`_asyncio.AsyncSession` sessions. + + .. versionadded:: 2.0.23 + + .. seealso:: + + :func:`.session.close_all_sessions` + + """ + await greenlet_spawn(_sync_close_all_sessions) + + _instance_state._async_provider = async_session # type: ignore diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 61759f8a93..394e7df45e 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import close_all_sessions from sqlalchemy.ext.asyncio import exc as async_exc from sqlalchemy.ext.asyncio.base import ReversibleProxy from sqlalchemy.orm import DeclarativeBase @@ -123,17 +124,52 @@ class AsyncSessionTest(AsyncFixture): ) @async_test - async def test_close_all(self, async_session): - User = self.classes.User - u = User(name="u") - async_session.add(u) - await async_session.commit() + async def test_close_all(self, async_engine): + users, User = self.tables.users, self.classes.User + + self.mapper_registry.map_imperatively(User, users) + + s1 = AsyncSession(async_engine) + u1 = User() + s1.add(u1) + + s2 = AsyncSession(async_engine) + u2 = User() + s2.add(u2) + + assert u1 in s1 + assert u2 in s2 + + await close_all_sessions() + + assert u1 not in s1 + assert u2 not in s2 + + @async_test + async def test_session_close_all_deprecated(self, async_engine): + users, User = self.tables.users, self.classes.User + + self.mapper_registry.map_imperatively(User, users) + + s1 = AsyncSession(async_engine) + u1 = User() + s1.add(u1) + + s2 = AsyncSession(async_engine) + u2 = User() + s2.add(u2) + + assert u1 in s1 + assert u2 in s2 + with expect_deprecated( r"The AsyncSession.close_all\(\) method is deprecated and will " "be removed in a future release. " ): await AsyncSession.close_all() - assert async_session.sync_session.identity_map.values() == [] + + assert u1 not in s1 + assert u2 not in s2 class AsyncSessionQueryTest(AsyncFixture): -- 2.47.3