From: Bryan不可思议 Date: Wed, 18 Oct 2023 20:31:45 +0000 (-0400) Subject: fix AsyncSession.close_all() X-Git-Tag: rel_2_0_23~24^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=29948f6848a84e5124c886aa6466e76938cb27de;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fix AsyncSession.close_all() Fixed bug with method :meth:`_asyncio.AsyncSession.close_all` that was not working correctly. Also added function :func:`_asyncio.close_all_sessions` that's the equivalent of :func:`_orm.close_all_sessions`. Fixes: #10421 Closes: #10429 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10429 Pull-request-sha: a35c29d41213bd79bfa73ac559a29035ca4e9eb9 Change-Id: If2c3f0130a71b239382c2ea11a3436788ee242be --- diff --git a/doc/build/changelog/changelog_20.rst b/doc/build/changelog/changelog_20.rst index 2c518321ac..39939ee0d8 100644 --- a/doc/build/changelog/changelog_20.rst +++ b/doc/build/changelog/changelog_20.rst @@ -934,7 +934,7 @@ This is a no-argument callable that provides a new asyncio connection, using the asyncio database driver directly. The :func:`.create_async_engine` function will wrap the driver-level connection - in the appropriate structures. Pull request curtesy of Jack Wotherspoon. + in the appropriate structures. Pull request courtesy of Jack Wotherspoon. .. change:: :tags: bug, orm, regression @@ -3100,7 +3100,7 @@ * ``@?`` using :meth:`_postgresql.JSONB.Comparator.path_exists` * ``#-`` using :meth:`_postgresql.JSONB.Comparator.delete_path` - Pull request curtesy of Guilherme Martins Crocetti. + Pull request courtesy of Guilherme Martins Crocetti. .. changelog:: :version: 2.0.0rc1 diff --git a/doc/build/changelog/unreleased_20/10421.rst b/doc/build/changelog/unreleased_20/10421.rst new file mode 100644 index 0000000000..c550647cea --- /dev/null +++ b/doc/build/changelog/unreleased_20/10421.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, asyncio + :tickets: 10421 + + Fixed bug with method :meth:`_asyncio.AsyncSession.close_all` + that was not working correctly. + Also added function :func:`_asyncio.close_all_sessions` that's + the equivalent of :func:`_orm.close_all_sessions`. + Pull request courtesy of Bryan不可思议. diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index e4782f2bbb..0815da29af 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -1097,6 +1097,8 @@ ORM Session API Documentation .. autofunction:: async_session +.. autofunction:: close_all_sessions + .. autoclass:: async_sessionmaker :members: :inherited-members: diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index ad6cd15268..8564db6f22 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -22,3 +22,4 @@ from .session import async_sessionmaker as async_sessionmaker from .session import AsyncAttrs as AsyncAttrs from .session import AsyncSession as AsyncSession from .session import AsyncSessionTransaction as AsyncSessionTransaction +from .session import close_all_sessions as close_all_sessions diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 3e6d4fc71c..7ac11fface 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -1544,7 +1544,7 @@ class async_scoped_session(Generic[_AS]): return self._proxied.info @classmethod - async def close_all(self) -> None: + async def close_all(cls) -> None: r"""Close all :class:`_asyncio.AsyncSession` sessions. .. container:: class_bases @@ -1552,6 +1552,8 @@ class async_scoped_session(Generic[_AS]): Proxied for the :class:`_asyncio.AsyncSession` class on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + .. deprecated:: 2.0 The :meth:`.AsyncSession.close_all` method is deprecated and will be removed in a future release. Please refer to :func:`_asyncio.close_all_sessions`. + """ # noqa: E501 return await AsyncSession.close_all() diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index dcbf26c6c4..b4b5ab3fbe 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -32,6 +32,7 @@ from .result import _ensure_sync_result from .result import AsyncResult from .result import AsyncScalarResult from ... import util +from ...orm import close_all_sessions as _sync_close_all_sessions from ...orm import object_session from ...orm import Session from ...orm import SessionTransaction @@ -1057,9 +1058,15 @@ class AsyncSession(ReversibleProxy[Session]): await greenlet_spawn(self.sync_session.invalidate) @classmethod - async def close_all(self) -> None: + @util.deprecated( + "2.0", + "The :meth:`.AsyncSession.close_all` method is deprecated and will be " + "removed in a future release. Please refer to " + ":func:`_asyncio.close_all_sessions`.", + ) + async def close_all(cls) -> None: """Close all :class:`_asyncio.AsyncSession` sessions.""" - await greenlet_spawn(self.sync_session.close_all) + await close_all_sessions() async def __aenter__(self: _AS) -> _AS: return self @@ -1911,4 +1918,17 @@ def async_session(session: Session) -> Optional[AsyncSession]: return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) +async def close_all_sessions() -> None: + """Close all :class:`_asyncio.AsyncSession` sessions. + + .. versionadded:: 2.0.23 + + .. seealso:: + + :func:`.session.close_all_sessions` + + """ + await greenlet_spawn(_sync_close_all_sessions) + + _instance_state._async_provider = async_session # type: ignore diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index c4527e123f..9c56487c40 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -265,6 +265,13 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]: metadata.update(format_argspec_plus(spec, grouped=False)) metadata["name"] = fn.__name__ + if inspect.iscoroutinefunction(fn): + metadata["prefix"] = "async " + metadata["target_prefix"] = "await " + else: + metadata["prefix"] = "" + metadata["target_prefix"] = "" + # look for __ positional arguments. This is a convention in # SQLAlchemy that arguments should be passed positionally # rather than as keyword @@ -276,16 +283,16 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]: if "__" in repr(spec[0]): code = ( """\ -def %(name)s%(grouped_args)s: - return %(target)s(%(fn)s, %(apply_pos)s) +%(prefix)sdef %(name)s%(grouped_args)s: + return %(target_prefix)s%(target)s(%(fn)s, %(apply_pos)s) """ % metadata ) else: code = ( """\ -def %(name)s%(grouped_args)s: - return %(target)s(%(fn)s, %(apply_kw)s) +%(prefix)sdef %(name)s%(grouped_args)s: + return %(target_prefix)s%(target)s(%(fn)s, %(apply_kw)s) """ % metadata ) diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 19ce55a2d7..e38a0cc52a 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -22,6 +22,7 @@ from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import close_all_sessions from sqlalchemy.ext.asyncio import exc as async_exc from sqlalchemy.ext.asyncio.base import ReversibleProxy from sqlalchemy.orm import DeclarativeBase @@ -41,7 +42,9 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import expect_deprecated +from sqlalchemy.testing.assertions import in_ from sqlalchemy.testing.assertions import is_false +from sqlalchemy.testing.assertions import not_in from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.provision import normalize_sequence from .test_engine_py3k import AsyncFixture as _AsyncFixture @@ -122,6 +125,50 @@ class AsyncSessionTest(AsyncFixture): sync_connection.dialect.default_sequence_base, ) + @async_test + async def test_close_all(self, async_engine): + User = self.classes.User + + s1 = AsyncSession(async_engine) + u1 = User() + s1.add(u1) + + s2 = AsyncSession(async_engine) + u2 = User() + s2.add(u2) + + in_(u1, s1) + in_(u2, s2) + + await close_all_sessions() + + not_in(u1, s1) + not_in(u2, s2) + + @async_test + async def test_session_close_all_deprecated(self, async_engine): + User = self.classes.User + + s1 = AsyncSession(async_engine) + u1 = User() + s1.add(u1) + + s2 = AsyncSession(async_engine) + u2 = User() + s2.add(u2) + + in_(u1, s1) + in_(u2, s2) + + with expect_deprecated( + r"The AsyncSession.close_all\(\) method is deprecated and will " + "be removed in a future release. " + ): + await AsyncSession.close_all() + + not_in(u1, s1) + not_in(u2, s2) + class AsyncSessionQueryTest(AsyncFixture): @async_test