]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improved usage of ``asyncio.shield()``
authorFederico Caselli <cfederico87@gmail.com>
Tue, 13 Sep 2022 19:23:12 +0000 (21:23 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Sep 2022 13:02:33 +0000 (09:02 -0400)
Fixes: #8516
Change-Id: Ifd8f5e5f42d9fbcd5b8d00bddc81ff6be690a75e

doc/build/changelog/unreleased_14/8516.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py

diff --git a/doc/build/changelog/unreleased_14/8516.rst b/doc/build/changelog/unreleased_14/8516.rst
new file mode 100644 (file)
index 0000000..2f83586
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 8516
+
+    Improved implementation of ``asyncio.shield()`` used in context managers as
+    added in :ticket:`8145`, such that the "close" operation is enclosed within
+    an ``asyncio.Task`` which is then strongly referenced as the operation
+    proceeds. This is per Python documentation indicating that the task is
+    otherwise not strongly referenced.
index e8ac10a3d0861db456c5c4664284b71687b3a82a..4d0c872b301d00ebed2b8f90fce70ad2a5c56b4e 100644 (file)
@@ -694,7 +694,8 @@ class AsyncConnection(
         return self.start().__await__()
 
     async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
-        await asyncio.shield(self.close())
+        task = asyncio.create_task(self.close())
+        await asyncio.shield(task)
 
     # START PROXY METHODS AsyncConnection
 
@@ -860,7 +861,8 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
                 await self.transaction.__aexit__(type_, value, traceback)
                 await self.conn.close()
 
-            await asyncio.shield(go())
+            task = asyncio.create_task(go())
+            await asyncio.shield(task)
 
     def __init__(self, sync_engine: Engine):
         if not sync_engine.dialect.is_async:
index d4b6b6d50aa7a13daa139cafd34c719a23f97885..60b77f3ea63ea90304acd0a63449b36db112e691 100644 (file)
@@ -851,7 +851,8 @@ class AsyncSession(ReversibleProxy[Session]):
         return self
 
     async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
-        await asyncio.shield(self.close())
+        task = asyncio.create_task(self.close())
+        await asyncio.shield(task)
 
     def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]:
         return _AsyncSessionContextManager(self)
@@ -1516,7 +1517,8 @@ class _AsyncSessionContextManager(Generic[_AS]):
             await self.trans.__aexit__(type_, value, traceback)
             await self.async_session.__aexit__(type_, value, traceback)
 
-        await asyncio.shield(go())
+        task = asyncio.create_task(go())
+        await asyncio.shield(task)
 
 
 class AsyncSessionTransaction(