]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement connection binding for AsyncSession
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 7 Jan 2021 03:56:14 +0000 (22:56 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 7 Jan 2021 20:59:59 +0000 (15:59 -0500)
Implemented "connection-binding" for :class:`.AsyncSession`, the ability to
pass an :class:`.AsyncConnection` to create an :class:`.AsyncSession`.
Previously, this use case was not implemented and would use the associated
engine when the connection were passed.  This fixes the issue where the
"join a session to an external transaction" use case would not work
correctly for the :class:`.AsyncSession`.  Additionally, added methods
:meth:`.AsyncConnection.in_transaction`,
:meth:`.AsyncConnection.in_nested_transaction`,
:meth:`.AsyncConnection.get_transaction`.

The :class:`.AsyncEngine`, :class:`.AsyncConnection` and
:class:`.AsyncTransaction` objects may be compared using Python ``==`` or
``!=``, which will compare the two given objects based on the "sync" object
they are proxying towards. This is useful as there are cases particularly
for :class:`.AsyncTransaction` where multiple instances of
:class:`.AsyncTransaction` can be proxying towards the same sync
:class:`_engine.Transaction`, and are actually equivalent.   The
:meth:`.AsyncConnection.get_transaction` method will currently return a new
proxying :class:`.AsyncTransaction` each time as the
:class:`.AsyncTransaction` is not otherwise statefully associated with its
originating :class:`.AsyncConnection`.

Fixes: #5811
Change-Id: I5a3a6b2f088541eee7b0e0f393510e61bc9f986b

doc/build/changelog/unreleased_14/5811.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/base.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py

diff --git a/doc/build/changelog/unreleased_14/5811.rst b/doc/build/changelog/unreleased_14/5811.rst
new file mode 100644 (file)
index 0000000..5ce358c
--- /dev/null
@@ -0,0 +1,30 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 5811
+
+    Implemented "connection-binding" for :class:`.AsyncSession`, the ability to
+    pass an :class:`.AsyncConnection` to create an :class:`.AsyncSession`.
+    Previously, this use case was not implemented and would use the associated
+    engine when the connection were passed.  This fixes the issue where the
+    "join a session to an external transaction" use case would not work
+    correctly for the :class:`.AsyncSession`.  Additionally, added methods
+    :meth:`.AsyncConnection.in_transaction`,
+    :meth:`.AsyncConnection.in_nested_transaction`,
+    :meth:`.AsyncConnection.get_transaction`,
+    :meth:`.AsyncConnection.get_nested_transaction` and
+    :attr:`.AsyncConnection.info` attribute.
+
+.. change::
+    :tags: usecase, asyncio
+
+    The :class:`.AsyncEngine`, :class:`.AsyncConnection` and
+    :class:`.AsyncTransaction` objects may be compared using Python ``==`` or
+    ``!=``, which will compare the two given objects based on the "sync" object
+    they are proxying towards. This is useful as there are cases particularly
+    for :class:`.AsyncTransaction` where multiple instances of
+    :class:`.AsyncTransaction` can be proxying towards the same sync
+    :class:`_engine.Transaction`, and are actually equivalent.   The
+    :meth:`.AsyncConnection.get_transaction` method will currently return a new
+    proxying :class:`.AsyncTransaction` each time as the
+    :class:`.AsyncTransaction` is not otherwise statefully associated with its
+    originating :class:`.AsyncConnection`.
\ No newline at end of file
index 051f9e21a1f420101115302ad07e0edae7742bb6..fa8c5006ee26caceda36611ed65ea8bc347df9f8 100644 (file)
@@ -23,3 +23,20 @@ class StartableContext(abc.ABC):
             "%s context has not been started and object has not been awaited."
             % (self.__class__.__name__)
         )
+
+
+class ProxyComparable:
+    def __hash__(self):
+        return id(self)
+
+    def __eq__(self, other):
+        return (
+            isinstance(other, self.__class__)
+            and self._proxied == other._proxied
+        )
+
+    def __ne__(self, other):
+        return (
+            not isinstance(other, self.__class__)
+            or self._proxied != other._proxied
+        )
index 93adaf78ab0643000b4b09b3d381b9ee39009e7d..5951abc1e997ab9b7b290b48b6b2ef65f1c38fba 100644 (file)
@@ -4,6 +4,7 @@ from typing import Mapping
 from typing import Optional
 
 from . import exc as async_exc
+from .base import ProxyComparable
 from .base import StartableContext
 from .result import AsyncResult
 from ... import exc
@@ -57,7 +58,7 @@ class AsyncConnectable:
         "default_isolation_level",
     ],
 )
-class AsyncConnection(StartableContext, AsyncConnectable):
+class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
     """An asyncio proxy for a :class:`_engine.Connection`.
 
     :class:`_asyncio.AsyncConnection` is acquired using the
@@ -131,6 +132,24 @@ class AsyncConnection(StartableContext, AsyncConnectable):
     def _proxied(self):
         return self.sync_connection
 
+    @property
+    def info(self):
+        """Return the :attr:`_engine.Connection.info` dictionary of the
+        underlying :class:`_engine.Connection`.
+
+        This dictionary is freely writable for user-defined state to be
+        associated with the database connection.
+
+        This attribute is only available if the :class:`.AsyncConnection` is
+        currently connected.   If the :attr:`.AsyncConnection.closed` attribute
+        is ``True``, then accessing this attribute will raise
+        :class:`.ResourceClosedError`.
+
+        .. versionadded:: 1.4.0b2
+
+        """
+        return self.sync_connection.info
+
     def _sync_connection(self):
         if not self.sync_connection:
             self._raise_for_not_started()
@@ -166,6 +185,69 @@ class AsyncConnection(StartableContext, AsyncConnectable):
         conn = self._sync_connection()
         return await greenlet_spawn(conn.get_isolation_level)
 
+    def in_transaction(self):
+        """Return True if a transaction is in progress.
+
+        .. versionadded:: 1.4.0b2
+
+        """
+
+        conn = self._sync_connection()
+
+        return conn.in_transaction()
+
+    def in_nested_transaction(self):
+        """Return True if a transaction is in progress.
+
+        .. versionadded:: 1.4.0b2
+
+        """
+        conn = self._sync_connection()
+
+        return conn.in_nested_transaction()
+
+    def get_transaction(self):
+        """Return an :class:`.AsyncTransaction` representing the current
+        transaction, if any.
+
+        This makes use of the underlying synchronous connection's
+        :meth:`_engine.Connection.get_transaction` method to get the current
+        :class:`_engine.Transaction`, which is then proxied in a new
+        :class:`.AsyncTransaction` object.
+
+        .. versionadded:: 1.4.0b2
+
+        """
+        conn = self._sync_connection()
+
+        trans = conn.get_transaction()
+        if trans is not None:
+            return AsyncTransaction._from_existing_transaction(self, trans)
+        else:
+            return None
+
+    def get_nested_transaction(self):
+        """Return an :class:`.AsyncTransaction` representing the current
+        nested (savepoint) transaction, if any.
+
+        This makes use of the underlying synchronous connection's
+        :meth:`_engine.Connection.get_nested_transaction` method to get the
+        current :class:`_engine.Transaction`, which is then proxied in a new
+        :class:`.AsyncTransaction` object.
+
+        .. versionadded:: 1.4.0b2
+
+        """
+        conn = self._sync_connection()
+
+        trans = conn.get_nested_transaction()
+        if trans is not None:
+            return AsyncTransaction._from_existing_transaction(
+                self, trans, True
+            )
+        else:
+            return None
+
     async def execution_options(self, **opt):
         r"""Set non-SQL options for the connection which take effect
         during execution.
@@ -391,7 +473,7 @@ class AsyncConnection(StartableContext, AsyncConnectable):
     ],
     attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
 )
-class AsyncEngine(AsyncConnectable):
+class AsyncEngine(ProxyComparable, AsyncConnectable):
     """An asyncio proxy for a :class:`_engine.Engine`.
 
     :class:`_asyncio.AsyncEngine` is acquired using the
@@ -513,7 +595,7 @@ class AsyncEngine(AsyncConnectable):
         return await greenlet_spawn(self.sync_engine.dispose)
 
 
-class AsyncTransaction(StartableContext):
+class AsyncTransaction(ProxyComparable, StartableContext):
     """An asyncio proxy for a :class:`_engine.Transaction`."""
 
     __slots__ = ("connection", "sync_transaction", "nested")
@@ -523,11 +605,28 @@ class AsyncTransaction(StartableContext):
         self.sync_transaction: Optional[Transaction] = None
         self.nested = nested
 
+    @classmethod
+    def _from_existing_transaction(
+        cls,
+        connection: AsyncConnection,
+        sync_transaction: Transaction,
+        nested: bool = False,
+    ):
+        obj = cls.__new__(cls)
+        obj.connection = connection
+        obj.sync_transaction = sync_transaction
+        obj.nested = nested
+        return obj
+
     def _sync_transaction(self):
         if not self.sync_transaction:
             self._raise_for_not_started()
         return self.sync_transaction
 
+    @property
+    def _proxied(self):
+        return self.sync_transaction
+
     @property
     def is_valid(self) -> bool:
         return self._sync_transaction().is_valid
@@ -582,7 +681,10 @@ class AsyncTransaction(StartableContext):
             await self.rollback()
 
 
-def _get_sync_engine(async_engine):
+def _get_sync_engine_or_connection(async_engine):
+    if isinstance(async_engine, AsyncConnection):
+        return async_engine.sync_connection
+
     try:
         return async_engine.sync_engine
     except AttributeError as e:
index bac2aa44b760054c3d71dbf64c554ff62928676b..9a8284e64928382f29339b6abf24f6ea319b59d0 100644 (file)
@@ -75,12 +75,13 @@ class AsyncSession:
         kw["future"] = True
         if bind:
             self.bind = engine
-            bind = engine._get_sync_engine(bind)
+            bind = engine._get_sync_engine_or_connection(bind)
 
         if binds:
             self.binds = binds
             binds = {
-                key: engine._get_sync_engine(b) for key, b in binds.items()
+                key: engine._get_sync_engine_or_connection(b)
+                for key, b in binds.items()
             }
 
         self.sync_session = self._proxied = Session(
index 191252bfbb22e7a4e752a57d4a984053f4dadfde..c1afeb90777696203b58b512c6e9c4b33d1de9f9 100644 (file)
@@ -29,6 +29,7 @@ from .assertions import in_  # noqa
 from .assertions import is_  # noqa
 from .assertions import is_false  # noqa
 from .assertions import is_instance_of  # noqa
+from .assertions import is_none  # noqa
 from .assertions import is_not  # noqa
 from .assertions import is_not_  # noqa
 from .assertions import is_true  # noqa
index 0a2aed9d85f6d9a49c13de3282e475d2fa2b4563..b2a4ac66e9c071e99f8b8dacd24246b5d8bb6abb 100644 (file)
@@ -224,6 +224,10 @@ def is_instance_of(a, b, msg=None):
     assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
 
 
+def is_none(a, msg=None):
+    is_(a, None, msg=msg)
+
+
 def is_true(a, msg=None):
     is_(bool(a), True, msg=msg)
 
index 7dae1411e542ae1884ac76cc12ac85c8bd277850..49bf20baf78c7ba1f9fffd1ad3207930eb6bf0e4 100644 (file)
@@ -23,8 +23,12 @@ from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_none
 from sqlalchemy.testing import is_not
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
+from sqlalchemy.testing import ne_
 from sqlalchemy.util.concurrency import greenlet_spawn
 
 
@@ -72,6 +76,53 @@ class AsyncEngineTest(EngineFixture):
         eq_(async_engine.driver, sync_engine.driver)
         eq_(async_engine.echo, sync_engine.echo)
 
+    @async_test
+    async def test_engine_eq_ne(self, async_engine):
+        e2 = _async_engine.AsyncEngine(async_engine.sync_engine)
+        e3 = testing.engines.testing_engine(asyncio=True)
+
+        eq_(async_engine, e2)
+        ne_(async_engine, e3)
+
+        is_false(async_engine == None)
+
+    @async_test
+    async def test_connection_info(self, async_engine):
+
+        async with async_engine.connect() as conn:
+            conn.info["foo"] = "bar"
+
+            eq_(conn.sync_connection.info, {"foo": "bar"})
+
+    @async_test
+    async def test_connection_eq_ne(self, async_engine):
+
+        async with async_engine.connect() as conn:
+            c2 = _async_engine.AsyncConnection(
+                async_engine, conn.sync_connection
+            )
+
+            eq_(conn, c2)
+
+            async with async_engine.connect() as c3:
+                ne_(conn, c3)
+
+            is_false(conn == None)
+
+    @async_test
+    async def test_transaction_eq_ne(self, async_engine):
+
+        async with async_engine.connect() as conn:
+            t1 = await conn.begin()
+
+            t2 = _async_engine.AsyncTransaction._from_existing_transaction(
+                conn, t1._proxied
+            )
+
+            eq_(t1, t2)
+
+            is_false(t1 == None)
+
     def test_clear_compiled_cache(self, async_engine):
         async_engine.sync_engine._compiled_cache["foo"] = "bar"
         eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
@@ -103,6 +154,48 @@ class AsyncEngineTest(EngineFixture):
         is_(conn.dialect, async_engine.sync_engine.dialect)
         eq_(conn.default_isolation_level, sync_conn.default_isolation_level)
 
+    @async_test
+    async def test_transaction_accessor(self, async_engine):
+        async with async_engine.connect() as conn:
+            is_none(conn.get_transaction())
+            is_false(conn.in_transaction())
+            is_false(conn.in_nested_transaction())
+
+            trans = await conn.begin()
+
+            is_true(conn.in_transaction())
+            is_false(conn.in_nested_transaction())
+
+            is_(
+                trans.sync_transaction, conn.get_transaction().sync_transaction
+            )
+
+            nested = await conn.begin_nested()
+
+            is_true(conn.in_transaction())
+            is_true(conn.in_nested_transaction())
+
+            is_(
+                conn.get_nested_transaction().sync_transaction,
+                nested.sync_transaction,
+            )
+            eq_(conn.get_nested_transaction(), nested)
+
+            is_(
+                trans.sync_transaction, conn.get_transaction().sync_transaction
+            )
+
+            await nested.commit()
+
+            is_true(conn.in_transaction())
+            is_false(conn.in_nested_transaction())
+
+            await trans.rollback()
+
+            is_none(conn.get_transaction())
+            is_false(conn.in_transaction())
+            is_false(conn.in_nested_transaction())
+
     @async_test
     async def test_invalidate(self, async_engine):
         conn = await async_engine.connect()
index dbe84e82c3eef9023681ea3757f1a214c3fd0c0e..e56adec4d3a55b39dae51a23ebf7024a74977999 100644 (file)
@@ -40,6 +40,11 @@ class AsyncSessionTest(AsyncFixture):
             bind=async_engine.sync_engine,
         )
 
+    def test_info(self, async_session):
+        async_session.info["foo"] = "bar"
+
+        eq_(async_session.sync_session.info, {"foo": "bar"})
+
 
 class AsyncSessionQueryTest(AsyncFixture):
     @async_test
@@ -297,6 +302,107 @@ class AsyncSessionTransactionTest(AsyncFixture):
             is_(new_u_merged, u1)
             eq_(u1.name, "new u1")
 
+    @async_test
+    async def test_join_to_external_transaction(self, async_engine):
+        User = self.classes.User
+
+        async with async_engine.connect() as conn:
+            t1 = await conn.begin()
+
+            async_session = AsyncSession(conn)
+
+            aconn = await async_session.connection()
+
+            eq_(aconn.get_transaction(), t1)
+
+            eq_(aconn, conn)
+            is_(aconn.sync_connection, conn.sync_connection)
+
+            u1 = User(id=1, name="u1")
+
+            async_session.add(u1)
+
+            await async_session.commit()
+
+            assert conn.in_transaction()
+            await conn.rollback()
+
+        async with AsyncSession(async_engine) as async_session:
+            result = await async_session.execute(select(User))
+            eq_(result.all(), [])
+
+    @testing.requires.savepoints
+    @async_test
+    async def test_join_to_external_transaction_with_savepoints(
+        self, async_engine
+    ):
+        """This is the full 'join to an external transaction' recipe
+        implemented for async using savepoints.
+
+        It's not particularly simple to understand as we have to switch between
+        async / sync APIs but it works and it's a start.
+
+        """
+
+        User = self.classes.User
+
+        async with async_engine.connect() as conn:
+
+            await conn.begin()
+
+            await conn.begin_nested()
+
+            async_session = AsyncSession(conn)
+
+            @event.listens_for(
+                async_session.sync_session, "after_transaction_end"
+            )
+            def end_savepoint(session, transaction):
+                """here's an event.  inside the event we write blocking
+                style code.    wow will this be fun to try to explain :)
+
+                """
+
+                if conn.closed:
+                    return
+
+                if not conn.in_nested_transaction():
+                    conn.sync_connection.begin_nested()
+
+            aconn = await async_session.connection()
+            is_(aconn.sync_connection, conn.sync_connection)
+
+            u1 = User(id=1, name="u1")
+
+            async_session.add(u1)
+
+            await async_session.commit()
+
+            result = (await async_session.execute(select(User))).all()
+            eq_(len(result), 1)
+
+            u2 = User(id=2, name="u2")
+            async_session.add(u2)
+
+            await async_session.flush()
+
+            result = (await async_session.execute(select(User))).all()
+            eq_(len(result), 2)
+
+            # a rollback inside the session ultimately ends the savepoint
+            await async_session.rollback()
+
+            # but the previous thing we "committed" is still in the DB
+            result = (await async_session.execute(select(User))).all()
+            eq_(len(result), 1)
+
+            assert conn.in_transaction()
+            await conn.rollback()
+
+        async with AsyncSession(async_engine) as async_session:
+            result = await async_session.execute(select(User))
+            eq_(result.all(), [])
+
 
 class AsyncEventTest(AsyncFixture):
     """The engine events all run in their normal synchronous context.