]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure correct lock type propagated in pool recreate
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jan 2024 21:54:58 +0000 (16:54 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jan 2024 22:20:19 +0000 (17:20 -0500)
Fixed critical issue in asyncio version of the connection pool where
calling :meth:`_asyncio.AsyncEngine.dispose` would produce a new connection
pool that did not fully re-establish the use of asyncio-compatible mutexes,
leading to the use of a plain ``threading.Lock()`` which would then cause
deadlocks in an asyncio context when using concurrency features like
``asyncio.gather()``.

Fixes: #10813
Change-Id: I95ec698b6a1ba79555aa0b28e6bce65fedf3b1fe

doc/build/changelog/unreleased_14/10813.rst [new file with mode: 0644]
lib/sqlalchemy/event/attr.py
test/ext/asyncio/test_engine_py3k.py

diff --git a/doc/build/changelog/unreleased_14/10813.rst b/doc/build/changelog/unreleased_14/10813.rst
new file mode 100644 (file)
index 0000000..d4f72d8
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 10813
+    :versions: 1.4.51, 2.0.25
+
+    Fixed critical issue in asyncio version of the connection pool where
+    calling :meth:`_asyncio.AsyncEngine.dispose` would produce a new connection
+    pool that did not fully re-establish the use of asyncio-compatible mutexes,
+    leading to the use of a plain ``threading.Lock()`` which would then cause
+    deadlocks in an asyncio context when using concurrency features like
+    ``asyncio.gather()``.
index 2a5fccba20209f37fe66456dc04e8cebc5cb8912..585553f629d67e1dcbc53955ed7fad115aa3ef02 100644 (file)
@@ -404,7 +404,12 @@ class _MutexProtocol(Protocol):
 
 
 class _CompoundListener(_InstanceLevelDispatch[_ET]):
-    __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once"
+    __slots__ = (
+        "_exec_once_mutex",
+        "_exec_once",
+        "_exec_w_sync_once",
+        "_is_asyncio",
+    )
 
     _exec_once_mutex: _MutexProtocol
     parent_listeners: Collection[_ListenerFnType]
@@ -412,11 +417,18 @@ class _CompoundListener(_InstanceLevelDispatch[_ET]):
     _exec_once: bool
     _exec_w_sync_once: bool
 
+    def __init__(self, *arg: Any, **kw: Any):
+        super().__init__(*arg, **kw)
+        self._is_asyncio = False
+
     def _set_asyncio(self) -> None:
-        self._exec_once_mutex = AsyncAdaptedLock()
+        self._is_asyncio = True
 
     def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol:
-        return threading.Lock()
+        if self._is_asyncio:
+            return AsyncAdaptedLock()
+        else:
+            return threading.Lock()
 
     def _exec_once_impl(
         self, retry_on_exception: bool, *args: Any, **kw: Any
@@ -525,6 +537,7 @@ class _ListenerCollection(_CompoundListener[_ET]):
     propagate: Set[_ListenerFnType]
 
     def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]):
+        super().__init__()
         if target_cls not in parent._clslevel:
             parent.update_subclass(target_cls)
         self._exec_once = False
@@ -564,6 +577,9 @@ class _ListenerCollection(_CompoundListener[_ET]):
 
         existing_listeners.extend(other_listeners)
 
+        if other._is_asyncio:
+            self._set_asyncio()
+
         to_associate = other.propagate.union(other_listeners)
         registry._stored_in_collection_multi(self, other, to_associate)
 
index adb6b0b6c9d08b31cd9c44c82aaf9a05b9779396..5ca465906a897210dd4360986ca359a5692216be 100644 (file)
@@ -1396,3 +1396,23 @@ class AsyncProxyTest(EngineFixture, fixtures.TestBase):
 
         async_t2 = async_conn.get_transaction()
         is_(async_t1, async_t2)
+
+
+class PoolRegenTest(EngineFixture):
+    @testing.requires.queue_pool
+    @async_test
+    @testing.variation("do_dispose", [True, False])
+    async def test_gather_after_dispose(self, testing_engine, do_dispose):
+        engine = testing_engine(
+            asyncio=True, options=dict(pool_size=10, max_overflow=10)
+        )
+
+        async def thing(engine):
+            async with engine.connect() as conn:
+                await conn.exec_driver_sql("select 1")
+
+        if do_dispose:
+            await engine.dispose()
+
+        tasks = [thing(engine) for _ in range(10)]
+        await asyncio.gather(*tasks)