]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add shield() in aexit
authorFederico Caselli <cfederico87@gmail.com>
Fri, 17 Jun 2022 21:12:39 +0000 (23:12 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Jul 2022 15:33:34 +0000 (11:33 -0400)
Added ``asyncio.shield()`` to the connection and session release process
specifically within the ``__aexit__()`` context manager exit, when using
:class:`.AsyncConnection` or :class:`.AsyncSession` as a context manager
that releases the object when the context manager is complete. This appears
to help with task cancellation when using alternate concurrency libraries
such as ``anyio``, ``uvloop`` that otherwise don't provide an async context
for the connection pool to release the connection properly during task
cancellation.

Fixes: #8145
Change-Id: I0b1ea9c3a22a18619341cbb8591225fcd339042c

doc/build/changelog/unreleased_14/8145.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py

diff --git a/doc/build/changelog/unreleased_14/8145.rst b/doc/build/changelog/unreleased_14/8145.rst
new file mode 100644 (file)
index 0000000..4cd6c12
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 8145
+
+    Added ``asyncio.shield()`` to the connection and session release process
+    specifically within the ``__aexit__()`` context manager exit, when using
+    :class:`.AsyncConnection` or :class:`.AsyncSession` as a context manager
+    that releases the object when the context manager is complete. This appears
+    to help with task cancellation when using alternate concurrency libraries
+    such as ``anyio``, ``uvloop`` that otherwise don't provide an async context
+    for the connection pool to release the connection properly during task
+    cancellation.
+
+
index 97d69fcbd29a1eaed4d734052c88af58ec4e609d..796ddccf5a11fe47b95c9522381a905a32088f91 100644 (file)
@@ -6,6 +6,7 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 from __future__ import annotations
 
+import asyncio
 from typing import Any
 from typing import Dict
 from typing import Generator
@@ -698,7 +699,7 @@ class AsyncConnection(
         return self.start().__await__()
 
     async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
-        await self.close()
+        await asyncio.shield(self.close())
 
     # START PROXY METHODS AsyncConnection
 
@@ -855,8 +856,11 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
         async def __aexit__(
             self, type_: Any, value: Any, traceback: Any
         ) -> None:
-            await self.transaction.__aexit__(type_, value, traceback)
-            await self.conn.close()
+            async def go() -> None:
+                await self.transaction.__aexit__(type_, value, traceback)
+                await self.conn.close()
+
+            await asyncio.shield(go())
 
     def __init__(self, sync_engine: Engine):
         if not sync_engine.dialect.is_async:
@@ -956,7 +960,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
 
         """
 
-        return await greenlet_spawn(self.sync_engine.dispose, close=close)
+        await greenlet_spawn(self.sync_engine.dispose, close=close)
 
     # START PROXY METHODS AsyncEngine
 
index be3414cef4aa39f13194a3bdf0cc8bd3f6170ee7..14e08b5c53b46ffd29a50f6a258260ad3c24f1f9 100644 (file)
@@ -6,6 +6,7 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 from __future__ import annotations
 
+import asyncio
 from typing import Any
 from typing import Dict
 from typing import Generic
@@ -837,7 +838,7 @@ class AsyncSession(ReversibleProxy[Session]):
             :meth:`_asyncio.AsyncSession.close`
 
         """
-        return await greenlet_spawn(self.sync_session.close)
+        await greenlet_spawn(self.sync_session.close)
 
     async def invalidate(self) -> None:
         """Close this Session, using connection invalidation.
@@ -855,7 +856,7 @@ class AsyncSession(ReversibleProxy[Session]):
         return self
 
     async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
-        await self.close()
+        await asyncio.shield(self.close())
 
     def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]:
         return _AsyncSessionContextManager(self)
@@ -1516,8 +1517,11 @@ class _AsyncSessionContextManager(Generic[_AS]):
         return self.async_session
 
     async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
-        await self.trans.__aexit__(type_, value, traceback)
-        await self.async_session.__aexit__(type_, value, traceback)
+        async def go() -> None:
+            await self.trans.__aexit__(type_, value, traceback)
+            await self.async_session.__aexit__(type_, value, traceback)
+
+        await asyncio.shield(go())
 
 
 class AsyncSessionTransaction(