]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow custom sync session class in ``AsyncSession``.
authorFederico Caselli <cfederico87@gmail.com>
Fri, 27 Aug 2021 20:45:56 +0000 (22:45 +0200)
committermike bayer <mike_mp@zzzcomputing.com>
Mon, 30 Aug 2021 14:52:37 +0000 (14:52 +0000)
The :class:`_asyncio.AsyncSession` now supports overriding which
:class:`_orm.Session` it uses as the proxied instance. A custom ``Session``
class can be passed using the :paramref:`.AsyncSession.sync_session_class`
parameter or by subclassing the ``AsyncSession`` and specifying a custom
:attr:`.AsyncSession.sync_session_class`.

Fixes: #6689
Change-Id: Idf9c24eae6c9f4e2fff292ed748feaa449a8deaa

doc/build/changelog/unreleased_14/6689.rst [new file with mode: 0644]
doc/build/orm/extensions/asyncio.rst
lib/sqlalchemy/ext/asyncio/session.py
test/ext/asyncio/test_session_py3k.py

diff --git a/doc/build/changelog/unreleased_14/6689.rst b/doc/build/changelog/unreleased_14/6689.rst
new file mode 100644 (file)
index 0000000..6abebc5
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: asyncio, usecase
+    :tickets: 6746
+
+    The :class:`_asyncio.AsyncSession` now supports overriding which
+    :class:`_orm.Session` it uses as the proxied instance. A custom ``Session``
+    class can be passed using the :paramref:`.AsyncSession.sync_session_class`
+    parameter or by subclassing the ``AsyncSession`` and specifying a custom
+    :attr:`.AsyncSession.sync_session_class`.
index c5fc356d1205beda1c7162a3a0385559633944ce..940c19a7e0becb7893dfab03989b4c970f78ee25 100644 (file)
@@ -581,6 +581,9 @@ ORM Session API Documentation
 
 .. autoclass:: AsyncSession
    :members:
+   :exclude-members: sync_session_class
+
+   .. autoattribute:: sync_session_class
 
 .. autoclass:: AsyncSessionTransaction
    :members:
index 5c6e7f5a7c23cedfacc046862c62d2d554591811..5c5426d7201f11f418299396266a05076dbe6732 100644 (file)
@@ -51,9 +51,16 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
 class AsyncSession(ReversibleProxy):
     """Asyncio version of :class:`_orm.Session`.
 
+    The :class:`_asyncio.AsyncSession` is a proxy for a traditional
+    :class:`_orm.Session` instance.
 
     .. versionadded:: 1.4
 
+    To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
+    implementations, see the
+    :paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
+
+
     """
 
     _is_asyncio = True
@@ -68,7 +75,25 @@ class AsyncSession(ReversibleProxy):
 
     dispatch = None
 
-    def __init__(self, bind=None, binds=None, **kw):
+    def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+        r"""Construct a new :class:`_asyncio.AsyncSession`.
+
+        All parameters other than ``sync_session_class`` are passed to the
+        ``sync_session_class`` callable directly to instantiate a new
+        :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
+        parameter documentation.
+
+        :param sync_session_class:
+          A :class:`_orm.Session` subclass or other callable which will be used
+          to construct the :class:`_orm.Session` which will be proxied. This
+          parameter may be used to provide custom :class:`_orm.Session`
+          subclasses. Defaults to the
+          :attr:`_asyncio.AsyncSession.sync_session_class` class-level
+          attribute.
+
+          .. versionadded:: 1.4.24
+
+        """
         kw["future"] = True
         if bind:
             self.bind = bind
@@ -81,10 +106,30 @@ class AsyncSession(ReversibleProxy):
                 for key, b in binds.items()
             }
 
+        if sync_session_class:
+            self.sync_session_class = sync_session_class
+
         self.sync_session = self._proxied = self._assign_proxied(
-            Session(bind=bind, binds=binds, **kw)
+            self.sync_session_class(bind=bind, binds=binds, **kw)
         )
 
+    sync_session_class = Session
+    """The class or callable that provides the
+    underlying :class:`_orm.Session` instance for a particular
+    :class:`_asyncio.AsyncSession`.
+
+    At the class level, this attribute is the default value for the
+    :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
+    subclasses of :class:`_asyncio.AsyncSession` can override this.
+
+    At the instance level, this attribute indicates the current class or
+    callable that was used to provide the :class:`_orm.Session` instance for
+    this :class:`_asyncio.AsyncSession` instance.
+
+    .. versionadded:: 1.4.24
+
+    """
+
     async def refresh(
         self, instance, attribute_names=None, with_for_update=None
     ):
@@ -141,7 +186,8 @@ class AsyncSession(ReversibleProxy):
         **kw
     ):
         """Execute a statement and return a buffered
-        :class:`_engine.Result` object."""
+        :class:`_engine.Result` object.
+        """
 
         if execution_options:
             execution_options = util.immutabledict(execution_options).union(
index ebedfedbfba0926a0ab3607d54bc0669c0e50f4d..48faa1ca1e2f9b7decbf8e6a140494ddeedde1a3 100644 (file)
@@ -14,11 +14,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.ext.asyncio.base import ReversibleProxy
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
+from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from .test_engine_py3k import AsyncFixture as _AsyncFixture
 from ...orm import _fixtures
@@ -724,8 +726,6 @@ class AsyncProxyTest(AsyncFixture):
         is_(inspect(u3).async_session, None)
 
     def test_inspect_session_no_asyncio_used(self):
-        from sqlalchemy.orm import Session
-
         User = self.classes.User
 
         s1 = Session(testing.db)
@@ -734,8 +734,6 @@ class AsyncProxyTest(AsyncFixture):
         is_(inspect(u1).async_session, None)
 
     def test_inspect_session_no_asyncio_imported(self):
-        from sqlalchemy.orm import Session
-
         with mock.patch("sqlalchemy.orm.state._async_provider", None):
 
             User = self.classes.User
@@ -758,3 +756,47 @@ class AsyncProxyTest(AsyncFixture):
         del async_session
 
         eq_(len(ReversibleProxy._proxy_objects), 0)
+
+
+class _MySession(Session):
+    pass
+
+
+class _MyAS(AsyncSession):
+    sync_session_class = _MySession
+
+
+class OverrideSyncSession(AsyncFixture):
+    def test_default(self, async_engine):
+        ass = AsyncSession(async_engine)
+
+        is_true(isinstance(ass.sync_session, Session))
+        is_(ass.sync_session.__class__, Session)
+        is_(ass.sync_session_class, Session)
+
+    def test_init_class(self, async_engine):
+        ass = AsyncSession(async_engine, sync_session_class=_MySession)
+
+        is_true(isinstance(ass.sync_session, _MySession))
+        is_(ass.sync_session_class, _MySession)
+
+    def test_init_sessionmaker(self, async_engine):
+        sm = sessionmaker(
+            async_engine, class_=AsyncSession, sync_session_class=_MySession
+        )
+        ass = sm()
+
+        is_true(isinstance(ass.sync_session, _MySession))
+        is_(ass.sync_session_class, _MySession)
+
+    def test_subclass(self, async_engine):
+        ass = _MyAS(async_engine)
+
+        is_true(isinstance(ass.sync_session, _MySession))
+        is_(ass.sync_session_class, _MySession)
+
+    def test_subclass_override(self, async_engine):
+        ass = _MyAS(async_engine, sync_session_class=Session)
+
+        is_true(not isinstance(ass.sync_session, _MySession))
+        is_(ass.sync_session_class, Session)