]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Propagate asyncio flag from the dialect to selected pool classes
authorFederico Caselli <cfederico87@gmail.com>
Thu, 3 Jun 2021 20:38:15 +0000 (22:38 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 8 Jun 2021 20:02:42 +0000 (22:02 +0200)
Fixed an issue that presented itself when using the :class:`_pool.NullPool`
or the :class:`_pool.StaticPool` with an async engine. This mostly affected
the aiosqlite dialect.

Fixes: #6575
Change-Id: Ic1e27d99ffcb20ed4de82ea78f430a0f3b629d86

doc/build/changelog/unreleased_14/6575.rst [new file with mode: 0644]
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/util/langhelpers.py
test/engine/test_pool.py
test/ext/asyncio/test_engine_py3k.py

diff --git a/doc/build/changelog/unreleased_14/6575.rst b/doc/build/changelog/unreleased_14/6575.rst
new file mode 100644 (file)
index 0000000..ee3ac7d
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, engine, asyncio
+    :tickets: 6285
+
+    Fixed an issue that presented itself when using the :class:`_pool.NullPool`
+    or the :class:`_pool.StaticPool` with an async engine. This mostly affected
+    the aiosqlite dialect.
index e2ed538003762005493378f0b5e365d6343a73e6..8a3abb82fad2033849255e659228a518782b6bf1 100644 (file)
@@ -53,14 +53,16 @@ class _ConnDialect(object):
         )
 
 
+class _AsyncConnDialect(_ConnDialect):
+    is_async = True
+
+
 class Pool(log.Identified):
 
     """Abstract base class for connection pools."""
 
     _dialect = _ConnDialect()
 
-    _is_asyncio = False
-
     def __init__(
         self,
         creator,
@@ -196,6 +198,10 @@ class Pool(log.Identified):
             for fn, target in events:
                 event.listen(self, target, fn)
 
+    @util.hybridproperty
+    def _is_asyncio(self):
+        return self._dialect.is_async
+
     @property
     def _creator(self):
         return self.__dict__["_creator"]
index 730293273adaa94315a23fd4b204c85179aab1d6..99d0c94d23e67db355381ac490d677e9922b63fb 100644 (file)
@@ -13,7 +13,7 @@
 import traceback
 import weakref
 
-from .base import _ConnDialect
+from .base import _AsyncConnDialect
 from .base import _ConnectionFairy
 from .base import _ConnectionRecord
 from .base import Pool
@@ -34,6 +34,7 @@ class QueuePool(Pool):
 
     """
 
+    _is_asyncio = False
     _queue_class = sqla_queue.Queue
 
     def __init__(
@@ -222,10 +223,6 @@ class QueuePool(Pool):
         return self._pool.maxsize - self._pool.qsize() + self._overflow
 
 
-class _AsyncConnDialect(_ConnDialect):
-    is_async = True
-
-
 class AsyncAdaptedQueuePool(QueuePool):
     _is_asyncio = True
     _queue_class = sqla_queue.AsyncAdaptedQueue
@@ -307,6 +304,8 @@ class SingletonThreadPool(Pool):
 
     """
 
+    _is_asyncio = False
+
     def __init__(self, creator, pool_size=5, **kw):
         Pool.__init__(self, creator, **kw)
         self._conn = threading.local()
index 1308ee7e0619c9f69ae02434890f98e9ca3e5ed0..e506b7529a9a84a63ab2c974299fbef0ca857a35 100644 (file)
@@ -1440,7 +1440,6 @@ class hybridproperty(object):
     def __get__(self, instance, owner):
         if instance is None:
             clsval = self.clslevel(owner)
-            clsval.__doc__ = self.func.__doc__
             return clsval
         else:
             return self.func(instance)
index 5b6dcfa45cbb557cefd31116be137ffe66e64d21..70671134f1f588853022851892fca7b4ed2a1c0b 100644 (file)
@@ -10,7 +10,8 @@ from sqlalchemy import pool
 from sqlalchemy import select
 from sqlalchemy import testing
 from sqlalchemy.engine import default
-from sqlalchemy.pool.impl import _AsyncConnDialect
+from sqlalchemy.pool.base import _AsyncConnDialect
+from sqlalchemy.pool.base import _ConnDialect
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_context_ok
 from sqlalchemy.testing import assert_raises_message
@@ -280,6 +281,39 @@ class PoolTest(PoolTestBase):
         if "use_lifo" in pool_args:
             eq_(p1._pool.use_lifo, p2._pool.use_lifo)
 
+    @testing.combinations(
+        (pool.QueuePool, False),
+        (pool.AsyncAdaptedQueuePool, True),
+        (pool.FallbackAsyncAdaptedQueuePool, True),
+        (pool.NullPool, None),
+        (pool.SingletonThreadPool, False),
+        (pool.StaticPool, None),
+        (pool.AssertionPool, None),
+    )
+    def test_is_asyncio_from_dialect(self, pool_cls, is_async_king):
+        p = pool_cls(creator=object())
+        for is_async in (True, False):
+            if is_async:
+                p._dialect = _AsyncConnDialect()
+            else:
+                p._dialect = _ConnDialect
+            if is_async_king is None:
+                eq_(p._is_asyncio, is_async)
+            else:
+                eq_(p._is_asyncio, is_async_king)
+
+    @testing.combinations(
+        (pool.QueuePool, False),
+        (pool.AsyncAdaptedQueuePool, True),
+        (pool.FallbackAsyncAdaptedQueuePool, True),
+        (pool.NullPool, False),
+        (pool.SingletonThreadPool, False),
+        (pool.StaticPool, False),
+        (pool.AssertionPool, False),
+    )
+    def test_is_asyncio_from_dialect_cls(self, pool_cls, is_async):
+        eq_(pool_cls._is_asyncio, is_async)
+
 
 class PoolDialectTest(PoolTestBase):
     def _dialect(self):
index d47ef5f3ffb2f5882c233ee3c4ec88fc279b3722..fec8bc6da1e78c657cf4805dc4f1da46ee15e214 100644 (file)
@@ -230,14 +230,8 @@ class AsyncEngineTest(EngineFixture):
 
         is_false(async_engine == None)
 
-    # NOTE: this test currently causes the test suite to hang; it previously
-    # was not actually running the worker thread
-    # as the testing_engine() fixture
-    # was rejecting the "transfer_staticpool" keyword argument
     @async_test
-    async def temporarily_dont_test_no_attach_to_event_loop(
-        self, testing_engine
-    ):
+    async def test_no_attach_to_event_loop(self, testing_engine):
         """test #6409"""
 
         import asyncio
@@ -249,12 +243,11 @@ class AsyncEngineTest(EngineFixture):
             loop = asyncio.new_event_loop()
             asyncio.set_event_loop(loop)
 
-            engine = testing_engine(asyncio=True, transfer_staticpool=True)
-
             async def main():
                 tasks = [task() for _ in range(2)]
 
                 await asyncio.gather(*tasks)
+                await engine.dispose()
 
             async def task():
                 async with engine.begin() as connection:
@@ -262,6 +255,10 @@ class AsyncEngineTest(EngineFixture):
                     result.all()
 
             try:
+                engine = testing_engine(
+                    asyncio=True, transfer_staticpool=False
+                )
+
                 asyncio.run(main())
             except Exception as err:
                 errs.append(err)