From: Mike Bayer Date: Wed, 2 Jun 2021 16:23:31 +0000 (-0400) Subject: Implement proxy back reference system for asyncio X-Git-Tag: rel_1_4_18~14^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=97d922663a0350c6ce026ecfbde8010ca1bc0c37;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement proxy back reference system for asyncio Implemented a new registry architecture that allows the ``Async`` version of an object, like ``AsyncSession``, ``AsyncConnection``, etc., to be locatable given the proxied "sync" object, i.e. ``Session``, ``Connection``. Previously, to the degree such lookup functions were used, an ``Async`` object would be re-created each time, which was less than ideal as the identity and state of the "async" object would not be preserved across calls. From there, new helper functions :func:`_asyncio.async_object_session`, :func:`_asyncio.async_session` as well as a new :class:`_orm.InstanceState` attribute :attr:`_orm.InstanceState.asyncio_session` have been added, which are used to retrieve the original :class:`_asyncio.AsyncSession` associated with an ORM mapped object, a :class:`_orm.Session` associated with an :class:`_asyncio.AsyncSession`, and an :class:`_asyncio.AsyncSession` associated with an :class:`_orm.InstanceState`, respectively. This patch also implements new methods :meth:`_asyncio.AsyncSession.in_nested_transaction`, :meth:`_asyncio.AsyncSession.get_transaction`, :meth:`_asyncio.AsyncSession.get_nested_transaction`. Fixes: #6319 Change-Id: Ia452a7e7ce9bad3ff8846c7dea8d45c839ac9fac --- diff --git a/doc/build/changelog/unreleased_14/6319.rst b/doc/build/changelog/unreleased_14/6319.rst new file mode 100644 index 0000000000..13db051e14 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6319.rst @@ -0,0 +1,24 @@ +.. change:: + :tags: usecase, asyncio + :tickets: 6319 + + Implemented a new registry architecture that allows the ``Async`` version + of an object, like ``AsyncSession``, ``AsyncConnection``, etc., to be + locatable given the proxied "sync" object, i.e. ``Session``, + ``Connection``. Previously, to the degree such lookup functions were used, + an ``Async`` object would be re-created each time, which was less than + ideal as the identity and state of the "async" object would not be + preserved across calls. + + From there, new helper functions :func:`_asyncio.async_object_session`, + :func:`_asyncio.async_session` as well as a new :class:`_orm.InstanceState` + attribute :attr:`_orm.InstanceState.async_session` have been added, which + are used to retrieve the original :class:`_asyncio.AsyncSession` associated + with an ORM mapped object, a :class:`_orm.Session` associated with an + :class:`_asyncio.AsyncSession`, and an :class:`_asyncio.AsyncSession` + associated with an :class:`_orm.InstanceState`, respectively. + + This patch also implements new methods + :meth:`_asyncio.AsyncSession.in_nested_transaction`, + :meth:`_asyncio.AsyncSession.get_transaction`, + :meth:`_asyncio.AsyncSession.get_nested_transaction`. diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index e56c4b59ac..92471d459e 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -100,7 +100,7 @@ illustrates a complete example including mapper and session configuration:: from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy.ext.asyncio import AsyncSession - from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.future import select from sqlalchemy.orm import declarative_base from sqlalchemy.orm import relationship @@ -452,6 +452,10 @@ cursor. ORM Session API Documentation ----------------------------- +.. autofunction:: async_object_session + +.. autofunction:: async_session + .. autoclass:: AsyncSession :members: diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py index 2fda2d777e..349bc1b753 100644 --- a/lib/sqlalchemy/ext/asyncio/__init__.py +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -14,5 +14,7 @@ from .events import AsyncSessionEvents from .result import AsyncMappingResult from .result import AsyncResult from .result import AsyncScalarResult +from .session import async_object_session +from .session import async_session from .session import AsyncSession from .session import AsyncSessionTransaction diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 76a2fbbde9..3f2c084f4a 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -1,8 +1,50 @@ import abc +import functools +import weakref from . import exc as async_exc +class ReversibleProxy: + # weakref.ref(async proxy object) -> weakref.ref(sync proxied object) + _proxy_objects = {} + + def _assign_proxied(self, target): + if target is not None: + target_ref = weakref.ref(target, ReversibleProxy._target_gced) + proxy_ref = weakref.ref( + self, + functools.partial(ReversibleProxy._target_gced, target_ref), + ) + ReversibleProxy._proxy_objects[target_ref] = proxy_ref + + return target + + @classmethod + def _target_gced(cls, ref, proxy_ref=None): + cls._proxy_objects.pop(ref, None) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + raise NotImplementedError() + + @classmethod + def _retrieve_proxy_for_target(cls, target, regenerate=True): + try: + proxy_ref = cls._proxy_objects[weakref.ref(target)] + except KeyError: + pass + else: + proxy = proxy_ref() + if proxy is not None: + return proxy + + if regenerate: + return cls._regenerate_proxy_for_target(target) + else: + return None + + class StartableContext(abc.ABC): @abc.abstractmethod async def start(self, is_ctxmanager=False): @@ -25,7 +67,7 @@ class StartableContext(abc.ABC): ) -class ProxyComparable: +class ProxyComparable(ReversibleProxy): def __hash__(self): return id(self) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 9cd3cb2f8b..8e5c019191 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -11,6 +11,7 @@ from .result import AsyncResult from ... import exc from ... import util from ...engine import create_engine as _create_engine +from ...engine.base import NestedTransaction from ...future import Connection from ...future import Engine from ...util.concurrency import greenlet_spawn @@ -86,7 +87,13 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): def __init__(self, async_engine, sync_connection=None): self.engine = async_engine self.sync_engine = async_engine.sync_engine - self.sync_connection = sync_connection + self.sync_connection = self._assign_proxied(sync_connection) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncConnection( + AsyncEngine._retrieve_proxy_for_target(target.engine), target + ) async def start(self, is_ctxmanager=False): """Start this :class:`_asyncio.AsyncConnection` object's context @@ -95,7 +102,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): """ if self.sync_connection: raise exc.InvalidRequestError("connection is already started") - self.sync_connection = await (greenlet_spawn(self.sync_engine.connect)) + self.sync_connection = self._assign_proxied( + await (greenlet_spawn(self.sync_engine.connect)) + ) return self @property @@ -216,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): trans = conn.get_transaction() if trans is not None: - return AsyncTransaction._from_existing_transaction(self, trans) + return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None @@ -236,9 +245,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): trans = conn.get_nested_transaction() if trans is not None: - return AsyncTransaction._from_existing_transaction( - self, trans, True - ) + return AsyncTransaction._retrieve_proxy_for_target(trans) else: return None @@ -522,7 +529,11 @@ class AsyncEngine(ProxyComparable, AsyncConnectable): "The asyncio extension requires an async driver to be used. " f"The loaded {sync_engine.dialect.driver!r} is not async." ) - self.sync_engine = self._proxied = sync_engine + self.sync_engine = self._proxied = self._assign_proxied(sync_engine) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncEngine(target) def begin(self): """Return a context manager which when entered will deliver an @@ -605,17 +616,24 @@ class AsyncTransaction(ProxyComparable, StartableContext): __slots__ = ("connection", "sync_transaction", "nested") def __init__(self, connection, nested=False): - self.connection = connection - self.sync_transaction = None + self.connection = connection # AsyncConnection + self.sync_transaction = None # sqlalchemy.engine.Transaction self.nested = nested @classmethod - def _from_existing_transaction( - cls, connection, sync_transaction, nested=False - ): + def _regenerate_proxy_for_target(cls, target): + sync_connection = target.connection + sync_transaction = target + nested = isinstance(target, NestedTransaction) + + async_connection = AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + assert async_connection is not None + obj = cls.__new__(cls) - obj.connection = connection - obj.sync_transaction = sync_transaction + obj.connection = async_connection + obj.sync_transaction = obj._assign_proxied(sync_transaction) obj.nested = nested return obj @@ -664,10 +682,12 @@ class AsyncTransaction(ProxyComparable, StartableContext): """ - self.sync_transaction = await greenlet_spawn( - self.connection._sync_connection().begin_nested - if self.nested - else self.connection._sync_connection().begin + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.connection._sync_connection().begin_nested + if self.nested + else self.connection._sync_connection().begin + ) ) if is_ctxmanager: self.sync_transaction.__enter__() diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 343465f377..16e15c8731 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -6,9 +6,12 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from . import engine from . import result as _result +from .base import ReversibleProxy from .base import StartableContext from ... import util +from ...orm import object_session from ...orm import Session +from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn @@ -29,6 +32,7 @@ from ...util.concurrency import greenlet_spawn "get_bind", "is_modified", "in_transaction", + "in_nested_transaction", ], attributes=[ "dirty", @@ -41,7 +45,7 @@ from ...util.concurrency import greenlet_spawn "info", ], ) -class AsyncSession: +class AsyncSession(ReversibleProxy): """Asyncio version of :class:`_orm.Session`. @@ -72,8 +76,8 @@ class AsyncSession: for key, b in binds.items() } - self.sync_session = self._proxied = Session( - bind=bind, binds=binds, **kw + self.sync_session = self._proxied = self._assign_proxied( + Session(bind=bind, binds=binds, **kw) ) async def refresh( @@ -242,21 +246,46 @@ class AsyncSession: """ await greenlet_spawn(self.sync_session.flush, objects=objects) + def get_transaction(self): + """Return the current root transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + trans = self.sync_session.get_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self): + """Return the current nested transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + + trans = self.sync_session.get_nested_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + async def connection(self): - r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this - :class:`.Session` object's transactional state. + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. """ - # POSSIBLY TODO: here, we see that the sync engine / connection - # that are generated from AsyncEngine / AsyncConnection don't - # provide any backlink from those sync objects back out to the - # async ones. it's not *too* big a deal since AsyncEngine/Connection - # are just proxies and all the state is actually in the sync - # version of things. However! it has to stay that way :) sync_connection = await greenlet_spawn(self.sync_session.connection) - return engine.AsyncConnection( - engine.AsyncEngine(sync_connection.engine), sync_connection + return engine.AsyncConnection._retrieve_proxy_for_target( + sync_connection ) def begin(self, **kw): @@ -363,7 +392,7 @@ class _AsyncSessionContextManager: await self.async_session.__aexit__(type_, value, traceback) -class AsyncSessionTransaction(StartableContext): +class AsyncSessionTransaction(ReversibleProxy, StartableContext): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. This object is provided so that a transaction-holding object @@ -408,10 +437,12 @@ class AsyncSessionTransaction(StartableContext): await greenlet_spawn(self._sync_transaction().commit) async def start(self, is_ctxmanager=False): - self.sync_transaction = await greenlet_spawn( - self.session.sync_session.begin_nested - if self.nested - else self.session.sync_session.begin + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.session.sync_session.begin_nested + if self.nested + else self.session.sync_session.begin + ) ) if is_ctxmanager: self.sync_transaction.__enter__() @@ -421,3 +452,48 @@ class AsyncSessionTransaction(StartableContext): await greenlet_spawn( self._sync_transaction().__exit__, type_, value, traceback ) + + +def async_object_session(instance): + """Return the :class:`_asyncio.AsyncSession` to which the given instance + belongs. + + This function makes use of the sync-API function + :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which + refers to the given instance, and from there links it to the original + :class:`_asyncio.AsyncSession`. + + If the :class:`_asyncio.AsyncSession` has been garbage collected, the + return value is ``None``. + + This functionality is also available from the + :attr:`_orm.InstanceState.async_session` accessor. + + :param instance: an ORM mapped instance + :return: an :class:`_asyncio.AsyncSession` object, or ``None``. + + .. versionadded:: 1.4.18 + + """ + + session = object_session(instance) + if session is not None: + return async_session(session) + else: + return None + + +def async_session(session): + """Return the :class:`_asyncio.AsyncSession` which is proxying the given + :class:`_orm.Session` object, if any. + + :param session: a :class:`_orm.Session` instance. + :return: a :class:`_asyncio.AsyncSession` instance, or ``None``. + + .. versionadded:: 1.4.18 + + """ + return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) + + +_instance_state._async_provider = async_session diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 08390328e4..884e364c68 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -34,6 +34,9 @@ from .. import util # late-populated by session.py _sessions = None +# optionally late-provided by sqlalchemy.ext.asyncio.session +_async_provider = None + @inspection._self_inspects class InstanceState(interfaces.InspectionAttrInfo): @@ -262,6 +265,10 @@ class InstanceState(interfaces.InspectionAttrInfo): Only when the transaction is completed does the object become fully detached under normal circumstances. + .. seealso:: + + :attr:`_orm.InstanceState.async_session` + """ if self.session_id: try: @@ -270,6 +277,34 @@ class InstanceState(interfaces.InspectionAttrInfo): pass return None + @property + def async_session(self): + """Return the owning :class:`_asyncio.AsyncSession` for this instance, + or ``None`` if none available. + + This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio` + API is in use for this ORM object. The returned + :class:`_asyncio.AsyncSession` object will be a proxy for the + :class:`_orm.Session` object that would be returned from the + :attr:`_orm.InstanceState.session` attribute for this + :class:`_orm.InstanceState`. + + .. versionadded:: 1.4.18 + + .. seealso:: + + :ref:`asyncio_toplevel` + + """ + if _async_provider is None: + return None + + sess = self.session + if sess is not None: + return _async_provider(sess) + else: + return None + @property def object(self): """Return the mapped object represented by this diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 18e55ff92c..59df759e67 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -16,6 +16,9 @@ from sqlalchemy import union_all from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import engine as _async_engine from sqlalchemy.ext.asyncio import exc as asyncio_exc +from sqlalchemy.ext.asyncio.base import ReversibleProxy +from sqlalchemy.ext.asyncio.engine import AsyncConnection +from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.pool import AsyncAdaptedQueuePool from sqlalchemy.testing import assertions from sqlalchemy.testing import async_test @@ -293,8 +296,8 @@ class AsyncEngineTest(EngineFixture): async with async_engine.connect() as conn: t1 = await conn.begin() - t2 = _async_engine.AsyncTransaction._from_existing_transaction( - conn, t1._proxied + t2 = _async_engine.AsyncTransaction._regenerate_proxy_for_target( + t1._proxied ) eq_(t1, t2) @@ -886,3 +889,118 @@ class TextSyncDBAPI(fixtures.TestBase): ) assert res == 1 assert await conn.run_sync(lambda _: 2) == 2 + + +class AsyncProxyTest(EngineFixture, fixtures.TestBase): + @async_test + async def test_get_transaction(self, async_engine): + async with async_engine.connect() as conn: + async with conn.begin() as trans: + + is_(trans.connection, conn) + is_(conn.get_transaction(), trans) + + @async_test + async def test_get_nested_transaction(self, async_engine): + async with async_engine.connect() as conn: + async with conn.begin() as trans: + n1 = await conn.begin_nested() + + is_(conn.get_nested_transaction(), n1) + + n2 = await conn.begin_nested() + + is_(conn.get_nested_transaction(), n2) + + await n2.commit() + + is_(conn.get_nested_transaction(), n1) + + is_(conn.get_transaction(), trans) + + @async_test + async def test_get_connection(self, async_engine): + async with async_engine.connect() as conn: + is_( + AsyncConnection._retrieve_proxy_for_target( + conn.sync_connection + ), + conn, + ) + + def test_regenerate_connection(self, connection): + + async_connection = AsyncConnection._retrieve_proxy_for_target( + connection + ) + + a2 = AsyncConnection._retrieve_proxy_for_target(connection) + is_(async_connection, a2) + is_not(async_connection, None) + + is_(async_connection.engine, a2.engine) + is_not(async_connection.engine, None) + + @testing.requires.predictable_gc + async def test_gc_engine(self, testing_engine): + ReversibleProxy._proxy_objects.clear() + + eq_(len(ReversibleProxy._proxy_objects), 0) + + async_engine = AsyncEngine(testing.db) + + eq_(len(ReversibleProxy._proxy_objects), 1) + + del async_engine + + eq_(len(ReversibleProxy._proxy_objects), 0) + + @testing.requires.predictable_gc + @async_test + async def test_gc_conn(self, testing_engine): + ReversibleProxy._proxy_objects.clear() + + async_engine = AsyncEngine(testing.db) + + eq_(len(ReversibleProxy._proxy_objects), 1) + + async with async_engine.connect() as conn: + eq_(len(ReversibleProxy._proxy_objects), 2) + + async with conn.begin() as trans: + eq_(len(ReversibleProxy._proxy_objects), 3) + + del trans + + del conn + + eq_(len(ReversibleProxy._proxy_objects), 1) + + del async_engine + + eq_(len(ReversibleProxy._proxy_objects), 0) + + def test_regen_conn_but_not_engine(self, async_engine): + + sync_conn = async_engine.sync_engine.connect() + + async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn) + async_conn2 = AsyncConnection._retrieve_proxy_for_target(sync_conn) + + is_(async_conn, async_conn2) + is_(async_conn.engine, async_engine) + + def test_regen_trans_but_not_conn(self, async_engine): + sync_conn = async_engine.sync_engine.connect() + + async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn) + + trans = sync_conn.begin() + + async_t1 = async_conn.get_transaction() + + is_(async_t1.connection, async_conn) + is_(async_t1.sync_transaction, trans) + + async_t2 = async_conn.get_transaction() + is_(async_t1, async_t2) diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index e97e2563ab..1f5c950542 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -1,11 +1,14 @@ from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import select from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import update +from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio.base import ReversibleProxy from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import sessionmaker @@ -503,3 +506,159 @@ class AsyncEventTest(AsyncFixture): canary.mock_calls, [mock.call(async_session.sync_session)], ) + + +class AsyncProxyTest(AsyncFixture): + @async_test + async def test_get_connection_engine_bound(self, async_session): + c1 = await async_session.connection() + + c2 = await async_session.connection() + + is_(c1, c2) + is_(c1.engine, c2.engine) + + @async_test + async def test_get_connection_connection_bound(self, async_engine): + async with async_engine.begin() as conn: + async_session = AsyncSession(conn) + + c1 = await async_session.connection() + + is_(c1, conn) + is_(c1.engine, conn.engine) + + @async_test + async def test_get_transaction(self, async_session): + + is_(async_session.get_transaction(), None) + is_(async_session.get_nested_transaction(), None) + + t1 = await async_session.begin() + + is_(async_session.get_transaction(), t1) + is_(async_session.get_nested_transaction(), None) + + n1 = await async_session.begin_nested() + + is_(async_session.get_transaction(), t1) + is_(async_session.get_nested_transaction(), n1) + + await n1.commit() + + is_(async_session.get_transaction(), t1) + is_(async_session.get_nested_transaction(), None) + + await t1.commit() + + is_(async_session.get_transaction(), None) + is_(async_session.get_nested_transaction(), None) + + @async_test + async def test_async_object_session(self, async_engine): + User = self.classes.User + + s1 = AsyncSession(async_engine) + + s2 = AsyncSession(async_engine) + + u1 = await s1.get(User, 7) + + u2 = User(name="n1") + + s2.add(u2) + + u3 = User(name="n2") + + is_(async_object_session(u1), s1) + is_(async_object_session(u2), s2) + + is_(async_object_session(u3), None) + + await s2.close() + is_(async_object_session(u2), None) + + @async_test + async def test_async_object_session_custom(self, async_engine): + User = self.classes.User + + class MyCustomAsync(AsyncSession): + pass + + s1 = MyCustomAsync(async_engine) + + u1 = await s1.get(User, 7) + + assert isinstance(async_object_session(u1), MyCustomAsync) + + @testing.requires.predictable_gc + @async_test + async def test_async_object_session_del(self, async_engine): + User = self.classes.User + + s1 = AsyncSession(async_engine) + + u1 = await s1.get(User, 7) + + is_(async_object_session(u1), s1) + + await s1.rollback() + del s1 + is_(async_object_session(u1), None) + + @async_test + async def test_inspect_session(self, async_engine): + User = self.classes.User + + s1 = AsyncSession(async_engine) + + s2 = AsyncSession(async_engine) + + u1 = await s1.get(User, 7) + + u2 = User(name="n1") + + s2.add(u2) + + u3 = User(name="n2") + + is_(inspect(u1).async_session, s1) + is_(inspect(u2).async_session, s2) + + is_(inspect(u3).async_session, None) + + def test_inspect_session_no_asyncio_used(self): + from sqlalchemy.orm import Session + + User = self.classes.User + + s1 = Session(testing.db) + u1 = s1.get(User, 7) + + is_(inspect(u1).async_session, None) + + def test_inspect_session_no_asyncio_imported(self): + from sqlalchemy.orm import Session + + with mock.patch("sqlalchemy.orm.state._async_provider", None): + + User = self.classes.User + + s1 = Session(testing.db) + u1 = s1.get(User, 7) + + is_(inspect(u1).async_session, None) + + @testing.requires.predictable_gc + def test_gc(self, async_engine): + ReversibleProxy._proxy_objects.clear() + + eq_(len(ReversibleProxy._proxy_objects), 0) + + async_session = AsyncSession(async_engine) + + eq_(len(ReversibleProxy._proxy_objects), 1) + + del async_session + + eq_(len(ReversibleProxy._proxy_objects), 0)