]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support testing of async drivers without fallback mode
authorFederico Caselli <cfederico87@gmail.com>
Sun, 16 Aug 2020 08:48:57 +0000 (10:48 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 30 Dec 2020 20:49:09 +0000 (15:49 -0500)
Change-Id: I4940d184a4dc790782fcddfb9873af3cca844398

23 files changed:
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/pool/__init__.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/testing/asyncio.py [new file with mode: 0644]
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/engines.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/testing/provision.py
lib/sqlalchemy/testing/util.py
lib/sqlalchemy/util/_concurrency_py3k.py
lib/sqlalchemy/util/concurrency.py
lib/sqlalchemy/util/queue.py
setup.cfg
test/aaa_profiling/test_memusage.py
test/base/test_concurrency_py3k.py
test/dialect/postgresql/test_types.py
test/orm/test_deprecations.py
test/requirements.py
tox.ini

index f560ece3321419caab9fcaa04ce67e50dbf85765..81864603dc2ad63fbe88f227562e2423ad4dc868 100644 (file)
@@ -34,6 +34,7 @@ handling.
 
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
+from ... import util
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
@@ -226,7 +227,7 @@ class AsyncAdapt_aiomysql_dbapi:
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
 
-        if async_fallback:
+        if util.asbool(async_fallback):
             return AsyncAdaptFallback_aiomysql_connection(
                 self,
                 await_fallback(self.aiomysql.connect(*arg, **kw)),
@@ -244,6 +245,8 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
     supports_server_side_cursors = True
     _sscursor = AsyncAdapt_aiomysql_ss_cursor
 
+    is_async = True
+
     @classmethod
     def dbapi(cls):
         return AsyncAdapt_aiomysql_dbapi(
@@ -251,8 +254,14 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
         )
 
     @classmethod
-    def get_pool_class(self, url):
-        return pool.AsyncAdaptedQueuePool
+    def get_pool_class(cls, url):
+
+        async_fallback = url.query.get("async_fallback", False)
+
+        if util.asbool(async_fallback):
+            return pool.FallbackAsyncAdaptedQueuePool
+        else:
+            return pool.AsyncAdaptedQueuePool
 
     def create_connect_args(self, url):
         args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url)
index 889293eabda26bcf69a440e68a583b041d68215e..6b7e78266476d085b55f57299ce9d77c061e4034 100644 (file)
@@ -587,7 +587,7 @@ class AsyncAdapt_asyncpg_dbapi:
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
 
-        if async_fallback:
+        if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncpg_connection(
                 self,
                 await_fallback(self.asyncpg.connect(*arg, **kw)),
@@ -729,6 +729,7 @@ class PGDialect_asyncpg(PGDialect):
             REGCLASS: AsyncpgREGCLASS,
         },
     )
+    is_async = True
 
     @util.memoized_property
     def _dbapi_version(self):
@@ -792,8 +793,14 @@ class PGDialect_asyncpg(PGDialect):
         return ([], opts)
 
     @classmethod
-    def get_pool_class(self, url):
-        return pool.AsyncAdaptedQueuePool
+    def get_pool_class(cls, url):
+
+        async_fallback = url.query.get("async_fallback", False)
+
+        if util.asbool(async_fallback):
+            return pool.FallbackAsyncAdaptedQueuePool
+        else:
+            return pool.AsyncAdaptedQueuePool
 
     def is_disconnect(self, e, connection, cursor):
         if connection:
index a754ebe586ea9adfde38cd580b2fe27ba01c81bf..0086fe310b75fc295f5e97b9c1583c2c11809088 100644 (file)
@@ -210,6 +210,8 @@ class DefaultDialect(interfaces.Dialect):
 
     """
 
+    is_async = False
+
     CACHE_HIT = CACHE_HIT
     CACHE_MISS = CACHE_MISS
     CACHING_DISABLED = CACHING_DISABLED
index 353f34333c94172c889798b38206e2d442e47573..7254ec0f7f01745477a97a7ecb1872d1b5751eb8 100644 (file)
@@ -29,6 +29,7 @@ from .dbapi_proxy import clear_managers
 from .dbapi_proxy import manage
 from .impl import AssertionPool
 from .impl import AsyncAdaptedQueuePool
+from .impl import FallbackAsyncAdaptedQueuePool
 from .impl import NullPool
 from .impl import QueuePool
 from .impl import SingletonThreadPool
@@ -46,6 +47,7 @@ __all__ = [
     "NullPool",
     "QueuePool",
     "AsyncAdaptedQueuePool",
+    "FallbackAsyncAdaptedQueuePool",
     "SingletonThreadPool",
     "StaticPool",
 ]
index 38afbc7a1a37a6090e17831439bd1670c76aaae6..312b1b732b2175fe88271a34f8983211b2b9e6c3 100644 (file)
@@ -226,6 +226,10 @@ class AsyncAdaptedQueuePool(QueuePool):
     _queue_class = sqla_queue.AsyncAdaptedQueue
 
 
+class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool):
+    _queue_class = sqla_queue.FallbackAsyncAdaptedQueue
+
+
 class NullPool(Pool):
 
     """A Pool which does not pool connections.
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
new file mode 100644 (file)
index 0000000..52386d3
--- /dev/null
@@ -0,0 +1,124 @@
+# testing/asyncio.py
+# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+
+# functions and wrappers to run tests, fixtures, provisioning and
+# setup/teardown in an asyncio event loop, conditionally based on the
+# current DB driver being used for a test.
+
+# note that SQLAlchemy's asyncio integration also supports a method
+# of running individual asyncio functions inside of separate event loops
+# using "async_fallback" mode; however running whole functions in the event
+# loop is a more accurate test for how SQLAlchemy's asyncio features
+# would run in the real world.
+
+
+from functools import wraps
+import inspect
+
+from . import config
+from ..util.concurrency import _util_async_run
+
+# may be set to False if the
+# --disable-asyncio flag is passed to the test runner.
+ENABLE_ASYNCIO = True
+
+
+def _assume_async(fn, *args, **kwargs):
+    """Run a function in an asyncio loop unconditionally.
+
+    This function is used for provisioning features like
+    testing a database connection for server info.
+
+    Note that for blocking IO database drivers, this means they block the
+    event loop.
+
+    """
+
+    if not ENABLE_ASYNCIO:
+        return fn(*args, **kwargs)
+
+    return _util_async_run(fn, *args, **kwargs)
+
+
+def _maybe_async_provisioning(fn, *args, **kwargs):
+    """Run a function in an asyncio loop if any current drivers might need it.
+
+    This function is used for provisioning features that take
+    place outside of a specific database driver being selected, so if the
+    current driver that happens to be used for the provisioning operation
+    is an async driver, it will run in asyncio and not fail.
+
+    Note that for blocking IO database drivers, this means they block the
+    event loop.
+
+    """
+    if not ENABLE_ASYNCIO:
+
+        return fn(*args, **kwargs)
+
+    if config.any_async:
+        return _util_async_run(fn, *args, **kwargs)
+    else:
+        return fn(*args, **kwargs)
+
+
+def _maybe_async(fn, *args, **kwargs):
+    """Run a function in an asyncio loop if the current selected driver is
+    async.
+
+    This function is used for test setup/teardown and tests themselves
+    where the current DB driver is known.
+
+
+    """
+    if not ENABLE_ASYNCIO:
+
+        return fn(*args, **kwargs)
+
+    is_async = config._current.is_async
+
+    if is_async:
+        return _util_async_run(fn, *args, **kwargs)
+    else:
+        return fn(*args, **kwargs)
+
+
+def _maybe_async_wrapper(fn):
+    """Apply the _maybe_async function to an existing function and return
+    as a wrapped callable, supporting generator functions as well.
+
+    This is currently used for pytest fixtures that support generator use.
+
+    """
+
+    if inspect.isgeneratorfunction(fn):
+        _stop = object()
+
+        def call_next(gen):
+            try:
+                return next(gen)
+                # can't raise StopIteration in an awaitable.
+            except StopIteration:
+                return _stop
+
+        @wraps(fn)
+        def wrap_fixture(*args, **kwargs):
+            gen = fn(*args, **kwargs)
+            while True:
+                value = _maybe_async(call_next, gen)
+                if value is _stop:
+                    break
+                yield value
+
+    else:
+
+        @wraps(fn)
+        def wrap_fixture(*args, **kwargs):
+            return _maybe_async(fn, *args, **kwargs)
+
+    return wrap_fixture
index 0b8027b84404bd6da6f5ac10baa73b1161b453b0..270ac4c2c09729e1ccc2b9fe0d6672f4442971c0 100644 (file)
@@ -7,6 +7,8 @@
 
 import collections
 
+from .. import util
+
 requirements = None
 db = None
 db_url = None
@@ -14,6 +16,7 @@ db_opts = None
 file_config = None
 test_schema = None
 test_schema_2 = None
+any_async = False
 _current = None
 ident = "main"
 
@@ -104,6 +107,10 @@ class Config(object):
         self.test_schema = "test_schema"
         self.test_schema_2 = "test_schema_2"
 
+        self.is_async = db.dialect.is_async and not util.asbool(
+            db.url.query.get("async_fallback", False)
+        )
+
     _stack = collections.deque()
     _configs = set()
 
@@ -121,7 +128,15 @@ class Config(object):
         If there are no configs set up yet, this config also
         gets set as the "_current".
         """
+        global any_async
+
         cfg = Config(db, db_opts, options, file_config)
+
+        # if any backends include an async driver, then ensure
+        # all setup/teardown and tests are wrapped in the maybe_async()
+        # decorator that will set up a greenlet context for async drivers.
+        any_async = any_async or cfg.is_async
+
         cls._configs.add(cfg)
         return cfg
 
index bb137cb328671302cf5a951387f48899c4c3422f..d0a1bc0d0d1ade7eb26ce93b7c370554dbf6ecb5 100644 (file)
@@ -46,7 +46,7 @@ class ConnectionKiller(object):
             fn()
         except Exception as e:
             warnings.warn(
-                "testing_reaper couldn't " "rollback/close connection: %s" % e
+                "testing_reaper couldn't rollback/close connection: %s" % e
             )
 
     def rollback_all(self):
@@ -199,9 +199,7 @@ class ReconnectFixture(object):
         try:
             fn()
         except Exception as e:
-            warnings.warn(
-                "ReconnectFixture couldn't " "close connection: %s" % e
-            )
+            warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
 
     def shutdown(self, stop=False):
         # TODO: this doesn't cover all cases
index 0ede25176a4c7bf439ec488b28378e36e61fd40e..a52fdd1967783e54165ab29fc88ee4af8f96a4a4 100644 (file)
@@ -48,6 +48,11 @@ class TestBase(object):
     # skipped.
     __skip_if__ = None
 
+    # If this class should be wrapped in asyncio compatibility functions
+    # when using an async engine. This should be set to False only for tests
+    # that use the asyncio features of sqlalchemy directly
+    __asyncio_wrap__ = True
+
     def assert_(self, val, msg=None):
         assert val, msg
 
@@ -90,6 +95,12 @@ class TestBase(object):
     #       engines.drop_all_tables(metadata, config.db)
 
 
+class AsyncTestBase(TestBase):
+    """Mixin marking a test as using its own explicit asyncio patterns."""
+
+    __asyncio_wrap__ = False
+
+
 class FutureEngineMixin(object):
     @classmethod
     def setup_class(cls):
index 5e41f2cdfc5b29ae9581a85c93aed461f1a67538..8b6a7d68ab3973e2114bd0724fec4f34a92528bd 100644 (file)
@@ -63,21 +63,21 @@ def setup_options(make_option):
     make_option(
         "--log-info",
         action="callback",
-        type="string",
+        type=str,
         callback=_log,
         help="turn on info logging for <LOG> (multiple OK)",
     )
     make_option(
         "--log-debug",
         action="callback",
-        type="string",
+        type=str,
         callback=_log,
         help="turn on debug logging for <LOG> (multiple OK)",
     )
     make_option(
         "--db",
         action="append",
-        type="string",
+        type=str,
         dest="db",
         help="Use prefab database uri. Multiple OK, "
         "first one is run by default.",
@@ -91,7 +91,7 @@ def setup_options(make_option):
     make_option(
         "--dburi",
         action="append",
-        type="string",
+        type=str,
         dest="dburi",
         help="Database uri.  Multiple OK, " "first one is run by default.",
     )
@@ -110,6 +110,11 @@ def setup_options(make_option):
         dest="dropfirst",
         help="Drop all tables in the target database first",
     )
+    make_option(
+        "--disable-asyncio",
+        action="store_true",
+        help="disable test / fixtures / provisoning running in asyncio",
+    )
     make_option(
         "--backend-only",
         action="store_true",
@@ -130,20 +135,20 @@ def setup_options(make_option):
     )
     make_option(
         "--profile-sort",
-        type="string",
+        type=str,
         default="cumulative",
         dest="profilesort",
         help="Type of sort for profiling standard output",
     )
     make_option(
         "--profile-dump",
-        type="string",
+        type=str,
         dest="profiledump",
         help="Filename where a single profile run will be dumped",
     )
     make_option(
         "--postgresql-templatedb",
-        type="string",
+        type=str,
         help="name of template database to use for PostgreSQL "
         "CREATE DATABASE (defaults to current database)",
     )
@@ -156,7 +161,7 @@ def setup_options(make_option):
     )
     make_option(
         "--write-idents",
-        type="string",
+        type=str,
         dest="write_idents",
         help="write out generated follower idents to <file>, "
         "when -n<num> is used",
@@ -172,7 +177,7 @@ def setup_options(make_option):
     make_option(
         "--requirements",
         action="callback",
-        type="string",
+        type=str,
         callback=_requirements_opt,
         help="requirements class for testing, overrides setup.cfg",
     )
@@ -188,14 +193,14 @@ def setup_options(make_option):
         "--include-tag",
         action="callback",
         callback=_include_tag,
-        type="string",
+        type=str,
         help="Include tests with tag <tag>",
     )
     make_option(
         "--exclude-tag",
         action="callback",
         callback=_exclude_tag,
-        type="string",
+        type=str,
         help="Exclude tests with tag <tag>",
     )
     make_option(
@@ -374,11 +379,19 @@ def _init_symbols(options, file_config):
     config._fixture_functions = _fixture_fn_class()
 
 
+@post
+def _set_disable_asyncio(opt, file_config):
+    if opt.disable_asyncio:
+        from sqlalchemy.testing import asyncio
+
+        asyncio.ENABLE_ASYNCIO = False
+
+
 @post
 def _engine_uri(options, file_config):
 
-    from sqlalchemy.testing import config
     from sqlalchemy import testing
+    from sqlalchemy.testing import config
     from sqlalchemy.testing import provision
 
     if options.dburi:
index 644ea6dc20d0a3c4fecf33e9a6773914213b254d..6be64aa6106b30f4b03339707a8dc6e696b961b8 100644 (file)
@@ -25,11 +25,6 @@ else:
     if typing.TYPE_CHECKING:
         from typing import Sequence
 
-try:
-    import asyncio
-except ImportError:
-    pass
-
 try:
     import xdist  # noqa
 
@@ -126,11 +121,15 @@ def collect_types_fixture():
 
 
 def pytest_sessionstart(session):
-    plugin_base.post_begin()
+    from sqlalchemy.testing import asyncio
+
+    asyncio._assume_async(plugin_base.post_begin)
 
 
 def pytest_sessionfinish(session):
-    plugin_base.final_process_cleanup()
+    from sqlalchemy.testing import asyncio
+
+    asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
 
     if session.config.option.dump_pyannotate:
         from pyannotate_runtime import collect_types
@@ -162,23 +161,31 @@ if has_xdist:
     import uuid
 
     def pytest_configure_node(node):
+        from sqlalchemy.testing import provision
+        from sqlalchemy.testing import asyncio
+
         # the master for each node fills workerinput dictionary
         # which pytest-xdist will transfer to the subprocess
 
         plugin_base.memoize_important_follower_config(node.workerinput)
 
         node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
-        from sqlalchemy.testing import provision
 
-        provision.create_follower_db(node.workerinput["follower_ident"])
+        asyncio._maybe_async_provisioning(
+            provision.create_follower_db, node.workerinput["follower_ident"]
+        )
 
     def pytest_testnodedown(node, error):
         from sqlalchemy.testing import provision
+        from sqlalchemy.testing import asyncio
 
-        provision.drop_follower_db(node.workerinput["follower_ident"])
+        asyncio._maybe_async_provisioning(
+            provision.drop_follower_db, node.workerinput["follower_ident"]
+        )
 
 
 def pytest_collection_modifyitems(session, config, items):
+
     # look for all those classes that specify __backend__ and
     # expand them out into per-database test cases.
 
@@ -189,6 +196,8 @@ def pytest_collection_modifyitems(session, config, items):
     # it's to suit the rather odd use case here which is that we are adding
     # new classes to a module on the fly.
 
+    from sqlalchemy.testing import asyncio
+
     rebuilt_items = collections.defaultdict(
         lambda: collections.defaultdict(list)
     )
@@ -201,20 +210,26 @@ def pytest_collection_modifyitems(session, config, items):
     ]
 
     test_classes = set(item.parent for item in items)
-    for test_class in test_classes:
-        for sub_cls in plugin_base.generate_sub_tests(
-            test_class.cls, test_class.parent.module
-        ):
-            if sub_cls is not test_class.cls:
-                per_cls_dict = rebuilt_items[test_class.cls]
 
-                # support pytest 5.4.0 and above pytest.Class.from_parent
-                ctor = getattr(pytest.Class, "from_parent", pytest.Class)
-                for inst in ctor(
-                    name=sub_cls.__name__, parent=test_class.parent.parent
-                ).collect():
-                    for t in inst.collect():
-                        per_cls_dict[t.name].append(t)
+    def setup_test_classes():
+        for test_class in test_classes:
+            for sub_cls in plugin_base.generate_sub_tests(
+                test_class.cls, test_class.parent.module
+            ):
+                if sub_cls is not test_class.cls:
+                    per_cls_dict = rebuilt_items[test_class.cls]
+
+                    # support pytest 5.4.0 and above pytest.Class.from_parent
+                    ctor = getattr(pytest.Class, "from_parent", pytest.Class)
+                    for inst in ctor(
+                        name=sub_cls.__name__, parent=test_class.parent.parent
+                    ).collect():
+                        for t in inst.collect():
+                            per_cls_dict[t.name].append(t)
+
+    # class requirements will sometimes need to access the DB to check
+    # capabilities, so need to do this for async
+    asyncio._maybe_async_provisioning(setup_test_classes)
 
     newitems = []
     for item in items:
@@ -238,6 +253,10 @@ def pytest_collection_modifyitems(session, config, items):
 def pytest_pycollect_makeitem(collector, name, obj):
 
     if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+        from sqlalchemy.testing import config
+
+        if config.any_async and getattr(obj, "__asyncio_wrap__", True):
+            obj = _apply_maybe_async(obj)
 
         ctor = getattr(pytest.Class, "from_parent", pytest.Class)
 
@@ -258,6 +277,38 @@ def pytest_pycollect_makeitem(collector, name, obj):
         return []
 
 
+def _apply_maybe_async(obj, recurse=True):
+    from sqlalchemy.testing import asyncio
+
+    setup_names = {"setup", "setup_class", "teardown", "teardown_class"}
+    for name, value in vars(obj).items():
+        if (
+            (callable(value) or isinstance(value, classmethod))
+            and not getattr(value, "_maybe_async_applied", False)
+            and (name.startswith("test_") or name in setup_names)
+        ):
+            is_classmethod = False
+            if isinstance(value, classmethod):
+                value = value.__func__
+                is_classmethod = True
+
+            @_pytest_fn_decorator
+            def make_async(fn, *args, **kwargs):
+                return asyncio._maybe_async(fn, *args, **kwargs)
+
+            do_async = make_async(value)
+            if is_classmethod:
+                do_async = classmethod(do_async)
+            do_async._maybe_async_applied = True
+
+            setattr(obj, name, do_async)
+    if recurse:
+        for cls in obj.mro()[1:]:
+            if cls != object:
+                _apply_maybe_async(cls, False)
+    return obj
+
+
 _current_class = None
 
 
@@ -297,6 +348,8 @@ def _parametrize_cls(module, cls):
 
 
 def pytest_runtest_setup(item):
+    from sqlalchemy.testing import asyncio
+
     # here we seem to get called only based on what we collected
     # in pytest_collection_modifyitems.   So to do class-based stuff
     # we have to tear that out.
@@ -307,7 +360,7 @@ def pytest_runtest_setup(item):
 
     # ... so we're doing a little dance here to figure it out...
     if _current_class is None:
-        class_setup(item.parent.parent)
+        asyncio._maybe_async(class_setup, item.parent.parent)
         _current_class = item.parent.parent
 
         # this is needed for the class-level, to ensure that the
@@ -315,20 +368,22 @@ def pytest_runtest_setup(item):
         # class-level teardown...
         def finalize():
             global _current_class
-            class_teardown(item.parent.parent)
+            asyncio._maybe_async(class_teardown, item.parent.parent)
             _current_class = None
 
         item.parent.parent.addfinalizer(finalize)
 
-    test_setup(item)
+    asyncio._maybe_async(test_setup, item)
 
 
 def pytest_runtest_teardown(item):
+    from sqlalchemy.testing import asyncio
+
     # ...but this works better as the hook here rather than
     # using a finalizer, as the finalizer seems to get in the way
     # of the test reporting failures correctly (you get a bunch of
     # pytest assertion stuff instead)
-    test_teardown(item)
+    asyncio._maybe_async(test_teardown, item)
 
 
 def test_setup(item):
@@ -342,7 +397,9 @@ def test_teardown(item):
 
 
 def class_setup(item):
-    plugin_base.start_test_class(item.cls)
+    from sqlalchemy.testing import asyncio
+
+    asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls)
 
 
 def class_teardown(item):
@@ -372,17 +429,19 @@ def _pytest_fn_decorator(target):
         if add_positional_parameters:
             spec.args.extend(add_positional_parameters)
 
-        metadata = dict(target="target", fn="__fn", name=fn.__name__)
+        metadata = dict(
+            __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
+        )
         metadata.update(format_argspec_plus(spec, grouped=False))
         code = (
             """\
 def %(name)s(%(args)s):
-    return %(target)s(%(fn)s, %(apply_kw)s)
+    return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
 """
             % metadata
         )
         decorated = _exec_code_in_env(
-            code, {"target": target, "__fn": fn}, fn.__name__
+            code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
         )
         if not add_positional_parameters:
             decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
@@ -554,14 +613,49 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
         return pytest.param(*parameters[1:], id=ident)
 
     def fixture(self, *arg, **kw):
-        return pytest.fixture(*arg, **kw)
+        from sqlalchemy.testing import config
+        from sqlalchemy.testing import asyncio
+
+        # wrapping pytest.fixture function.  determine if
+        # decorator was called as @fixture or @fixture().
+        if len(arg) > 0 and callable(arg[0]):
+            # was called as @fixture(), we have the function to wrap.
+            fn = arg[0]
+            arg = arg[1:]
+        else:
+            # was called as @fixture, don't have the function yet.
+            fn = None
+
+        # create a pytest.fixture marker.  because the fn is not being
+        # passed, this is always a pytest.FixtureFunctionMarker()
+        # object (or whatever pytest is calling it when you read this)
+        # that is waiting for a function.
+        fixture = pytest.fixture(*arg, **kw)
+
+        # now apply wrappers to the function, including fixture itself
+
+        def wrap(fn):
+            if config.any_async:
+                fn = asyncio._maybe_async_wrapper(fn)
+            # other wrappers may be added here
+
+            # now apply FixtureFunctionMarker
+            fn = fixture(fn)
+            return fn
+
+        if fn:
+            return wrap(fn)
+        else:
+            return wrap
 
     def get_current_test_name(self):
         return os.environ.get("PYTEST_CURRENT_TEST")
 
     def async_test(self, fn):
+        from sqlalchemy.testing import asyncio
+
         @_pytest_fn_decorator
         def decorate(fn, *args, **kwargs):
-            asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs))
+            asyncio._assume_async(fn, *args, **kwargs)
 
         return decorate(fn)
index c4f489a69df07fa8269f744fe6f0b13f1e30e2b2..fb3d77dc40daf8694cfcc8e516259003e41a4b12 100644 (file)
@@ -94,11 +94,11 @@ def generate_db_urls(db_urls, extra_drivers):
         --dburi postgresql://db2  \
         --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
 
-    Noting that the default postgresql driver is psycopg2.  the output
+    Noting that the default postgresql driver is psycopg2,  the output
     would be::
 
         postgresql+psycopg2://db1
-        postgresql+asyncpg://db1?async_fallback=true
+        postgresql+asyncpg://db1
         postgresql+psycopg2://db2
         postgresql+psycopg2://db3
 
@@ -108,6 +108,12 @@ def generate_db_urls(db_urls, extra_drivers):
     for a driver that is both coming from --dburi as well as --dbdrivers,
     we want to keep it in that dburi.
 
+    Driver specific query options can be specified by added them to the
+    driver name. For example, to enable the async fallback option for
+    asyncpg::
+
+        --dburi postgresql://db1  \
+        --dbdriver=asyncpg?async_fallback=true
 
     """
     urls = set()
index c6626b9e087f9f17c6b9a801c9e461409f43dd85..bbaf5034f85953dd5c8505dfaca6d822c0d8ef1b 100644 (file)
@@ -11,13 +11,24 @@ import random
 import sys
 import types
 
+from . import config
 from . import mock
+from .. import inspect
+from ..schema import Column
+from ..schema import DropConstraint
+from ..schema import DropTable
+from ..schema import ForeignKeyConstraint
+from ..schema import MetaData
+from ..schema import Table
+from ..sql import schema
+from ..sql.sqltypes import Integer
 from ..util import decorator
 from ..util import defaultdict
 from ..util import has_refcount_gc
 from ..util import inspect_getfullargspec
 from ..util import py2k
 
+
 if not has_refcount_gc:
 
     def non_refcount_gc_collect(*args):
@@ -198,9 +209,9 @@ def fail(msg):
 def provide_metadata(fn, *args, **kw):
     """Provide bound MetaData for a single test, dropping afterwards."""
 
-    from . import config
+    # import cycle that only occurs with py2k's import resolver
+    # in py3k this can be moved top level.
     from . import engines
-    from sqlalchemy import schema
 
     metadata = schema.MetaData(config.db)
     self = args[0]
@@ -243,8 +254,6 @@ def flag_combinations(*combinations):
 
     """
 
-    from . import config
-
     keys = set()
 
     for d in combinations:
@@ -264,8 +273,6 @@ def flag_combinations(*combinations):
 
 
 def lambda_combinations(lambda_arg_sets, **kw):
-    from . import config
-
     args = inspect_getfullargspec(lambda_arg_sets)
 
     arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
@@ -302,11 +309,8 @@ def resolve_lambda(__fn, **kw):
 def metadata_fixture(ddl="function"):
     """Provide MetaData for a pytest fixture."""
 
-    from . import config
-
     def decorate(fn):
         def run_ddl(self):
-            from sqlalchemy import schema
 
             metadata = self.metadata = schema.MetaData()
             try:
@@ -328,8 +332,6 @@ def force_drop_names(*names):
     isolating for foreign key cycles
 
     """
-    from . import config
-    from sqlalchemy import inspect
 
     @decorator
     def go(fn, *args, **kw):
@@ -358,14 +360,6 @@ class adict(dict):
 
 
 def drop_all_tables(engine, inspector, schema=None, include_names=None):
-    from sqlalchemy import (
-        Column,
-        Table,
-        Integer,
-        MetaData,
-        ForeignKeyConstraint,
-    )
-    from sqlalchemy.schema import DropTable, DropConstraint
 
     if include_names is not None:
         include_names = set(include_names)
index 8ad3be5439b78c16b5c3838357775e54226c4c30..6042e4395a6eed1964e59f486f0ee3a70c116121 100644 (file)
@@ -64,7 +64,6 @@ def await_fallback(awaitable: Coroutine) -> Any:
     :param awaitable: The coroutine to call.
 
     """
-
     # this is called in the context greenlet while running fn
     current = greenlet.getcurrent()
     if not isinstance(current, _AsyncIoGreenlet):
@@ -135,3 +134,15 @@ class AsyncAdaptedLock:
 
     def __exit__(self, *arg, **kw):
         self.mutex.release()
+
+
+def _util_async_run(fn, *args, **kwargs):
+    """for test suite/ util only"""
+
+    loop = asyncio.get_event_loop()
+    if not loop.is_running():
+        return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
+    else:
+        # allow for a wrapped test function to call another
+        assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet)
+        return fn(*args, **kwargs)
index f78c0971c786dfb56a04ffd722ef5527e3d530e9..7b4ff6ba40c2aae8070ebc9f647821b908c7f450 100644 (file)
@@ -13,6 +13,7 @@ if compat.py3k:
         from ._concurrency_py3k import await_fallback
         from ._concurrency_py3k import greenlet_spawn
         from ._concurrency_py3k import AsyncAdaptedLock
+        from ._concurrency_py3k import _util_async_run  # noqa F401
         from ._concurrency_py3k import asyncio  # noqa F401
 
 if not have_greenlet:
@@ -38,3 +39,6 @@ if not have_greenlet:
 
     def AsyncAdaptedLock(*args, **kw):  # noqa F81
         _not_implemented()
+
+    def _util_async_run(fn, *arg, **kw):  # noqa F81
+        return fn(*arg, **kw)
index 3687dc8dc31f331cb9819a59872335c36e94708e..a92d6b862e364bad8dc2cf5056629065cbcfe4d0 100644 (file)
@@ -25,6 +25,7 @@ from . import compat
 from .compat import threading
 from .concurrency import asyncio
 from .concurrency import await_fallback
+from .concurrency import await_only
 
 
 __all__ = ["Empty", "Full", "Queue"]
@@ -202,7 +203,7 @@ class Queue:
 
 
 class AsyncAdaptedQueue:
-    await_ = staticmethod(await_fallback)
+    await_ = staticmethod(await_only)
 
     def __init__(self, maxsize=0, use_lifo=False):
         if use_lifo:
@@ -265,3 +266,7 @@ class AsyncAdaptedQueue:
                 Empty(),
                 replace_context=err,
             )
+
+
+class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue):
+    await_ = staticmethod(await_fallback)
index 46fe781044233673ebf34c514c1b41229355cdae..e151769daca530ae167a10ca1b3752a51260c174 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -37,6 +37,7 @@ packages = find:
 python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*
 package_dir =
     =lib
+# TODO remove greenlet from the default requires?
 install_requires =
     importlib-metadata;python_version<"3.8"
     greenlet != 0.4.17;python_version>="3"
@@ -120,12 +121,14 @@ default = sqlite:///:memory:
 sqlite = sqlite:///:memory:
 sqlite_file = sqlite:///querytest.db
 postgresql = postgresql://scott:tiger@127.0.0.1:5432/test
-asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
+asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test
+asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
 pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
 postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test
 mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
 pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
-aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
+aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+aiomysql_fallback = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
 mariadb = mariadb://scott:tiger@127.0.0.1:3306/test
 mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
 mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008
index 5e388c0b7d29efa3ab234d293cfbcf39694657c4..2284c1326c86c52a5f58477f14b8aa2ddb775802 100644 (file)
@@ -1568,7 +1568,7 @@ class CycleTest(_fixtures.FixtureTest):
 
         go()
 
-    @testing.fails
+    @testing.fails()
     def test_the_counter(self):
         @assert_cycles()
         def go():
index cf1067667d6a680afb5ec9a60947b0873bf1795d..2cc2075bcd6b9e78f01477281b8231b1ee129dc6 100644 (file)
@@ -26,7 +26,7 @@ def go(*fns):
     return sum(await_only(fn()) for fn in fns)
 
 
-class TestAsyncioCompat(fixtures.TestBase):
+class TestAsyncioCompat(fixtures.AsyncTestBase):
     @async_test
     async def test_ok(self):
 
index ae7a65a3af8a03b8065a50af4a85b92c325ec063..38da60a434056edb9373a47090f883a016cfafab 100644 (file)
@@ -2734,14 +2734,14 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
 
     def test_where_equal(self):
         self._test_clause(
-            self.col == self._data_str,
+            self.col == self._data_str(),
             "data_table.range = %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_where_not_equal(self):
         self._test_clause(
-            self.col != self._data_str,
+            self.col != self._data_str(),
             "data_table.range <> %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
@@ -2760,94 +2760,94 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
 
     def test_where_less_than(self):
         self._test_clause(
-            self.col < self._data_str,
+            self.col < self._data_str(),
             "data_table.range < %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_where_greater_than(self):
         self._test_clause(
-            self.col > self._data_str,
+            self.col > self._data_str(),
             "data_table.range > %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_where_less_than_or_equal(self):
         self._test_clause(
-            self.col <= self._data_str,
+            self.col <= self._data_str(),
             "data_table.range <= %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_where_greater_than_or_equal(self):
         self._test_clause(
-            self.col >= self._data_str,
+            self.col >= self._data_str(),
             "data_table.range >= %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_contains(self):
         self._test_clause(
-            self.col.contains(self._data_str),
+            self.col.contains(self._data_str()),
             "data_table.range @> %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_contained_by(self):
         self._test_clause(
-            self.col.contained_by(self._data_str),
+            self.col.contained_by(self._data_str()),
             "data_table.range <@ %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_overlaps(self):
         self._test_clause(
-            self.col.overlaps(self._data_str),
+            self.col.overlaps(self._data_str()),
             "data_table.range && %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_strictly_left_of(self):
         self._test_clause(
-            self.col << self._data_str,
+            self.col << self._data_str(),
             "data_table.range << %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
         self._test_clause(
-            self.col.strictly_left_of(self._data_str),
+            self.col.strictly_left_of(self._data_str()),
             "data_table.range << %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_strictly_right_of(self):
         self._test_clause(
-            self.col >> self._data_str,
+            self.col >> self._data_str(),
             "data_table.range >> %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
         self._test_clause(
-            self.col.strictly_right_of(self._data_str),
+            self.col.strictly_right_of(self._data_str()),
             "data_table.range >> %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_not_extend_right_of(self):
         self._test_clause(
-            self.col.not_extend_right_of(self._data_str),
+            self.col.not_extend_right_of(self._data_str()),
             "data_table.range &< %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_not_extend_left_of(self):
         self._test_clause(
-            self.col.not_extend_left_of(self._data_str),
+            self.col.not_extend_left_of(self._data_str()),
             "data_table.range &> %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
 
     def test_adjacent_to(self):
         self._test_clause(
-            self.col.adjacent_to(self._data_str),
+            self.col.adjacent_to(self._data_str()),
             "data_table.range -|- %(range_1)s",
             sqltypes.BOOLEANTYPE,
         )
@@ -2920,14 +2920,14 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
 
     def test_insert_text(self, connection):
         connection.execute(
-            self.tables.data_table.insert(), {"range": self._data_str}
+            self.tables.data_table.insert(), {"range": self._data_str()}
         )
         self._assert_data(connection)
 
     def test_union_result(self, connection):
         # insert
         connection.execute(
-            self.tables.data_table.insert(), {"range": self._data_str}
+            self.tables.data_table.insert(), {"range": self._data_str()}
         )
         # select
         range_ = self.tables.data_table.c.range
@@ -2937,7 +2937,7 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
     def test_intersection_result(self, connection):
         # insert
         connection.execute(
-            self.tables.data_table.insert(), {"range": self._data_str}
+            self.tables.data_table.insert(), {"range": self._data_str()}
         )
         # select
         range_ = self.tables.data_table.c.range
@@ -2947,7 +2947,7 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
     def test_difference_result(self, connection):
         # insert
         connection.execute(
-            self.tables.data_table.insert(), {"range": self._data_str}
+            self.tables.data_table.insert(), {"range": self._data_str()}
         )
         # select
         range_ = self.tables.data_table.c.range
@@ -2959,7 +2959,9 @@ class _Int4RangeTests(object):
 
     _col_type = INT4RANGE
     _col_str = "INT4RANGE"
-    _data_str = "[1,2)"
+
+    def _data_str(self):
+        return "[1,2)"
 
     def _data_obj(self):
         return self.extras().NumericRange(1, 2)
@@ -2969,7 +2971,9 @@ class _Int8RangeTests(object):
 
     _col_type = INT8RANGE
     _col_str = "INT8RANGE"
-    _data_str = "[9223372036854775806,9223372036854775807)"
+
+    def _data_str(self):
+        return "[9223372036854775806,9223372036854775807)"
 
     def _data_obj(self):
         return self.extras().NumericRange(
@@ -2981,7 +2985,9 @@ class _NumRangeTests(object):
 
     _col_type = NUMRANGE
     _col_str = "NUMRANGE"
-    _data_str = "[1.0,2.0)"
+
+    def _data_str(self):
+        return "[1.0,2.0)"
 
     def _data_obj(self):
         return self.extras().NumericRange(
@@ -2993,7 +2999,9 @@ class _DateRangeTests(object):
 
     _col_type = DATERANGE
     _col_str = "DATERANGE"
-    _data_str = "[2013-03-23,2013-03-24)"
+
+    def _data_str(self):
+        return "[2013-03-23,2013-03-24)"
 
     def _data_obj(self):
         return self.extras().DateRange(
@@ -3005,7 +3013,9 @@ class _DateTimeRangeTests(object):
 
     _col_type = TSRANGE
     _col_str = "TSRANGE"
-    _data_str = "[2013-03-23 14:30,2013-03-23 23:30)"
+
+    def _data_str(self):
+        return "[2013-03-23 14:30,2013-03-23 23:30)"
 
     def _data_obj(self):
         return self.extras().DateTimeRange(
@@ -3031,7 +3041,6 @@ class _DateTimeTZRangeTests(object):
                 self._tstzs = (lower, upper)
         return self._tstzs
 
-    @property
     def _data_str(self):
         return "[%s,%s)" % self.tstzs()
 
@@ -3178,7 +3187,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
     __only_on__ = ("postgresql >= 9.3",)
     __backend__ = True
 
-    test_type = JSON
+    data_type = JSON
 
     @classmethod
     def define_tables(cls, metadata):
@@ -3187,8 +3196,8 @@ class JSONRoundTripTest(fixtures.TablesTest):
             metadata,
             Column("id", Integer, primary_key=True),
             Column("name", String(30), nullable=False),
-            Column("data", cls.test_type),
-            Column("nulldata", cls.test_type(none_as_null=True)),
+            Column("data", cls.data_type),
+            Column("nulldata", cls.data_type(none_as_null=True)),
         )
 
     def _fixture_data(self, engine):
@@ -3255,7 +3264,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
     def test_reflect(self):
         insp = inspect(testing.db)
         cols = insp.get_columns("data_table")
-        assert isinstance(cols[2]["type"], self.test_type)
+        assert isinstance(cols[2]["type"], self.data_type)
 
     def test_insert(self, connection):
         self._test_insert(connection)
@@ -3286,7 +3295,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
             options=dict(json_serializer=dumps, json_deserializer=loads)
         )
 
-        s = select(cast({"key": "value", "x": "q"}, self.test_type))
+        s = select(cast({"key": "value", "x": "q"}, self.data_type))
         with engine.begin() as conn:
             eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
 
@@ -3366,7 +3375,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
         s = select(
             cast(
                 {"key": "value", "key2": {"k1": "v1", "k2": "v2"}},
-                self.test_type,
+                self.data_type,
             )
         )
         eq_(
@@ -3381,7 +3390,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
                     util.u("réveillé"): util.u("réveillé"),
                     "data": {"k1": util.u("drôle")},
                 },
-                self.test_type,
+                self.data_type,
             )
         )
         eq_(
@@ -3483,7 +3492,7 @@ class JSONBTest(JSONTest):
 class JSONBRoundTripTest(JSONRoundTripTest):
     __requires__ = ("postgresql_jsonb",)
 
-    test_type = JSONB
+    data_type = JSONB
 
     @testing.requires.postgresql_utf8_server_encoding
     def test_unicode_round_trip(self, connection):
index c0d5a93d50d8a38c6bbbfb6750b1f81a5e1a41f9..7989026d13d9e0e9be1106ad9715f682d5e70a57 100644 (file)
@@ -3626,6 +3626,13 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
         "pg8000 parses the SQL itself before passing on "
         "to PG, doesn't parse this",
     )
+    @testing.fails_on(
+        "postgresql+asyncpg",
+        "Asyncpg uses preprated statements that are not compatible with how "
+        "sqlalchemy passes the query. Fails with "
+        'ERROR:  column "users.name" must appear in the GROUP BY clause'
+        " or be used in an aggregate function",
+    )
     @testing.fails_on("firebird", "unknown")
     def test_values_with_boolean_selects(self):
         """Tests a values clause that works with select boolean
index cb2f4840f6d558624dfd17b57055b0d949fab49e..5413d217e01572d8986bff5b835932494a12f940 100644 (file)
@@ -1330,7 +1330,10 @@ class DefaultRequirements(SuiteRequirements):
         """dialect makes use of await_() to invoke operations on the DBAPI."""
 
         return only_on(
-            ["postgresql+asyncpg", "mysql+aiomysql", "mariadb+aiomysql"]
+            LambdaPredicate(
+                lambda config: config.db.dialect.is_async,
+                "Async dialect required",
+            )
         )
 
     @property
diff --git a/tox.ini b/tox.ini
index 355e8446d0aed8a3ab248e9d578d67c54e6a704a..8c0e0b749897cab9926e03bf0028110bb4b5df39 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -74,11 +74,11 @@ setenv=
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
 
     postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
-    py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg?async_fallback=true --dbdriver pg8000}
+    py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000}
 
     mysql: MYSQL={env:TOX_MYSQL:--db mysql}
     mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
-    py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql?async_fallback=true}
+    py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql}
 
 
     mssql: MSSQL={env:TOX_MSSQL:--db mssql}