From: Mike Bayer Date: Wed, 26 Mar 2025 17:55:46 +0000 (-0400) Subject: implement AsyncSessionTransaction._regenerate_proxy_for_target X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0202673a34b1b0cbbda6e2cb06012f77df642085;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement AsyncSessionTransaction._regenerate_proxy_for_target Fixed issue where :meth:`.AsyncSession.get_transaction` and :meth:`.AsyncSession.get_nested_transaction` would fail with ``NotImplementedError`` if the "proxy transaction" used by :class:`.AsyncSession` were garbage collected and needed regeneration. Fixes: #12471 Change-Id: Ia8055524618df706d7958786a500cdd25d9d8eaf --- diff --git a/doc/build/changelog/unreleased_20/12471.rst b/doc/build/changelog/unreleased_20/12471.rst new file mode 100644 index 0000000000..d3178b712a --- /dev/null +++ b/doc/build/changelog/unreleased_20/12471.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, asyncio + :tickets: 12471 + + Fixed issue where :meth:`.AsyncSession.get_transaction` and + :meth:`.AsyncSession.get_nested_transaction` would fail with + ``NotImplementedError`` if the "proxy transaction" used by + :class:`.AsyncSession` were garbage collected and needed regeneration. diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index b53d53b1a4..ce2c439f16 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -71,26 +71,26 @@ class ReversibleProxy(Generic[_PT]): cls._proxy_objects.pop(ref, None) @classmethod - def _regenerate_proxy_for_target(cls, target: _PT) -> Self: + def _regenerate_proxy_for_target( + cls, target: _PT, **additional_kw: Any + ) -> Self: raise NotImplementedError() @overload @classmethod def _retrieve_proxy_for_target( - cls, - target: _PT, - regenerate: Literal[True] = ..., + cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any ) -> Self: ... @overload @classmethod def _retrieve_proxy_for_target( - cls, target: _PT, regenerate: bool = True + cls, target: _PT, regenerate: bool = True, **additional_kw: Any ) -> Optional[Self]: ... @classmethod def _retrieve_proxy_for_target( - cls, target: _PT, regenerate: bool = True + cls, target: _PT, regenerate: bool = True, **additional_kw: Any ) -> Optional[Self]: try: proxy_ref = cls._proxy_objects[weakref.ref(target)] @@ -102,7 +102,7 @@ class ReversibleProxy(Generic[_PT]): return proxy # type: ignore if regenerate: - return cls._regenerate_proxy_for_target(target) + return cls._regenerate_proxy_for_target(target, **additional_kw) else: return None diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 0595668eb3..bf3cae6349 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -258,7 +258,7 @@ class AsyncConnection( @classmethod def _regenerate_proxy_for_target( - cls, target: Connection + cls, target: Connection, **additional_kw: Any # noqa: U100 ) -> AsyncConnection: return AsyncConnection( AsyncEngine._retrieve_proxy_for_target(target.engine), target @@ -1045,7 +1045,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): return self.sync_engine @classmethod - def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: + def _regenerate_proxy_for_target( + cls, target: Engine, **additional_kw: Any # noqa: U100 + ) -> AsyncEngine: return AsyncEngine(target) @contextlib.asynccontextmanager @@ -1346,7 +1348,7 @@ class AsyncTransaction( @classmethod def _regenerate_proxy_for_target( - cls, target: Transaction + cls, target: Transaction, **additional_kw: Any # noqa: U100 ) -> AsyncTransaction: sync_connection = target.connection sync_transaction = target diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index adb88f53f6..17be0c8409 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -843,7 +843,9 @@ class AsyncSession(ReversibleProxy[Session]): """ trans = self.sync_session.get_transaction() if trans is not None: - return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + return AsyncSessionTransaction._retrieve_proxy_for_target( + trans, async_session=self + ) else: return None @@ -859,7 +861,9 @@ class AsyncSession(ReversibleProxy[Session]): trans = self.sync_session.get_nested_transaction() if trans is not None: - return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + return AsyncSessionTransaction._retrieve_proxy_for_target( + trans, async_session=self + ) else: return None @@ -1896,6 +1900,21 @@ class AsyncSessionTransaction( await greenlet_spawn(self._sync_transaction().commit) + @classmethod + def _regenerate_proxy_for_target( # type: ignore[override] + cls, + target: SessionTransaction, + async_session: AsyncSession, + **additional_kw: Any, # noqa: U100 + ) -> AsyncSessionTransaction: + sync_transaction = target + nested = target.nested + obj = cls.__new__(cls) + obj.session = async_session + obj.sync_transaction = obj._assign_proxied(sync_transaction) + obj.nested = nested + return obj + async def start( self, is_ctxmanager: bool = False ) -> AsyncSessionTransaction: diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 2d6ce09da3..5f9bf2e089 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -38,6 +38,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import expect_deprecated @@ -934,6 +935,38 @@ class AsyncProxyTest(AsyncFixture): is_(async_session.get_transaction(), None) is_(async_session.get_nested_transaction(), None) + @async_test + async def test_get_transaction_gced(self, async_session): + """test #12471 + + this tests that the AsyncSessionTransaction is regenerated if + we don't have any reference to it beforehand. + + """ + is_(async_session.get_transaction(), None) + is_(async_session.get_nested_transaction(), None) + + await async_session.begin() + + trans = async_session.get_transaction() + is_not(trans, None) + is_(trans.session, async_session) + is_false(trans.nested) + is_( + trans.sync_transaction, + async_session.sync_session.get_transaction(), + ) + + await async_session.begin_nested() + nested = async_session.get_nested_transaction() + is_not(nested, None) + is_true(nested.nested) + is_(nested.session, async_session) + is_( + nested.sync_transaction, + async_session.sync_session.get_nested_transaction(), + ) + @async_test async def test_async_object_session(self, async_engine): User = self.classes.User