from .. import exc
from .. import util
from ..util import threading
+from ..util.concurrency import AsyncAdaptedLock
class RefCollection(util.MemoizedSlots):
class _CompoundListener(_InstanceLevelDispatch):
__slots__ = "_exec_once_mutex", "_exec_once"
+ def _set_asyncio(self):
+ self._exec_once_mutex = AsyncAdaptedLock()
+
def _memoized_attr__exec_once_mutex(self):
return threading.Lock()
return target
@classmethod
- def _listen(cls, event_key, propagate=False, insert=False, named=False):
- event_key.base_listen(propagate=propagate, insert=insert, named=named)
+ def _listen(
+ cls,
+ event_key,
+ propagate=False,
+ insert=False,
+ named=False,
+ asyncio=False,
+ ):
+ event_key.base_listen(
+ propagate=propagate, insert=insert, named=named, asyncio=asyncio
+ )
@classmethod
def _remove(cls, event_key):
return self._key in _key_to_collection
def base_listen(
- self, propagate=False, insert=False, named=False, retval=None
+ self,
+ propagate=False,
+ insert=False,
+ named=False,
+ retval=None,
+ asyncio=False,
):
target, identifier = self.dispatch_target, self.identifier
dispatch_collection = getattr(target.dispatch, identifier)
+ for_modify = dispatch_collection.for_modify(target.dispatch)
+ if asyncio:
+ for_modify._set_asyncio()
+
if insert:
- dispatch_collection.for_modify(target.dispatch).insert(
- self, propagate
- )
+ for_modify.insert(self, propagate)
else:
- dispatch_collection.for_modify(target.dispatch).append(
- self, propagate
- )
+ for_modify.append(self, propagate)
@property
def _listen_fn(self):
_dialect = _ConnDialect()
+ _is_asyncio = False
+
def __init__(
self,
creator,
else:
return target
+ @classmethod
+ def _listen(cls, event_key, **kw):
+ target = event_key.dispatch_target
+
+ event_key.base_listen(asyncio=target._is_asyncio)
+
def connect(self, dbapi_connection, connection_record):
"""Called at the moment a particular DBAPI connection is first
created for a given :class:`_pool.Pool`.
class AsyncAdaptedQueuePool(QueuePool):
+ _is_asyncio = True
_queue_class = sqla_queue.AsyncAdaptedQueue
del context.driver
return result
+ class AsyncAdaptedLock:
+ def __init__(self):
+ self.mutex = asyncio.Lock()
+
+ def __enter__(self):
+ await_fallback(self.mutex.acquire())
+ return self
+
+ def __exit__(self, *arg, **kw):
+ self.mutex.release()
+
except ImportError: # pragma: no cover
greenlet = None
from ._concurrency_py3k import await_fallback
from ._concurrency_py3k import greenlet
from ._concurrency_py3k import greenlet_spawn
+ from ._concurrency_py3k import AsyncAdaptedLock
else:
asyncio = None
greenlet = None
def greenlet_spawn(fn, *args, **kw):
raise ValueError("Cannot use this function in py2.")
+
+ def AsyncAdaptedLock(*args, **kw):
+ raise ValueError("Cannot use this function in py2.")
class AsyncEngineTest(EngineFixture):
__backend__ = True
+ @async_test
+ async def test_init_once_concurrency(self, async_engine):
+ c1 = async_engine.connect()
+ c2 = async_engine.connect()
+ await asyncio.wait([c1, c2])
+
@async_test
async def test_connect_ctxmanager(self, async_engine):
async with async_engine.connect() as conn: