--- /dev/null
+.. change::
+ :tags: usecase, asyncio
+ :tickets: 6319
+
+ Implemented a new registry architecture that allows the ``Async`` version
+ of an object, like ``AsyncSession``, ``AsyncConnection``, etc., to be
+ locatable given the proxied "sync" object, i.e. ``Session``,
+ ``Connection``. Previously, to the degree such lookup functions were used,
+ an ``Async`` object would be re-created each time, which was less than
+ ideal as the identity and state of the "async" object would not be
+ preserved across calls.
+
+ From there, new helper functions :func:`_asyncio.async_object_session`,
+ :func:`_asyncio.async_session` as well as a new :class:`_orm.InstanceState`
+ attribute :attr:`_orm.InstanceState.async_session` have been added, which
+ are used to retrieve the original :class:`_asyncio.AsyncSession` associated
+ with an ORM mapped object, a :class:`_orm.Session` associated with an
+ :class:`_asyncio.AsyncSession`, and an :class:`_asyncio.AsyncSession`
+ associated with an :class:`_orm.InstanceState`, respectively.
+
+ This patch also implements new methods
+ :meth:`_asyncio.AsyncSession.in_nested_transaction`,
+ :meth:`_asyncio.AsyncSession.get_transaction`,
+ :meth:`_asyncio.AsyncSession.get_nested_transaction`.
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.ext.asyncio import create_async_engine
+ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.future import select
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship
ORM Session API Documentation
-----------------------------
+.. autofunction:: async_object_session
+
+.. autofunction:: async_session
+
.. autoclass:: AsyncSession
:members:
from .result import AsyncMappingResult
from .result import AsyncResult
from .result import AsyncScalarResult
+from .session import async_object_session
+from .session import async_session
from .session import AsyncSession
from .session import AsyncSessionTransaction
import abc
+import functools
+import weakref
from . import exc as async_exc
+class ReversibleProxy:
+ # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
+ _proxy_objects = {}
+
+ def _assign_proxied(self, target):
+ if target is not None:
+ target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ proxy_ref = weakref.ref(
+ self,
+ functools.partial(ReversibleProxy._target_gced, target_ref),
+ )
+ ReversibleProxy._proxy_objects[target_ref] = proxy_ref
+
+ return target
+
+ @classmethod
+ def _target_gced(cls, ref, proxy_ref=None):
+ cls._proxy_objects.pop(ref, None)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ raise NotImplementedError()
+
+ @classmethod
+ def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ try:
+ proxy_ref = cls._proxy_objects[weakref.ref(target)]
+ except KeyError:
+ pass
+ else:
+ proxy = proxy_ref()
+ if proxy is not None:
+ return proxy
+
+ if regenerate:
+ return cls._regenerate_proxy_for_target(target)
+ else:
+ return None
+
+
class StartableContext(abc.ABC):
@abc.abstractmethod
async def start(self, is_ctxmanager=False):
)
-class ProxyComparable:
+class ProxyComparable(ReversibleProxy):
def __hash__(self):
return id(self)
from ... import exc
from ... import util
from ...engine import create_engine as _create_engine
+from ...engine.base import NestedTransaction
from ...future import Connection
from ...future import Engine
from ...util.concurrency import greenlet_spawn
def __init__(self, async_engine, sync_connection=None):
self.engine = async_engine
self.sync_engine = async_engine.sync_engine
- self.sync_connection = sync_connection
+ self.sync_connection = self._assign_proxied(sync_connection)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncConnection(
+ AsyncEngine._retrieve_proxy_for_target(target.engine), target
+ )
async def start(self, is_ctxmanager=False):
"""Start this :class:`_asyncio.AsyncConnection` object's context
"""
if self.sync_connection:
raise exc.InvalidRequestError("connection is already started")
- self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
+ self.sync_connection = self._assign_proxied(
+ await (greenlet_spawn(self.sync_engine.connect))
+ )
return self
@property
trans = conn.get_transaction()
if trans is not None:
- return AsyncTransaction._from_existing_transaction(self, trans)
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
trans = conn.get_nested_transaction()
if trans is not None:
- return AsyncTransaction._from_existing_transaction(
- self, trans, True
- )
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
"The asyncio extension requires an async driver to be used. "
f"The loaded {sync_engine.dialect.driver!r} is not async."
)
- self.sync_engine = self._proxied = sync_engine
+ self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncEngine(target)
def begin(self):
"""Return a context manager which when entered will deliver an
__slots__ = ("connection", "sync_transaction", "nested")
def __init__(self, connection, nested=False):
- self.connection = connection
- self.sync_transaction = None
+ self.connection = connection # AsyncConnection
+ self.sync_transaction = None # sqlalchemy.engine.Transaction
self.nested = nested
@classmethod
- def _from_existing_transaction(
- cls, connection, sync_transaction, nested=False
- ):
+ def _regenerate_proxy_for_target(cls, target):
+ sync_connection = target.connection
+ sync_transaction = target
+ nested = isinstance(target, NestedTransaction)
+
+ async_connection = AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+ assert async_connection is not None
+
obj = cls.__new__(cls)
- obj.connection = connection
- obj.sync_transaction = sync_transaction
+ obj.connection = async_connection
+ obj.sync_transaction = obj._assign_proxied(sync_transaction)
obj.nested = nested
return obj
"""
- self.sync_transaction = await greenlet_spawn(
- self.connection._sync_connection().begin_nested
- if self.nested
- else self.connection._sync_connection().begin
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.connection._sync_connection().begin_nested
+ if self.nested
+ else self.connection._sync_connection().begin
+ )
)
if is_ctxmanager:
self.sync_transaction.__enter__()
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from . import engine
from . import result as _result
+from .base import ReversibleProxy
from .base import StartableContext
from ... import util
+from ...orm import object_session
from ...orm import Session
+from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
"get_bind",
"is_modified",
"in_transaction",
+ "in_nested_transaction",
],
attributes=[
"dirty",
"info",
],
)
-class AsyncSession:
+class AsyncSession(ReversibleProxy):
"""Asyncio version of :class:`_orm.Session`.
for key, b in binds.items()
}
- self.sync_session = self._proxied = Session(
- bind=bind, binds=binds, **kw
+ self.sync_session = self._proxied = self._assign_proxied(
+ Session(bind=bind, binds=binds, **kw)
)
async def refresh(
"""
await greenlet_spawn(self.sync_session.flush, objects=objects)
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ trans = self.sync_session.get_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ trans = self.sync_session.get_nested_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
async def connection(self):
- r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to this
- :class:`.Session` object's transactional state.
+ r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
+ this :class:`.Session` object's transactional state.
"""
- # POSSIBLY TODO: here, we see that the sync engine / connection
- # that are generated from AsyncEngine / AsyncConnection don't
- # provide any backlink from those sync objects back out to the
- # async ones. it's not *too* big a deal since AsyncEngine/Connection
- # are just proxies and all the state is actually in the sync
- # version of things. However! it has to stay that way :)
sync_connection = await greenlet_spawn(self.sync_session.connection)
- return engine.AsyncConnection(
- engine.AsyncEngine(sync_connection.engine), sync_connection
+ return engine.AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
)
def begin(self, **kw):
await self.async_session.__aexit__(type_, value, traceback)
-class AsyncSessionTransaction(StartableContext):
+class AsyncSessionTransaction(ReversibleProxy, StartableContext):
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.
This object is provided so that a transaction-holding object
await greenlet_spawn(self._sync_transaction().commit)
async def start(self, is_ctxmanager=False):
- self.sync_transaction = await greenlet_spawn(
- self.session.sync_session.begin_nested
- if self.nested
- else self.session.sync_session.begin
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.session.sync_session.begin_nested
+ if self.nested
+ else self.session.sync_session.begin
+ )
)
if is_ctxmanager:
self.sync_transaction.__enter__()
await greenlet_spawn(
self._sync_transaction().__exit__, type_, value, traceback
)
+
+
+def async_object_session(instance):
+ """Return the :class:`_asyncio.AsyncSession` to which the given instance
+ belongs.
+
+ This function makes use of the sync-API function
+ :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
+ refers to the given instance, and from there links it to the original
+ :class:`_asyncio.AsyncSession`.
+
+ If the :class:`_asyncio.AsyncSession` has been garbage collected, the
+ return value is ``None``.
+
+ This functionality is also available from the
+ :attr:`_orm.InstanceState.async_session` accessor.
+
+ :param instance: an ORM mapped instance
+ :return: an :class:`_asyncio.AsyncSession` object, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ session = object_session(instance)
+ if session is not None:
+ return async_session(session)
+ else:
+ return None
+
+
+def async_session(session):
+ """Return the :class:`_asyncio.AsyncSession` which is proxying the given
+ :class:`_orm.Session` object, if any.
+
+ :param session: a :class:`_orm.Session` instance.
+ :return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
+
+
+_instance_state._async_provider = async_session
# late-populated by session.py
_sessions = None
+# optionally late-provided by sqlalchemy.ext.asyncio.session
+_async_provider = None
+
@inspection._self_inspects
class InstanceState(interfaces.InspectionAttrInfo):
Only when the transaction is completed does the object become
fully detached under normal circumstances.
+ .. seealso::
+
+ :attr:`_orm.InstanceState.async_session`
+
"""
if self.session_id:
try:
pass
return None
+ @property
+ def async_session(self):
+ """Return the owning :class:`_asyncio.AsyncSession` for this instance,
+ or ``None`` if none available.
+
+ This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio`
+ API is in use for this ORM object. The returned
+ :class:`_asyncio.AsyncSession` object will be a proxy for the
+ :class:`_orm.Session` object that would be returned from the
+ :attr:`_orm.InstanceState.session` attribute for this
+ :class:`_orm.InstanceState`.
+
+ .. versionadded:: 1.4.18
+
+ .. seealso::
+
+ :ref:`asyncio_toplevel`
+
+ """
+ if _async_provider is None:
+ return None
+
+ sess = self.session
+ if sess is not None:
+ return _async_provider(sess)
+ else:
+ return None
+
@property
def object(self):
"""Return the mapped object represented by this
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import engine as _async_engine
from sqlalchemy.ext.asyncio import exc as asyncio_exc
+from sqlalchemy.ext.asyncio.base import ReversibleProxy
+from sqlalchemy.ext.asyncio.engine import AsyncConnection
+from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.pool import AsyncAdaptedQueuePool
from sqlalchemy.testing import assertions
from sqlalchemy.testing import async_test
async with async_engine.connect() as conn:
t1 = await conn.begin()
- t2 = _async_engine.AsyncTransaction._from_existing_transaction(
- conn, t1._proxied
+ t2 = _async_engine.AsyncTransaction._regenerate_proxy_for_target(
+ t1._proxied
)
eq_(t1, t2)
)
assert res == 1
assert await conn.run_sync(lambda _: 2) == 2
+
+
+class AsyncProxyTest(EngineFixture, fixtures.TestBase):
+ @async_test
+ async def test_get_transaction(self, async_engine):
+ async with async_engine.connect() as conn:
+ async with conn.begin() as trans:
+
+ is_(trans.connection, conn)
+ is_(conn.get_transaction(), trans)
+
+ @async_test
+ async def test_get_nested_transaction(self, async_engine):
+ async with async_engine.connect() as conn:
+ async with conn.begin() as trans:
+ n1 = await conn.begin_nested()
+
+ is_(conn.get_nested_transaction(), n1)
+
+ n2 = await conn.begin_nested()
+
+ is_(conn.get_nested_transaction(), n2)
+
+ await n2.commit()
+
+ is_(conn.get_nested_transaction(), n1)
+
+ is_(conn.get_transaction(), trans)
+
+ @async_test
+ async def test_get_connection(self, async_engine):
+ async with async_engine.connect() as conn:
+ is_(
+ AsyncConnection._retrieve_proxy_for_target(
+ conn.sync_connection
+ ),
+ conn,
+ )
+
+ def test_regenerate_connection(self, connection):
+
+ async_connection = AsyncConnection._retrieve_proxy_for_target(
+ connection
+ )
+
+ a2 = AsyncConnection._retrieve_proxy_for_target(connection)
+ is_(async_connection, a2)
+ is_not(async_connection, None)
+
+ is_(async_connection.engine, a2.engine)
+ is_not(async_connection.engine, None)
+
+ @testing.requires.predictable_gc
+ async def test_gc_engine(self, testing_engine):
+ ReversibleProxy._proxy_objects.clear()
+
+ eq_(len(ReversibleProxy._proxy_objects), 0)
+
+ async_engine = AsyncEngine(testing.db)
+
+ eq_(len(ReversibleProxy._proxy_objects), 1)
+
+ del async_engine
+
+ eq_(len(ReversibleProxy._proxy_objects), 0)
+
+ @testing.requires.predictable_gc
+ @async_test
+ async def test_gc_conn(self, testing_engine):
+ ReversibleProxy._proxy_objects.clear()
+
+ async_engine = AsyncEngine(testing.db)
+
+ eq_(len(ReversibleProxy._proxy_objects), 1)
+
+ async with async_engine.connect() as conn:
+ eq_(len(ReversibleProxy._proxy_objects), 2)
+
+ async with conn.begin() as trans:
+ eq_(len(ReversibleProxy._proxy_objects), 3)
+
+ del trans
+
+ del conn
+
+ eq_(len(ReversibleProxy._proxy_objects), 1)
+
+ del async_engine
+
+ eq_(len(ReversibleProxy._proxy_objects), 0)
+
+ def test_regen_conn_but_not_engine(self, async_engine):
+
+ sync_conn = async_engine.sync_engine.connect()
+
+ async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
+ async_conn2 = AsyncConnection._retrieve_proxy_for_target(sync_conn)
+
+ is_(async_conn, async_conn2)
+ is_(async_conn.engine, async_engine)
+
+ def test_regen_trans_but_not_conn(self, async_engine):
+ sync_conn = async_engine.sync_engine.connect()
+
+ async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
+
+ trans = sync_conn.begin()
+
+ async_t1 = async_conn.get_transaction()
+
+ is_(async_t1.connection, async_conn)
+ is_(async_t1.sync_transaction, trans)
+
+ async_t2 = async_conn.get_transaction()
+ is_(async_t1, async_t2)
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
+from sqlalchemy import inspect
from sqlalchemy import select
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import update
+from sqlalchemy.ext.asyncio import async_object_session
from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio.base import ReversibleProxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import sessionmaker
canary.mock_calls,
[mock.call(async_session.sync_session)],
)
+
+
+class AsyncProxyTest(AsyncFixture):
+ @async_test
+ async def test_get_connection_engine_bound(self, async_session):
+ c1 = await async_session.connection()
+
+ c2 = await async_session.connection()
+
+ is_(c1, c2)
+ is_(c1.engine, c2.engine)
+
+ @async_test
+ async def test_get_connection_connection_bound(self, async_engine):
+ async with async_engine.begin() as conn:
+ async_session = AsyncSession(conn)
+
+ c1 = await async_session.connection()
+
+ is_(c1, conn)
+ is_(c1.engine, conn.engine)
+
+ @async_test
+ async def test_get_transaction(self, async_session):
+
+ is_(async_session.get_transaction(), None)
+ is_(async_session.get_nested_transaction(), None)
+
+ t1 = await async_session.begin()
+
+ is_(async_session.get_transaction(), t1)
+ is_(async_session.get_nested_transaction(), None)
+
+ n1 = await async_session.begin_nested()
+
+ is_(async_session.get_transaction(), t1)
+ is_(async_session.get_nested_transaction(), n1)
+
+ await n1.commit()
+
+ is_(async_session.get_transaction(), t1)
+ is_(async_session.get_nested_transaction(), None)
+
+ await t1.commit()
+
+ is_(async_session.get_transaction(), None)
+ is_(async_session.get_nested_transaction(), None)
+
+ @async_test
+ async def test_async_object_session(self, async_engine):
+ User = self.classes.User
+
+ s1 = AsyncSession(async_engine)
+
+ s2 = AsyncSession(async_engine)
+
+ u1 = await s1.get(User, 7)
+
+ u2 = User(name="n1")
+
+ s2.add(u2)
+
+ u3 = User(name="n2")
+
+ is_(async_object_session(u1), s1)
+ is_(async_object_session(u2), s2)
+
+ is_(async_object_session(u3), None)
+
+ await s2.close()
+ is_(async_object_session(u2), None)
+
+ @async_test
+ async def test_async_object_session_custom(self, async_engine):
+ User = self.classes.User
+
+ class MyCustomAsync(AsyncSession):
+ pass
+
+ s1 = MyCustomAsync(async_engine)
+
+ u1 = await s1.get(User, 7)
+
+ assert isinstance(async_object_session(u1), MyCustomAsync)
+
+ @testing.requires.predictable_gc
+ @async_test
+ async def test_async_object_session_del(self, async_engine):
+ User = self.classes.User
+
+ s1 = AsyncSession(async_engine)
+
+ u1 = await s1.get(User, 7)
+
+ is_(async_object_session(u1), s1)
+
+ await s1.rollback()
+ del s1
+ is_(async_object_session(u1), None)
+
+ @async_test
+ async def test_inspect_session(self, async_engine):
+ User = self.classes.User
+
+ s1 = AsyncSession(async_engine)
+
+ s2 = AsyncSession(async_engine)
+
+ u1 = await s1.get(User, 7)
+
+ u2 = User(name="n1")
+
+ s2.add(u2)
+
+ u3 = User(name="n2")
+
+ is_(inspect(u1).async_session, s1)
+ is_(inspect(u2).async_session, s2)
+
+ is_(inspect(u3).async_session, None)
+
+ def test_inspect_session_no_asyncio_used(self):
+ from sqlalchemy.orm import Session
+
+ User = self.classes.User
+
+ s1 = Session(testing.db)
+ u1 = s1.get(User, 7)
+
+ is_(inspect(u1).async_session, None)
+
+ def test_inspect_session_no_asyncio_imported(self):
+ from sqlalchemy.orm import Session
+
+ with mock.patch("sqlalchemy.orm.state._async_provider", None):
+
+ User = self.classes.User
+
+ s1 = Session(testing.db)
+ u1 = s1.get(User, 7)
+
+ is_(inspect(u1).async_session, None)
+
+ @testing.requires.predictable_gc
+ def test_gc(self, async_engine):
+ ReversibleProxy._proxy_objects.clear()
+
+ eq_(len(ReversibleProxy._proxy_objects), 0)
+
+ async_session = AsyncSession(async_engine)
+
+ eq_(len(ReversibleProxy._proxy_objects), 1)
+
+ del async_session
+
+ eq_(len(ReversibleProxy._proxy_objects), 0)