From: Federico Caselli Date: Thu, 3 Jun 2021 20:38:15 +0000 (+0200) Subject: Propagate asyncio flag from the dialect to selected pool classes X-Git-Tag: rel_1_4_18~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d200ba26a0f5b8542ec258d2fcfe0b53a80af42c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Propagate asyncio flag from the dialect to selected pool classes Fixed an issue that presented itself when using the :class:`_pool.NullPool` or the :class:`_pool.StaticPool` with an async engine. This mostly affected the aiosqlite dialect. Fixes: #6575 Change-Id: Ic1e27d99ffcb20ed4de82ea78f430a0f3b629d86 --- diff --git a/doc/build/changelog/unreleased_14/6575.rst b/doc/build/changelog/unreleased_14/6575.rst new file mode 100644 index 0000000000..ee3ac7d89a --- /dev/null +++ b/doc/build/changelog/unreleased_14/6575.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, engine, asyncio + :tickets: 6285 + + Fixed an issue that presented itself when using the :class:`_pool.NullPool` + or the :class:`_pool.StaticPool` with an async engine. This mostly affected + the aiosqlite dialect. diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index e2ed538003..8a3abb82fa 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -53,14 +53,16 @@ class _ConnDialect(object): ) +class _AsyncConnDialect(_ConnDialect): + is_async = True + + class Pool(log.Identified): """Abstract base class for connection pools.""" _dialect = _ConnDialect() - _is_asyncio = False - def __init__( self, creator, @@ -196,6 +198,10 @@ class Pool(log.Identified): for fn, target in events: event.listen(self, target, fn) + @util.hybridproperty + def _is_asyncio(self): + return self._dialect.is_async + @property def _creator(self): return self.__dict__["_creator"] diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 730293273a..99d0c94d23 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -13,7 +13,7 @@ import traceback import weakref -from .base import _ConnDialect +from .base import _AsyncConnDialect from .base import _ConnectionFairy from .base import _ConnectionRecord from .base import Pool @@ -34,6 +34,7 @@ class QueuePool(Pool): """ + _is_asyncio = False _queue_class = sqla_queue.Queue def __init__( @@ -222,10 +223,6 @@ class QueuePool(Pool): return self._pool.maxsize - self._pool.qsize() + self._overflow -class _AsyncConnDialect(_ConnDialect): - is_async = True - - class AsyncAdaptedQueuePool(QueuePool): _is_asyncio = True _queue_class = sqla_queue.AsyncAdaptedQueue @@ -307,6 +304,8 @@ class SingletonThreadPool(Pool): """ + _is_asyncio = False + def __init__(self, creator, pool_size=5, **kw): Pool.__init__(self, creator, **kw) self._conn = threading.local() diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 1308ee7e06..e506b7529a 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1440,7 +1440,6 @@ class hybridproperty(object): def __get__(self, instance, owner): if instance is None: clsval = self.clslevel(owner) - clsval.__doc__ = self.func.__doc__ return clsval else: return self.func(instance) diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 5b6dcfa45c..70671134f1 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -10,7 +10,8 @@ from sqlalchemy import pool from sqlalchemy import select from sqlalchemy import testing from sqlalchemy.engine import default -from sqlalchemy.pool.impl import _AsyncConnDialect +from sqlalchemy.pool.base import _AsyncConnDialect +from sqlalchemy.pool.base import _ConnDialect from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_context_ok from sqlalchemy.testing import assert_raises_message @@ -280,6 +281,39 @@ class PoolTest(PoolTestBase): if "use_lifo" in pool_args: eq_(p1._pool.use_lifo, p2._pool.use_lifo) + @testing.combinations( + (pool.QueuePool, False), + (pool.AsyncAdaptedQueuePool, True), + (pool.FallbackAsyncAdaptedQueuePool, True), + (pool.NullPool, None), + (pool.SingletonThreadPool, False), + (pool.StaticPool, None), + (pool.AssertionPool, None), + ) + def test_is_asyncio_from_dialect(self, pool_cls, is_async_king): + p = pool_cls(creator=object()) + for is_async in (True, False): + if is_async: + p._dialect = _AsyncConnDialect() + else: + p._dialect = _ConnDialect + if is_async_king is None: + eq_(p._is_asyncio, is_async) + else: + eq_(p._is_asyncio, is_async_king) + + @testing.combinations( + (pool.QueuePool, False), + (pool.AsyncAdaptedQueuePool, True), + (pool.FallbackAsyncAdaptedQueuePool, True), + (pool.NullPool, False), + (pool.SingletonThreadPool, False), + (pool.StaticPool, False), + (pool.AssertionPool, False), + ) + def test_is_asyncio_from_dialect_cls(self, pool_cls, is_async): + eq_(pool_cls._is_asyncio, is_async) + class PoolDialectTest(PoolTestBase): def _dialect(self): diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index d47ef5f3ff..fec8bc6da1 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -230,14 +230,8 @@ class AsyncEngineTest(EngineFixture): is_false(async_engine == None) - # NOTE: this test currently causes the test suite to hang; it previously - # was not actually running the worker thread - # as the testing_engine() fixture - # was rejecting the "transfer_staticpool" keyword argument @async_test - async def temporarily_dont_test_no_attach_to_event_loop( - self, testing_engine - ): + async def test_no_attach_to_event_loop(self, testing_engine): """test #6409""" import asyncio @@ -249,12 +243,11 @@ class AsyncEngineTest(EngineFixture): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - engine = testing_engine(asyncio=True, transfer_staticpool=True) - async def main(): tasks = [task() for _ in range(2)] await asyncio.gather(*tasks) + await engine.dispose() async def task(): async with engine.begin() as connection: @@ -262,6 +255,10 @@ class AsyncEngineTest(EngineFixture): result.all() try: + engine = testing_engine( + asyncio=True, transfer_staticpool=False + ) + asyncio.run(main()) except Exception as err: errs.append(err)