]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Adapt event exec_once_mutex to asyncio
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Sep 2020 12:04:09 +0000 (08:04 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Sep 2020 12:04:09 +0000 (08:04 -0400)
The pool makes use of a threading.Lock() for the
"first_connect" event.  if the pool is async make sure this
is a greenlet-adapted asyncio lock.

Fixes: #5581
Change-Id: If52415839c7ed82135465f1fe93b95d86c305820

lib/sqlalchemy/event/attr.py
lib/sqlalchemy/event/base.py
lib/sqlalchemy/event/registry.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/pool/events.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/util/_concurrency_py3k.py
lib/sqlalchemy/util/concurrency.py
test/ext/asyncio/test_engine_py3k.py

index 87c6e980f8abeb07edea58ad4f4e167f6d531697..abb264f98f154c43afd1ea4ed9e80f022a8f8908 100644 (file)
@@ -41,6 +41,7 @@ from . import registry
 from .. import exc
 from .. import util
 from ..util import threading
+from ..util.concurrency import AsyncAdaptedLock
 
 
 class RefCollection(util.MemoizedSlots):
@@ -277,6 +278,9 @@ class _EmptyListener(_InstanceLevelDispatch):
 class _CompoundListener(_InstanceLevelDispatch):
     __slots__ = "_exec_once_mutex", "_exec_once"
 
+    def _set_asyncio(self):
+        self._exec_once_mutex = AsyncAdaptedLock()
+
     def _memoized_attr__exec_once_mutex(self):
         return threading.Lock()
 
index a87c1fe4484962d77911b5278307284f79cd49a9..c78080738f9d261893e83e34395da46e4c99aa0f 100644 (file)
@@ -241,8 +241,17 @@ class Events(util.with_metaclass(_EventMeta, object)):
                 return target
 
     @classmethod
-    def _listen(cls, event_key, propagate=False, insert=False, named=False):
-        event_key.base_listen(propagate=propagate, insert=insert, named=named)
+    def _listen(
+        cls,
+        event_key,
+        propagate=False,
+        insert=False,
+        named=False,
+        asyncio=False,
+    ):
+        event_key.base_listen(
+            propagate=propagate, insert=insert, named=named, asyncio=asyncio
+        )
 
     @classmethod
     def _remove(cls, event_key):
index 19b9174b71c41e7621df3fac8460add8a1caee77..144dd45dc6e127e46f4ea7a57ad84933ca4f8cfc 100644 (file)
@@ -244,21 +244,26 @@ class _EventKey(object):
         return self._key in _key_to_collection
 
     def base_listen(
-        self, propagate=False, insert=False, named=False, retval=None
+        self,
+        propagate=False,
+        insert=False,
+        named=False,
+        retval=None,
+        asyncio=False,
     ):
 
         target, identifier = self.dispatch_target, self.identifier
 
         dispatch_collection = getattr(target.dispatch, identifier)
 
+        for_modify = dispatch_collection.for_modify(target.dispatch)
+        if asyncio:
+            for_modify._set_asyncio()
+
         if insert:
-            dispatch_collection.for_modify(target.dispatch).insert(
-                self, propagate
-            )
+            for_modify.insert(self, propagate)
         else:
-            dispatch_collection.for_modify(target.dispatch).append(
-                self, propagate
-            )
+            for_modify.append(self, propagate)
 
     @property
     def _listen_fn(self):
index f20b63cf54e1c3acd82099a62a255d292f1eb530..87383fef717a6a1d76b44e4d625d6eb405a9f431 100644 (file)
@@ -59,6 +59,8 @@ class Pool(log.Identified):
 
     _dialect = _ConnDialect()
 
+    _is_asyncio = False
+
     def __init__(
         self,
         creator,
index 3954f907f46236dba8c0d296e9d82e15e92b6467..9443877a91843eba8f6cedfc90f416fb3fae7cac 100644 (file)
@@ -54,6 +54,12 @@ class PoolEvents(event.Events):
         else:
             return target
 
+    @classmethod
+    def _listen(cls, event_key, **kw):
+        target = event_key.dispatch_target
+
+        event_key.base_listen(asyncio=target._is_asyncio)
+
     def connect(self, dbapi_connection, connection_record):
         """Called at the moment a particular DBAPI connection is first
         created for a given :class:`_pool.Pool`.
index e1a9f00db186e687de442d551c63372281aeeb98..ffdd63671aaa88d843ab677dab92e54d96d218f4 100644 (file)
@@ -218,6 +218,7 @@ class QueuePool(Pool):
 
 
 class AsyncAdaptedQueuePool(QueuePool):
+    _is_asyncio = True
     _queue_class = sqla_queue.AsyncAdaptedQueue
 
 
index 3b112ff7db4c6f33be7e4a9a3183e4f746f303fa..82125b7713080d8a12fd5d82bd556a1652451881 100644 (file)
@@ -96,6 +96,17 @@ try:
             del context.driver
         return result
 
+    class AsyncAdaptedLock:
+        def __init__(self):
+            self.mutex = asyncio.Lock()
+
+        def __enter__(self):
+            await_fallback(self.mutex.acquire())
+            return self
+
+        def __exit__(self, *arg, **kw):
+            self.mutex.release()
+
 
 except ImportError:  # pragma: no cover
     greenlet = None
index 4c4ea20d12afc5c1cf81d6751299b14e5973d891..e0883aa6835fc67a6d0eae6a67518527d67b91bf 100644 (file)
@@ -7,6 +7,7 @@ if compat.py3k:
     from ._concurrency_py3k import await_fallback
     from ._concurrency_py3k import greenlet
     from ._concurrency_py3k import greenlet_spawn
+    from ._concurrency_py3k import AsyncAdaptedLock
 else:
     asyncio = None
     greenlet = None
@@ -19,3 +20,6 @@ else:
 
     def greenlet_spawn(fn, *args, **kw):
         raise ValueError("Cannot use this function in py2.")
+
+    def AsyncAdaptedLock(*args, **kw):
+        raise ValueError("Cannot use this function in py2.")
index 59c47c4032fad65d273ff3944bda71d232ba902b..a5d167c2e73f86222283318044102c7e41a59216 100644 (file)
@@ -50,6 +50,12 @@ class EngineFixture(fixtures.TablesTest):
 class AsyncEngineTest(EngineFixture):
     __backend__ = True
 
+    @async_test
+    async def test_init_once_concurrency(self, async_engine):
+        c1 = async_engine.connect()
+        c2 = async_engine.connect()
+        await asyncio.wait([c1, c2])
+
     @async_test
     async def test_connect_ctxmanager(self, async_engine):
         async with async_engine.connect() as conn: