]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement AsyncSessionTransaction._regenerate_proxy_for_target
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Mar 2025 17:55:46 +0000 (13:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Mar 2025 15:11:04 +0000 (11:11 -0400)
Fixed issue where :meth:`.AsyncSession.get_transaction` and
:meth:`.AsyncSession.get_nested_transaction` would fail with
``NotImplementedError`` if the "proxy transaction" used by
:class:`.AsyncSession` were garbage collected and needed regeneration.

Fixes: #12471
Change-Id: Ia8055524618df706d7958786a500cdd25d9d8eaf

doc/build/changelog/unreleased_20/12471.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
test/ext/asyncio/test_session_py3k.py

diff --git a/doc/build/changelog/unreleased_20/12471.rst b/doc/build/changelog/unreleased_20/12471.rst
new file mode 100644 (file)
index 0000000..d3178b7
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 12471
+
+    Fixed issue where :meth:`.AsyncSession.get_transaction` and
+    :meth:`.AsyncSession.get_nested_transaction` would fail with
+    ``NotImplementedError`` if the "proxy transaction" used by
+    :class:`.AsyncSession` were garbage collected and needed regeneration.
index b53d53b1a4ea549313e3a114875661dca479abc8..ce2c439f1609670f4ac0b010c7be129b99b33db3 100644 (file)
@@ -71,26 +71,26 @@ class ReversibleProxy(Generic[_PT]):
         cls._proxy_objects.pop(ref, None)
 
     @classmethod
-    def _regenerate_proxy_for_target(cls, target: _PT) -> Self:
+    def _regenerate_proxy_for_target(
+        cls, target: _PT, **additional_kw: Any
+    ) -> Self:
         raise NotImplementedError()
 
     @overload
     @classmethod
     def _retrieve_proxy_for_target(
-        cls,
-        target: _PT,
-        regenerate: Literal[True] = ...,
+        cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any
     ) -> Self: ...
 
     @overload
     @classmethod
     def _retrieve_proxy_for_target(
-        cls, target: _PT, regenerate: bool = True
+        cls, target: _PT, regenerate: bool = True, **additional_kw: Any
     ) -> Optional[Self]: ...
 
     @classmethod
     def _retrieve_proxy_for_target(
-        cls, target: _PT, regenerate: bool = True
+        cls, target: _PT, regenerate: bool = True, **additional_kw: Any
     ) -> Optional[Self]:
         try:
             proxy_ref = cls._proxy_objects[weakref.ref(target)]
@@ -102,7 +102,7 @@ class ReversibleProxy(Generic[_PT]):
                 return proxy  # type: ignore
 
         if regenerate:
-            return cls._regenerate_proxy_for_target(target)
+            return cls._regenerate_proxy_for_target(target, **additional_kw)
         else:
             return None
 
index 0595668eb35265dc7d35a0755b8751440dd42d2e..bf3cae63493bd4405ec67a2f41fabcff820f27ab 100644 (file)
@@ -258,7 +258,7 @@ class AsyncConnection(
 
     @classmethod
     def _regenerate_proxy_for_target(
-        cls, target: Connection
+        cls, target: Connection, **additional_kw: Any  # noqa: U100
     ) -> AsyncConnection:
         return AsyncConnection(
             AsyncEngine._retrieve_proxy_for_target(target.engine), target
@@ -1045,7 +1045,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
         return self.sync_engine
 
     @classmethod
-    def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
+    def _regenerate_proxy_for_target(
+        cls, target: Engine, **additional_kw: Any  # noqa: U100
+    ) -> AsyncEngine:
         return AsyncEngine(target)
 
     @contextlib.asynccontextmanager
@@ -1346,7 +1348,7 @@ class AsyncTransaction(
 
     @classmethod
     def _regenerate_proxy_for_target(
-        cls, target: Transaction
+        cls, target: Transaction, **additional_kw: Any  # noqa: U100
     ) -> AsyncTransaction:
         sync_connection = target.connection
         sync_transaction = target
index adb88f53f6e2f09fc3d60884db7ec58213177f09..17be0c8409ebd273f198772edc5e0f0693dec144 100644 (file)
@@ -843,7 +843,9 @@ class AsyncSession(ReversibleProxy[Session]):
         """
         trans = self.sync_session.get_transaction()
         if trans is not None:
-            return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+            return AsyncSessionTransaction._retrieve_proxy_for_target(
+                trans, async_session=self
+            )
         else:
             return None
 
@@ -859,7 +861,9 @@ class AsyncSession(ReversibleProxy[Session]):
 
         trans = self.sync_session.get_nested_transaction()
         if trans is not None:
-            return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+            return AsyncSessionTransaction._retrieve_proxy_for_target(
+                trans, async_session=self
+            )
         else:
             return None
 
@@ -1896,6 +1900,21 @@ class AsyncSessionTransaction(
 
         await greenlet_spawn(self._sync_transaction().commit)
 
+    @classmethod
+    def _regenerate_proxy_for_target(  # type: ignore[override]
+        cls,
+        target: SessionTransaction,
+        async_session: AsyncSession,
+        **additional_kw: Any,  # noqa: U100
+    ) -> AsyncSessionTransaction:
+        sync_transaction = target
+        nested = target.nested
+        obj = cls.__new__(cls)
+        obj.session = async_session
+        obj.sync_transaction = obj._assign_proxied(sync_transaction)
+        obj.nested = nested
+        return obj
+
     async def start(
         self, is_ctxmanager: bool = False
     ) -> AsyncSessionTransaction:
index 2d6ce09da3af052672d8f6c651ad6bcd22b42e4f..5f9bf2e089e87b87544e10a7296d1b64dea6d15f 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import expect_deprecated
@@ -934,6 +935,38 @@ class AsyncProxyTest(AsyncFixture):
         is_(async_session.get_transaction(), None)
         is_(async_session.get_nested_transaction(), None)
 
+    @async_test
+    async def test_get_transaction_gced(self, async_session):
+        """test #12471
+
+        this tests that the AsyncSessionTransaction is regenerated if
+        we don't have any reference to it beforehand.
+
+        """
+        is_(async_session.get_transaction(), None)
+        is_(async_session.get_nested_transaction(), None)
+
+        await async_session.begin()
+
+        trans = async_session.get_transaction()
+        is_not(trans, None)
+        is_(trans.session, async_session)
+        is_false(trans.nested)
+        is_(
+            trans.sync_transaction,
+            async_session.sync_session.get_transaction(),
+        )
+
+        await async_session.begin_nested()
+        nested = async_session.get_nested_transaction()
+        is_not(nested, None)
+        is_true(nested.nested)
+        is_(nested.session, async_session)
+        is_(
+            nested.sync_transaction,
+            async_session.sync_session.get_nested_transaction(),
+        )
+
     @async_test
     async def test_async_object_session(self, async_engine):
         User = self.classes.User