]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Repair async test refactor
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Jan 2021 15:55:21 +0000 (10:55 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 3 Jan 2021 00:08:25 +0000 (19:08 -0500)
in I4940d184a4dc790782fcddfb9873af3cca844398 we reworked how async
tests run but apparently the async tests in test/ext/asyncio
are reporting success without being run.   This patch pushes
pytestplugin further so that it won't instrument any test
or function overall that declares itself async. This removes
the need for the __async_wrap__ flag and also allows us to
use a more strict "run_async_test" function that always
runs the asyncio event loop from the top.

Also start working asyncio into main testing suite.

Change-Id: If7144e951a9db67eb7ea73b377f81c4440d39819

lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/testing/asyncio.py
lib/sqlalchemy/testing/engines.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/util/_concurrency_py3k.py
lib/sqlalchemy/util/concurrency.py
test/base/test_concurrency_py3k.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py

index 16edcc2b2a6c4e50d7c833b96dcb8f35087c9daa..93adaf78ab0643000b4b09b3d381b9ee39009e7d 100644 (file)
@@ -41,7 +41,7 @@ def create_async_engine(*arg, **kw):
 
 
 class AsyncConnectable:
-    __slots__ = "_slots_dispatch"
+    __slots__ = "_slots_dispatch", "__weakref__"
 
 
 @util.create_proxy_methods(
index 52386d33e19c4f803020dde374d4b8e3efebfd43..bdf730a4c1534155f50d6a1d1567d9f80282fc0a 100644 (file)
@@ -22,12 +22,17 @@ import inspect
 
 from . import config
 from ..util.concurrency import _util_async_run
+from ..util.concurrency import _util_async_run_coroutine_function
 
 # may be set to False if the
 # --disable-asyncio flag is passed to the test runner.
 ENABLE_ASYNCIO = True
 
 
+def _run_coroutine_function(fn, *args, **kwargs):
+    return _util_async_run_coroutine_function(fn, *args, **kwargs)
+
+
 def _assume_async(fn, *args, **kwargs):
     """Run a function in an asyncio loop unconditionally.
 
index d0a1bc0d0d1ade7eb26ce93b7c370554dbf6ecb5..4d4563afb0c0dd82dfd4250b186fded05c760472 100644 (file)
@@ -97,7 +97,10 @@ class ConnectionKiller(object):
 
         self.conns = set()
         for rec in list(self.testing_engines):
-            rec.dispose()
+            if hasattr(rec, "sync_engine"):
+                rec.sync_engine.dispose()
+            else:
+                rec.dispose()
 
     def assert_all_closed(self):
         for rec in self.proxy_refs:
@@ -236,10 +239,12 @@ def reconnecting_engine(url=None, options=None):
     return engine
 
 
-def testing_engine(url=None, options=None, future=False):
+def testing_engine(url=None, options=None, future=False, asyncio=False):
     """Produce an engine configured by --options with optional overrides."""
 
-    if future or config.db and config.db._is_future:
+    if asyncio:
+        from sqlalchemy.ext.asyncio import create_async_engine as create_engine
+    elif future or config.db and config.db._is_future:
         from sqlalchemy.future import create_engine
     else:
         from sqlalchemy import create_engine
@@ -263,7 +268,10 @@ def testing_engine(url=None, options=None, future=False):
         default_opt.update(options)
 
     engine = create_engine(url, **options)
-    engine._has_events = True  # enable event blocks, helps with profiling
+    if asyncio:
+        engine.sync_engine._has_events = True
+    else:
+        engine._has_events = True  # enable event blocks, helps with profiling
 
     if isinstance(engine.pool, pool.QueuePool):
         engine.pool._timeout = 0
index a52fdd1967783e54165ab29fc88ee4af8f96a4a4..0ede25176a4c7bf439ec488b28378e36e61fd40e 100644 (file)
@@ -48,11 +48,6 @@ class TestBase(object):
     # skipped.
     __skip_if__ = None
 
-    # If this class should be wrapped in asyncio compatibility functions
-    # when using an async engine. This should be set to False only for tests
-    # that use the asyncio features of sqlalchemy directly
-    __asyncio_wrap__ = True
-
     def assert_(self, val, msg=None):
         assert val, msg
 
@@ -95,12 +90,6 @@ class TestBase(object):
     #       engines.drop_all_tables(metadata, config.db)
 
 
-class AsyncTestBase(TestBase):
-    """Mixin marking a test as using its own explicit asyncio patterns."""
-
-    __asyncio_wrap__ = False
-
-
 class FutureEngineMixin(object):
     @classmethod
     def setup_class(cls):
index 6be64aa6106b30f4b03339707a8dc6e696b961b8..46468a07dcb70ab55f4e000fdd2b1c5acbfe8413 100644 (file)
@@ -255,7 +255,7 @@ def pytest_pycollect_makeitem(collector, name, obj):
     if inspect.isclass(obj) and plugin_base.want_class(name, obj):
         from sqlalchemy.testing import config
 
-        if config.any_async and getattr(obj, "__asyncio_wrap__", True):
+        if config.any_async:
             obj = _apply_maybe_async(obj)
 
         ctor = getattr(pytest.Class, "from_parent", pytest.Class)
@@ -277,6 +277,13 @@ def pytest_pycollect_makeitem(collector, name, obj):
         return []
 
 
+def _is_wrapped_coroutine_function(fn):
+    while hasattr(fn, "__wrapped__"):
+        fn = fn.__wrapped__
+
+    return inspect.iscoroutinefunction(fn)
+
+
 def _apply_maybe_async(obj, recurse=True):
     from sqlalchemy.testing import asyncio
 
@@ -286,6 +293,7 @@ def _apply_maybe_async(obj, recurse=True):
             (callable(value) or isinstance(value, classmethod))
             and not getattr(value, "_maybe_async_applied", False)
             and (name.startswith("test_") or name in setup_names)
+            and not _is_wrapped_coroutine_function(value)
         ):
             is_classmethod = False
             if isinstance(value, classmethod):
@@ -656,6 +664,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
 
         @_pytest_fn_decorator
         def decorate(fn, *args, **kwargs):
-            asyncio._assume_async(fn, *args, **kwargs)
+            asyncio._run_coroutine_function(fn, *args, **kwargs)
 
         return decorate(fn)
index 6042e4395a6eed1964e59f486f0ee3a70c116121..663d3e0f42cca07b028a02932d0a495a22b36c05 100644 (file)
@@ -136,6 +136,18 @@ class AsyncAdaptedLock:
         self.mutex.release()
 
 
+def _util_async_run_coroutine_function(fn, *args, **kwargs):
+    """for test suite/ util only"""
+
+    loop = asyncio.get_event_loop()
+    if loop.is_running():
+        raise Exception(
+            "for async run coroutine we expect that no greenlet or event "
+            "loop is running when we start out"
+        )
+    return loop.run_until_complete(fn(*args, **kwargs))
+
+
 def _util_async_run(fn, *args, **kwargs):
     """for test suite/ util only"""
 
index 7b4ff6ba40c2aae8070ebc9f647821b908c7f450..c44efba6202ba301b9d17f07623da3a670205d13 100644 (file)
@@ -14,6 +14,9 @@ if compat.py3k:
         from ._concurrency_py3k import greenlet_spawn
         from ._concurrency_py3k import AsyncAdaptedLock
         from ._concurrency_py3k import _util_async_run  # noqa F401
+        from ._concurrency_py3k import (
+            _util_async_run_coroutine_function,
+        )  # noqa F401, E501
         from ._concurrency_py3k import asyncio  # noqa F401
 
 if not have_greenlet:
@@ -42,3 +45,6 @@ if not have_greenlet:
 
     def _util_async_run(fn, *arg, **kw):  # noqa F81
         return fn(*arg, **kw)
+
+    def _util_async_run_coroutine_function(fn, *arg, **kw):  # noqa F81
+        _not_implemented()
index 2cc2075bcd6b9e78f01477281b8231b1ee129dc6..e7ae8c9ad203ffd8aadb4c0bfbde2a1b78301958 100644 (file)
@@ -26,7 +26,7 @@ def go(*fns):
     return sum(await_only(fn()) for fn in fns)
 
 
-class TestAsyncioCompat(fixtures.AsyncTestBase):
+class TestAsyncioCompat(fixtures.TestBase):
     @async_test
     async def test_ok(self):
 
@@ -53,7 +53,8 @@ class TestAsyncioCompat(fixtures.AsyncTestBase):
         to_await = run1()
         await_fallback(to_await)
 
-    def test_await_only_no_greenlet(self):
+    @async_test
+    async def test_await_only_no_greenlet(self):
         to_await = run1()
         with expect_raises_message(
             exc.InvalidRequestError,
@@ -62,7 +63,7 @@ class TestAsyncioCompat(fixtures.AsyncTestBase):
             await_only(to_await)
 
         # ensure no warning
-        await_fallback(to_await)
+        await greenlet_spawn(await_fallback, to_await)
 
     @async_test
     async def test_await_fallback_error(self):
index cd1e16ed9166dc4abc7638207a83b7cf663da2b4..7dae1411e542ae1884ac76cc12ac85c8bd277850 100644 (file)
@@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import engine as _async_engine
 from sqlalchemy.ext.asyncio import exc as asyncio_exc
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import combinations
+from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
@@ -32,7 +33,7 @@ class EngineFixture(fixtures.TablesTest):
 
     @testing.fixture
     def async_engine(self):
-        return create_async_engine(testing.db.url)
+        return engines.testing_engine(asyncio=True)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -55,6 +56,12 @@ class EngineFixture(fixtures.TablesTest):
 class AsyncEngineTest(EngineFixture):
     __backend__ = True
 
+    @testing.fails("the failure is the test")
+    @async_test
+    async def test_we_are_definitely_running_async_tests(self, async_engine):
+        async with async_engine.connect() as conn:
+            eq_(await conn.scalar(text("select 1")), 2)
+
     def test_proxied_attrs_engine(self, async_engine):
         sync_engine = async_engine.sync_engine
 
index 37e1b807b11fa9cce34696c74ba64c086c8ca1e7..dbe84e82c3eef9023681ea3757f1a214c3fd0c0e 100644 (file)
@@ -5,10 +5,10 @@ from sqlalchemy import select
 from sqlalchemy import testing
 from sqlalchemy import update
 from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.testing import async_test
+from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
@@ -24,7 +24,7 @@ class AsyncFixture(_fixtures.FixtureTest):
 
     @testing.fixture
     def async_engine(self):
-        return create_async_engine(testing.db.url)
+        return engines.testing_engine(asyncio=True)
 
     @testing.fixture
     def async_session(self, async_engine):