--- /dev/null
+.. change::
+ :tags: bug, asyncio
+ :tickets: 12471
+
+ Fixed issue where :meth:`.AsyncSession.get_transaction` and
+ :meth:`.AsyncSession.get_nested_transaction` would fail with
+ ``NotImplementedError`` if the "proxy transaction" used by
+ :class:`.AsyncSession` were garbage collected and needed regeneration.
cls._proxy_objects.pop(ref, None)
@classmethod
- def _regenerate_proxy_for_target(cls, target: _PT) -> Self:
+ def _regenerate_proxy_for_target(
+ cls, target: _PT, **additional_kw: Any
+ ) -> Self:
raise NotImplementedError()
@overload
@classmethod
def _retrieve_proxy_for_target(
- cls,
- target: _PT,
- regenerate: Literal[True] = ...,
+ cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any
) -> Self: ...
@overload
@classmethod
def _retrieve_proxy_for_target(
- cls, target: _PT, regenerate: bool = True
+ cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]: ...
@classmethod
def _retrieve_proxy_for_target(
- cls, target: _PT, regenerate: bool = True
+ cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]:
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
return proxy # type: ignore
if regenerate:
- return cls._regenerate_proxy_for_target(target)
+ return cls._regenerate_proxy_for_target(target, **additional_kw)
else:
return None
@classmethod
def _regenerate_proxy_for_target(
- cls, target: Connection
+ cls, target: Connection, **additional_kw: Any # noqa: U100
) -> AsyncConnection:
return AsyncConnection(
AsyncEngine._retrieve_proxy_for_target(target.engine), target
return self.sync_engine
@classmethod
- def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
+ def _regenerate_proxy_for_target(
+ cls, target: Engine, **additional_kw: Any # noqa: U100
+ ) -> AsyncEngine:
return AsyncEngine(target)
@contextlib.asynccontextmanager
@classmethod
def _regenerate_proxy_for_target(
- cls, target: Transaction
+ cls, target: Transaction, **additional_kw: Any # noqa: U100
) -> AsyncTransaction:
sync_connection = target.connection
sync_transaction = target
"""
trans = self.sync_session.get_transaction()
if trans is not None:
- return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ return AsyncSessionTransaction._retrieve_proxy_for_target(
+ trans, async_session=self
+ )
else:
return None
trans = self.sync_session.get_nested_transaction()
if trans is not None:
- return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ return AsyncSessionTransaction._retrieve_proxy_for_target(
+ trans, async_session=self
+ )
else:
return None
await greenlet_spawn(self._sync_transaction().commit)
+ @classmethod
+ def _regenerate_proxy_for_target( # type: ignore[override]
+ cls,
+ target: SessionTransaction,
+ async_session: AsyncSession,
+ **additional_kw: Any, # noqa: U100
+ ) -> AsyncSessionTransaction:
+ sync_transaction = target
+ nested = target.nested
+ obj = cls.__new__(cls)
+ obj.session = async_session
+ obj.sync_transaction = obj._assign_proxied(sync_transaction)
+ obj.nested = nested
+ return obj
+
async def start(
self, is_ctxmanager: bool = False
) -> AsyncSessionTransaction:
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertions import expect_deprecated
is_(async_session.get_transaction(), None)
is_(async_session.get_nested_transaction(), None)
+ @async_test
+ async def test_get_transaction_gced(self, async_session):
+ """test #12471
+
+ this tests that the AsyncSessionTransaction is regenerated if
+ we don't have any reference to it beforehand.
+
+ """
+ is_(async_session.get_transaction(), None)
+ is_(async_session.get_nested_transaction(), None)
+
+ await async_session.begin()
+
+ trans = async_session.get_transaction()
+ is_not(trans, None)
+ is_(trans.session, async_session)
+ is_false(trans.nested)
+ is_(
+ trans.sync_transaction,
+ async_session.sync_session.get_transaction(),
+ )
+
+ await async_session.begin_nested()
+ nested = async_session.get_nested_transaction()
+ is_not(nested, None)
+ is_true(nested.nested)
+ is_(nested.session, async_session)
+ is_(
+ nested.sync_transaction,
+ async_session.sync_session.get_nested_transaction(),
+ )
+
@async_test
async def test_async_object_session(self, async_engine):
User = self.classes.User