From af0824fd790bad28beb01c11f262ac1ffe8c53be Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Fri, 27 Aug 2021 22:45:56 +0200 Subject: [PATCH] Allow custom sync session class in ``AsyncSession``. The :class:`_asyncio.AsyncSession` now supports overriding which :class:`_orm.Session` it uses as the proxied instance. A custom ``Session`` class can be passed using the :paramref:`.AsyncSession.sync_session_class` parameter or by subclassing the ``AsyncSession`` and specifying a custom :attr:`.AsyncSession.sync_session_class`. Fixes: #6689 Change-Id: Idf9c24eae6c9f4e2fff292ed748feaa449a8deaa --- doc/build/changelog/unreleased_14/6689.rst | 9 ++++ doc/build/orm/extensions/asyncio.rst | 3 ++ lib/sqlalchemy/ext/asyncio/session.py | 52 ++++++++++++++++++++-- test/ext/asyncio/test_session_py3k.py | 50 +++++++++++++++++++-- 4 files changed, 107 insertions(+), 7 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6689.rst diff --git a/doc/build/changelog/unreleased_14/6689.rst b/doc/build/changelog/unreleased_14/6689.rst new file mode 100644 index 0000000000..6abebc5f3c --- /dev/null +++ b/doc/build/changelog/unreleased_14/6689.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: asyncio, usecase + :tickets: 6746 + + The :class:`_asyncio.AsyncSession` now supports overriding which + :class:`_orm.Session` it uses as the proxied instance. A custom ``Session`` + class can be passed using the :paramref:`.AsyncSession.sync_session_class` + parameter or by subclassing the ``AsyncSession`` and specifying a custom + :attr:`.AsyncSession.sync_session_class`. diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index c5fc356d12..940c19a7e0 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -581,6 +581,9 @@ ORM Session API Documentation .. autoclass:: AsyncSession :members: + :exclude-members: sync_session_class + + .. autoattribute:: sync_session_class .. autoclass:: AsyncSessionTransaction :members: diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 5c6e7f5a7c..5c5426d720 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -51,9 +51,16 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True}) class AsyncSession(ReversibleProxy): """Asyncio version of :class:`_orm.Session`. + The :class:`_asyncio.AsyncSession` is a proxy for a traditional + :class:`_orm.Session` instance. .. versionadded:: 1.4 + To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session` + implementations, see the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. + + """ _is_asyncio = True @@ -68,7 +75,25 @@ class AsyncSession(ReversibleProxy): dispatch = None - def __init__(self, bind=None, binds=None, **kw): + def __init__(self, bind=None, binds=None, sync_session_class=None, **kw): + r"""Construct a new :class:`_asyncio.AsyncSession`. + + All parameters other than ``sync_session_class`` are passed to the + ``sync_session_class`` callable directly to instantiate a new + :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for + parameter documentation. + + :param sync_session_class: + A :class:`_orm.Session` subclass or other callable which will be used + to construct the :class:`_orm.Session` which will be proxied. This + parameter may be used to provide custom :class:`_orm.Session` + subclasses. Defaults to the + :attr:`_asyncio.AsyncSession.sync_session_class` class-level + attribute. + + .. versionadded:: 1.4.24 + + """ kw["future"] = True if bind: self.bind = bind @@ -81,10 +106,30 @@ class AsyncSession(ReversibleProxy): for key, b in binds.items() } + if sync_session_class: + self.sync_session_class = sync_session_class + self.sync_session = self._proxied = self._assign_proxied( - Session(bind=bind, binds=binds, **kw) + self.sync_session_class(bind=bind, binds=binds, **kw) ) + sync_session_class = Session + """The class or callable that provides the + underlying :class:`_orm.Session` instance for a particular + :class:`_asyncio.AsyncSession`. + + At the class level, this attribute is the default value for the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom + subclasses of :class:`_asyncio.AsyncSession` can override this. + + At the instance level, this attribute indicates the current class or + callable that was used to provide the :class:`_orm.Session` instance for + this :class:`_asyncio.AsyncSession` instance. + + .. versionadded:: 1.4.24 + + """ + async def refresh( self, instance, attribute_names=None, with_for_update=None ): @@ -141,7 +186,8 @@ class AsyncSession(ReversibleProxy): **kw ): """Execute a statement and return a buffered - :class:`_engine.Result` object.""" + :class:`_engine.Result` object. + """ if execution_options: execution_options = util.immutabledict(execution_options).union( diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index ebedfedbfb..48faa1ca1e 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -14,11 +14,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio.base import ReversibleProxy from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import async_test from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from .test_engine_py3k import AsyncFixture as _AsyncFixture from ...orm import _fixtures @@ -724,8 +726,6 @@ class AsyncProxyTest(AsyncFixture): is_(inspect(u3).async_session, None) def test_inspect_session_no_asyncio_used(self): - from sqlalchemy.orm import Session - User = self.classes.User s1 = Session(testing.db) @@ -734,8 +734,6 @@ class AsyncProxyTest(AsyncFixture): is_(inspect(u1).async_session, None) def test_inspect_session_no_asyncio_imported(self): - from sqlalchemy.orm import Session - with mock.patch("sqlalchemy.orm.state._async_provider", None): User = self.classes.User @@ -758,3 +756,47 @@ class AsyncProxyTest(AsyncFixture): del async_session eq_(len(ReversibleProxy._proxy_objects), 0) + + +class _MySession(Session): + pass + + +class _MyAS(AsyncSession): + sync_session_class = _MySession + + +class OverrideSyncSession(AsyncFixture): + def test_default(self, async_engine): + ass = AsyncSession(async_engine) + + is_true(isinstance(ass.sync_session, Session)) + is_(ass.sync_session.__class__, Session) + is_(ass.sync_session_class, Session) + + def test_init_class(self, async_engine): + ass = AsyncSession(async_engine, sync_session_class=_MySession) + + is_true(isinstance(ass.sync_session, _MySession)) + is_(ass.sync_session_class, _MySession) + + def test_init_sessionmaker(self, async_engine): + sm = sessionmaker( + async_engine, class_=AsyncSession, sync_session_class=_MySession + ) + ass = sm() + + is_true(isinstance(ass.sync_session, _MySession)) + is_(ass.sync_session_class, _MySession) + + def test_subclass(self, async_engine): + ass = _MyAS(async_engine) + + is_true(isinstance(ass.sync_session, _MySession)) + is_(ass.sync_session_class, _MySession) + + def test_subclass_override(self, async_engine): + ass = _MyAS(async_engine, sync_session_class=Session) + + is_true(not isinstance(ass.sync_session, _MySession)) + is_(ass.sync_session_class, Session) -- 2.47.2