From: Federico Caselli Date: Sun, 16 Aug 2020 08:48:57 +0000 (+0200) Subject: Support testing of async drivers without fallback mode X-Git-Tag: rel_1_4_0b2~74 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=2581655c545a0cf705e0347e81cd092896d3207c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support testing of async drivers without fallback mode Change-Id: I4940d184a4dc790782fcddfb9873af3cca844398 --- diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index f560ece332..81864603dc 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -34,6 +34,7 @@ handling. from .pymysql import MySQLDialect_pymysql from ... import pool +from ... import util from ...util.concurrency import await_fallback from ...util.concurrency import await_only @@ -226,7 +227,7 @@ class AsyncAdapt_aiomysql_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - if async_fallback: + if util.asbool(async_fallback): return AsyncAdaptFallback_aiomysql_connection( self, await_fallback(self.aiomysql.connect(*arg, **kw)), @@ -244,6 +245,8 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): supports_server_side_cursors = True _sscursor = AsyncAdapt_aiomysql_ss_cursor + is_async = True + @classmethod def dbapi(cls): return AsyncAdapt_aiomysql_dbapi( @@ -251,8 +254,14 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): ) @classmethod - def get_pool_class(self, url): - return pool.AsyncAdaptedQueuePool + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool def create_connect_args(self, url): args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 889293eabd..6b7e782664 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -587,7 +587,7 @@ class AsyncAdapt_asyncpg_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - if async_fallback: + if util.asbool(async_fallback): return AsyncAdaptFallback_asyncpg_connection( self, await_fallback(self.asyncpg.connect(*arg, **kw)), @@ -729,6 +729,7 @@ class PGDialect_asyncpg(PGDialect): REGCLASS: AsyncpgREGCLASS, }, ) + is_async = True @util.memoized_property def _dbapi_version(self): @@ -792,8 +793,14 @@ class PGDialect_asyncpg(PGDialect): return ([], opts) @classmethod - def get_pool_class(self, url): - return pool.AsyncAdaptedQueuePool + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool def is_disconnect(self, e, connection, cursor): if connection: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a754ebe586..0086fe310b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -210,6 +210,8 @@ class DefaultDialect(interfaces.Dialect): """ + is_async = False + CACHE_HIT = CACHE_HIT CACHE_MISS = CACHE_MISS CACHING_DISABLED = CACHING_DISABLED diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py index 353f34333c..7254ec0f7f 100644 --- a/lib/sqlalchemy/pool/__init__.py +++ b/lib/sqlalchemy/pool/__init__.py @@ -29,6 +29,7 @@ from .dbapi_proxy import clear_managers from .dbapi_proxy import manage from .impl import AssertionPool from .impl import AsyncAdaptedQueuePool +from .impl import FallbackAsyncAdaptedQueuePool from .impl import NullPool from .impl import QueuePool from .impl import SingletonThreadPool @@ -46,6 +47,7 @@ __all__ = [ "NullPool", "QueuePool", "AsyncAdaptedQueuePool", + "FallbackAsyncAdaptedQueuePool", "SingletonThreadPool", "StaticPool", ] diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index 38afbc7a1a..312b1b732b 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -226,6 +226,10 @@ class AsyncAdaptedQueuePool(QueuePool): _queue_class = sqla_queue.AsyncAdaptedQueue +class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool): + _queue_class = sqla_queue.FallbackAsyncAdaptedQueue + + class NullPool(Pool): """A Pool which does not pool connections. diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py new file mode 100644 index 0000000000..52386d33e1 --- /dev/null +++ b/lib/sqlalchemy/testing/asyncio.py @@ -0,0 +1,124 @@ +# testing/asyncio.py +# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + + +# functions and wrappers to run tests, fixtures, provisioning and +# setup/teardown in an asyncio event loop, conditionally based on the +# current DB driver being used for a test. + +# note that SQLAlchemy's asyncio integration also supports a method +# of running individual asyncio functions inside of separate event loops +# using "async_fallback" mode; however running whole functions in the event +# loop is a more accurate test for how SQLAlchemy's asyncio features +# would run in the real world. + + +from functools import wraps +import inspect + +from . import config +from ..util.concurrency import _util_async_run + +# may be set to False if the +# --disable-asyncio flag is passed to the test runner. +ENABLE_ASYNCIO = True + + +def _assume_async(fn, *args, **kwargs): + """Run a function in an asyncio loop unconditionally. + + This function is used for provisioning features like + testing a database connection for server info. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + + if not ENABLE_ASYNCIO: + return fn(*args, **kwargs) + + return _util_async_run(fn, *args, **kwargs) + + +def _maybe_async_provisioning(fn, *args, **kwargs): + """Run a function in an asyncio loop if any current drivers might need it. + + This function is used for provisioning features that take + place outside of a specific database driver being selected, so if the + current driver that happens to be used for the provisioning operation + is an async driver, it will run in asyncio and not fail. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + if not ENABLE_ASYNCIO: + + return fn(*args, **kwargs) + + if config.any_async: + return _util_async_run(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async(fn, *args, **kwargs): + """Run a function in an asyncio loop if the current selected driver is + async. + + This function is used for test setup/teardown and tests themselves + where the current DB driver is known. + + + """ + if not ENABLE_ASYNCIO: + + return fn(*args, **kwargs) + + is_async = config._current.is_async + + if is_async: + return _util_async_run(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async_wrapper(fn): + """Apply the _maybe_async function to an existing function and return + as a wrapped callable, supporting generator functions as well. + + This is currently used for pytest fixtures that support generator use. + + """ + + if inspect.isgeneratorfunction(fn): + _stop = object() + + def call_next(gen): + try: + return next(gen) + # can't raise StopIteration in an awaitable. + except StopIteration: + return _stop + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + gen = fn(*args, **kwargs) + while True: + value = _maybe_async(call_next, gen) + if value is _stop: + break + yield value + + else: + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + return _maybe_async(fn, *args, **kwargs) + + return wrap_fixture diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 0b8027b844..270ac4c2c0 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -7,6 +7,8 @@ import collections +from .. import util + requirements = None db = None db_url = None @@ -14,6 +16,7 @@ db_opts = None file_config = None test_schema = None test_schema_2 = None +any_async = False _current = None ident = "main" @@ -104,6 +107,10 @@ class Config(object): self.test_schema = "test_schema" self.test_schema_2 = "test_schema_2" + self.is_async = db.dialect.is_async and not util.asbool( + db.url.query.get("async_fallback", False) + ) + _stack = collections.deque() _configs = set() @@ -121,7 +128,15 @@ class Config(object): If there are no configs set up yet, this config also gets set as the "_current". """ + global any_async + cfg = Config(db, db_opts, options, file_config) + + # if any backends include an async driver, then ensure + # all setup/teardown and tests are wrapped in the maybe_async() + # decorator that will set up a greenlet context for async drivers. + any_async = any_async or cfg.is_async + cls._configs.add(cfg) return cfg diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index bb137cb328..d0a1bc0d0d 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -46,7 +46,7 @@ class ConnectionKiller(object): fn() except Exception as e: warnings.warn( - "testing_reaper couldn't " "rollback/close connection: %s" % e + "testing_reaper couldn't rollback/close connection: %s" % e ) def rollback_all(self): @@ -199,9 +199,7 @@ class ReconnectFixture(object): try: fn() except Exception as e: - warnings.warn( - "ReconnectFixture couldn't " "close connection: %s" % e - ) + warnings.warn("ReconnectFixture couldn't close connection: %s" % e) def shutdown(self, stop=False): # TODO: this doesn't cover all cases diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 0ede25176a..a52fdd1967 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -48,6 +48,11 @@ class TestBase(object): # skipped. __skip_if__ = None + # If this class should be wrapped in asyncio compatibility functions + # when using an async engine. This should be set to False only for tests + # that use the asyncio features of sqlalchemy directly + __asyncio_wrap__ = True + def assert_(self, val, msg=None): assert val, msg @@ -90,6 +95,12 @@ class TestBase(object): # engines.drop_all_tables(metadata, config.db) +class AsyncTestBase(TestBase): + """Mixin marking a test as using its own explicit asyncio patterns.""" + + __asyncio_wrap__ = False + + class FutureEngineMixin(object): @classmethod def setup_class(cls): diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 5e41f2cdfc..8b6a7d68ab 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -63,21 +63,21 @@ def setup_options(make_option): make_option( "--log-info", action="callback", - type="string", + type=str, callback=_log, help="turn on info logging for (multiple OK)", ) make_option( "--log-debug", action="callback", - type="string", + type=str, callback=_log, help="turn on debug logging for (multiple OK)", ) make_option( "--db", action="append", - type="string", + type=str, dest="db", help="Use prefab database uri. Multiple OK, " "first one is run by default.", @@ -91,7 +91,7 @@ def setup_options(make_option): make_option( "--dburi", action="append", - type="string", + type=str, dest="dburi", help="Database uri. Multiple OK, " "first one is run by default.", ) @@ -110,6 +110,11 @@ def setup_options(make_option): dest="dropfirst", help="Drop all tables in the target database first", ) + make_option( + "--disable-asyncio", + action="store_true", + help="disable test / fixtures / provisoning running in asyncio", + ) make_option( "--backend-only", action="store_true", @@ -130,20 +135,20 @@ def setup_options(make_option): ) make_option( "--profile-sort", - type="string", + type=str, default="cumulative", dest="profilesort", help="Type of sort for profiling standard output", ) make_option( "--profile-dump", - type="string", + type=str, dest="profiledump", help="Filename where a single profile run will be dumped", ) make_option( "--postgresql-templatedb", - type="string", + type=str, help="name of template database to use for PostgreSQL " "CREATE DATABASE (defaults to current database)", ) @@ -156,7 +161,7 @@ def setup_options(make_option): ) make_option( "--write-idents", - type="string", + type=str, dest="write_idents", help="write out generated follower idents to , " "when -n is used", @@ -172,7 +177,7 @@ def setup_options(make_option): make_option( "--requirements", action="callback", - type="string", + type=str, callback=_requirements_opt, help="requirements class for testing, overrides setup.cfg", ) @@ -188,14 +193,14 @@ def setup_options(make_option): "--include-tag", action="callback", callback=_include_tag, - type="string", + type=str, help="Include tests with tag ", ) make_option( "--exclude-tag", action="callback", callback=_exclude_tag, - type="string", + type=str, help="Exclude tests with tag ", ) make_option( @@ -374,11 +379,19 @@ def _init_symbols(options, file_config): config._fixture_functions = _fixture_fn_class() +@post +def _set_disable_asyncio(opt, file_config): + if opt.disable_asyncio: + from sqlalchemy.testing import asyncio + + asyncio.ENABLE_ASYNCIO = False + + @post def _engine_uri(options, file_config): - from sqlalchemy.testing import config from sqlalchemy import testing + from sqlalchemy.testing import config from sqlalchemy.testing import provision if options.dburi: diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 644ea6dc20..6be64aa610 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -25,11 +25,6 @@ else: if typing.TYPE_CHECKING: from typing import Sequence -try: - import asyncio -except ImportError: - pass - try: import xdist # noqa @@ -126,11 +121,15 @@ def collect_types_fixture(): def pytest_sessionstart(session): - plugin_base.post_begin() + from sqlalchemy.testing import asyncio + + asyncio._assume_async(plugin_base.post_begin) def pytest_sessionfinish(session): - plugin_base.final_process_cleanup() + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup) if session.config.option.dump_pyannotate: from pyannotate_runtime import collect_types @@ -162,23 +161,31 @@ if has_xdist: import uuid def pytest_configure_node(node): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + # the master for each node fills workerinput dictionary # which pytest-xdist will transfer to the subprocess plugin_base.memoize_important_follower_config(node.workerinput) node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] - from sqlalchemy.testing import provision - provision.create_follower_db(node.workerinput["follower_ident"]) + asyncio._maybe_async_provisioning( + provision.create_follower_db, node.workerinput["follower_ident"] + ) def pytest_testnodedown(node, error): from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio - provision.drop_follower_db(node.workerinput["follower_ident"]) + asyncio._maybe_async_provisioning( + provision.drop_follower_db, node.workerinput["follower_ident"] + ) def pytest_collection_modifyitems(session, config, items): + # look for all those classes that specify __backend__ and # expand them out into per-database test cases. @@ -189,6 +196,8 @@ def pytest_collection_modifyitems(session, config, items): # it's to suit the rather odd use case here which is that we are adding # new classes to a module on the fly. + from sqlalchemy.testing import asyncio + rebuilt_items = collections.defaultdict( lambda: collections.defaultdict(list) ) @@ -201,20 +210,26 @@ def pytest_collection_modifyitems(session, config, items): ] test_classes = set(item.parent for item in items) - for test_class in test_classes: - for sub_cls in plugin_base.generate_sub_tests( - test_class.cls, test_class.parent.module - ): - if sub_cls is not test_class.cls: - per_cls_dict = rebuilt_items[test_class.cls] - # support pytest 5.4.0 and above pytest.Class.from_parent - ctor = getattr(pytest.Class, "from_parent", pytest.Class) - for inst in ctor( - name=sub_cls.__name__, parent=test_class.parent.parent - ).collect(): - for t in inst.collect(): - per_cls_dict[t.name].append(t) + def setup_test_classes(): + for test_class in test_classes: + for sub_cls in plugin_base.generate_sub_tests( + test_class.cls, test_class.parent.module + ): + if sub_cls is not test_class.cls: + per_cls_dict = rebuilt_items[test_class.cls] + + # support pytest 5.4.0 and above pytest.Class.from_parent + ctor = getattr(pytest.Class, "from_parent", pytest.Class) + for inst in ctor( + name=sub_cls.__name__, parent=test_class.parent.parent + ).collect(): + for t in inst.collect(): + per_cls_dict[t.name].append(t) + + # class requirements will sometimes need to access the DB to check + # capabilities, so need to do this for async + asyncio._maybe_async_provisioning(setup_test_classes) newitems = [] for item in items: @@ -238,6 +253,10 @@ def pytest_collection_modifyitems(session, config, items): def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(name, obj): + from sqlalchemy.testing import config + + if config.any_async and getattr(obj, "__asyncio_wrap__", True): + obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) @@ -258,6 +277,38 @@ def pytest_pycollect_makeitem(collector, name, obj): return [] +def _apply_maybe_async(obj, recurse=True): + from sqlalchemy.testing import asyncio + + setup_names = {"setup", "setup_class", "teardown", "teardown_class"} + for name, value in vars(obj).items(): + if ( + (callable(value) or isinstance(value, classmethod)) + and not getattr(value, "_maybe_async_applied", False) + and (name.startswith("test_") or name in setup_names) + ): + is_classmethod = False + if isinstance(value, classmethod): + value = value.__func__ + is_classmethod = True + + @_pytest_fn_decorator + def make_async(fn, *args, **kwargs): + return asyncio._maybe_async(fn, *args, **kwargs) + + do_async = make_async(value) + if is_classmethod: + do_async = classmethod(do_async) + do_async._maybe_async_applied = True + + setattr(obj, name, do_async) + if recurse: + for cls in obj.mro()[1:]: + if cls != object: + _apply_maybe_async(cls, False) + return obj + + _current_class = None @@ -297,6 +348,8 @@ def _parametrize_cls(module, cls): def pytest_runtest_setup(item): + from sqlalchemy.testing import asyncio + # here we seem to get called only based on what we collected # in pytest_collection_modifyitems. So to do class-based stuff # we have to tear that out. @@ -307,7 +360,7 @@ def pytest_runtest_setup(item): # ... so we're doing a little dance here to figure it out... if _current_class is None: - class_setup(item.parent.parent) + asyncio._maybe_async(class_setup, item.parent.parent) _current_class = item.parent.parent # this is needed for the class-level, to ensure that the @@ -315,20 +368,22 @@ def pytest_runtest_setup(item): # class-level teardown... def finalize(): global _current_class - class_teardown(item.parent.parent) + asyncio._maybe_async(class_teardown, item.parent.parent) _current_class = None item.parent.parent.addfinalizer(finalize) - test_setup(item) + asyncio._maybe_async(test_setup, item) def pytest_runtest_teardown(item): + from sqlalchemy.testing import asyncio + # ...but this works better as the hook here rather than # using a finalizer, as the finalizer seems to get in the way # of the test reporting failures correctly (you get a bunch of # pytest assertion stuff instead) - test_teardown(item) + asyncio._maybe_async(test_teardown, item) def test_setup(item): @@ -342,7 +397,9 @@ def test_teardown(item): def class_setup(item): - plugin_base.start_test_class(item.cls) + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls) def class_teardown(item): @@ -372,17 +429,19 @@ def _pytest_fn_decorator(target): if add_positional_parameters: spec.args.extend(add_positional_parameters) - metadata = dict(target="target", fn="__fn", name=fn.__name__) + metadata = dict( + __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__ + ) metadata.update(format_argspec_plus(spec, grouped=False)) code = ( """\ def %(name)s(%(args)s): - return %(target)s(%(fn)s, %(apply_kw)s) + return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s) """ % metadata ) decorated = _exec_code_in_env( - code, {"target": target, "__fn": fn}, fn.__name__ + code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__ ) if not add_positional_parameters: decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ @@ -554,14 +613,49 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): return pytest.param(*parameters[1:], id=ident) def fixture(self, *arg, **kw): - return pytest.fixture(*arg, **kw) + from sqlalchemy.testing import config + from sqlalchemy.testing import asyncio + + # wrapping pytest.fixture function. determine if + # decorator was called as @fixture or @fixture(). + if len(arg) > 0 and callable(arg[0]): + # was called as @fixture(), we have the function to wrap. + fn = arg[0] + arg = arg[1:] + else: + # was called as @fixture, don't have the function yet. + fn = None + + # create a pytest.fixture marker. because the fn is not being + # passed, this is always a pytest.FixtureFunctionMarker() + # object (or whatever pytest is calling it when you read this) + # that is waiting for a function. + fixture = pytest.fixture(*arg, **kw) + + # now apply wrappers to the function, including fixture itself + + def wrap(fn): + if config.any_async: + fn = asyncio._maybe_async_wrapper(fn) + # other wrappers may be added here + + # now apply FixtureFunctionMarker + fn = fixture(fn) + return fn + + if fn: + return wrap(fn) + else: + return wrap def get_current_test_name(self): return os.environ.get("PYTEST_CURRENT_TEST") def async_test(self, fn): + from sqlalchemy.testing import asyncio + @_pytest_fn_decorator def decorate(fn, *args, **kwargs): - asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs)) + asyncio._assume_async(fn, *args, **kwargs) return decorate(fn) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index c4f489a69d..fb3d77dc40 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -94,11 +94,11 @@ def generate_db_urls(db_urls, extra_drivers): --dburi postgresql://db2 \ --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true - Noting that the default postgresql driver is psycopg2. the output + Noting that the default postgresql driver is psycopg2, the output would be:: postgresql+psycopg2://db1 - postgresql+asyncpg://db1?async_fallback=true + postgresql+asyncpg://db1 postgresql+psycopg2://db2 postgresql+psycopg2://db3 @@ -108,6 +108,12 @@ def generate_db_urls(db_urls, extra_drivers): for a driver that is both coming from --dburi as well as --dbdrivers, we want to keep it in that dburi. + Driver specific query options can be specified by added them to the + driver name. For example, to enable the async fallback option for + asyncpg:: + + --dburi postgresql://db1 \ + --dbdriver=asyncpg?async_fallback=true """ urls = set() diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index c6626b9e08..bbaf5034f8 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -11,13 +11,24 @@ import random import sys import types +from . import config from . import mock +from .. import inspect +from ..schema import Column +from ..schema import DropConstraint +from ..schema import DropTable +from ..schema import ForeignKeyConstraint +from ..schema import MetaData +from ..schema import Table +from ..sql import schema +from ..sql.sqltypes import Integer from ..util import decorator from ..util import defaultdict from ..util import has_refcount_gc from ..util import inspect_getfullargspec from ..util import py2k + if not has_refcount_gc: def non_refcount_gc_collect(*args): @@ -198,9 +209,9 @@ def fail(msg): def provide_metadata(fn, *args, **kw): """Provide bound MetaData for a single test, dropping afterwards.""" - from . import config + # import cycle that only occurs with py2k's import resolver + # in py3k this can be moved top level. from . import engines - from sqlalchemy import schema metadata = schema.MetaData(config.db) self = args[0] @@ -243,8 +254,6 @@ def flag_combinations(*combinations): """ - from . import config - keys = set() for d in combinations: @@ -264,8 +273,6 @@ def flag_combinations(*combinations): def lambda_combinations(lambda_arg_sets, **kw): - from . import config - args = inspect_getfullargspec(lambda_arg_sets) arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]]) @@ -302,11 +309,8 @@ def resolve_lambda(__fn, **kw): def metadata_fixture(ddl="function"): """Provide MetaData for a pytest fixture.""" - from . import config - def decorate(fn): def run_ddl(self): - from sqlalchemy import schema metadata = self.metadata = schema.MetaData() try: @@ -328,8 +332,6 @@ def force_drop_names(*names): isolating for foreign key cycles """ - from . import config - from sqlalchemy import inspect @decorator def go(fn, *args, **kw): @@ -358,14 +360,6 @@ class adict(dict): def drop_all_tables(engine, inspector, schema=None, include_names=None): - from sqlalchemy import ( - Column, - Table, - Integer, - MetaData, - ForeignKeyConstraint, - ) - from sqlalchemy.schema import DropTable, DropConstraint if include_names is not None: include_names = set(include_names) diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 8ad3be5439..6042e4395a 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -64,7 +64,6 @@ def await_fallback(awaitable: Coroutine) -> Any: :param awaitable: The coroutine to call. """ - # this is called in the context greenlet while running fn current = greenlet.getcurrent() if not isinstance(current, _AsyncIoGreenlet): @@ -135,3 +134,15 @@ class AsyncAdaptedLock: def __exit__(self, *arg, **kw): self.mutex.release() + + +def _util_async_run(fn, *args, **kwargs): + """for test suite/ util only""" + + loop = asyncio.get_event_loop() + if not loop.is_running(): + return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs)) + else: + # allow for a wrapped test function to call another + assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet) + return fn(*args, **kwargs) diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index f78c0971c7..7b4ff6ba40 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -13,6 +13,7 @@ if compat.py3k: from ._concurrency_py3k import await_fallback from ._concurrency_py3k import greenlet_spawn from ._concurrency_py3k import AsyncAdaptedLock + from ._concurrency_py3k import _util_async_run # noqa F401 from ._concurrency_py3k import asyncio # noqa F401 if not have_greenlet: @@ -38,3 +39,6 @@ if not have_greenlet: def AsyncAdaptedLock(*args, **kw): # noqa F81 _not_implemented() + + def _util_async_run(fn, *arg, **kw): # noqa F81 + return fn(*arg, **kw) diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 3687dc8dc3..a92d6b862e 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -25,6 +25,7 @@ from . import compat from .compat import threading from .concurrency import asyncio from .concurrency import await_fallback +from .concurrency import await_only __all__ = ["Empty", "Full", "Queue"] @@ -202,7 +203,7 @@ class Queue: class AsyncAdaptedQueue: - await_ = staticmethod(await_fallback) + await_ = staticmethod(await_only) def __init__(self, maxsize=0, use_lifo=False): if use_lifo: @@ -265,3 +266,7 @@ class AsyncAdaptedQueue: Empty(), replace_context=err, ) + + +class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue): + await_ = staticmethod(await_fallback) diff --git a/setup.cfg b/setup.cfg index 46fe781044..e151769dac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ packages = find: python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.* package_dir = =lib +# TODO remove greenlet from the default requires? install_requires = importlib-metadata;python_version<"3.8" greenlet != 0.4.17;python_version>="3" @@ -120,12 +121,14 @@ default = sqlite:///:memory: sqlite = sqlite:///:memory: sqlite_file = sqlite:///querytest.db postgresql = postgresql://scott:tiger@127.0.0.1:5432/test -asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true +asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test +asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 -aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true +aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 +aiomysql_fallback = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true mariadb = mariadb://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008 diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 5e388c0b7d..2284c1326c 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1568,7 +1568,7 @@ class CycleTest(_fixtures.FixtureTest): go() - @testing.fails + @testing.fails() def test_the_counter(self): @assert_cycles() def go(): diff --git a/test/base/test_concurrency_py3k.py b/test/base/test_concurrency_py3k.py index cf1067667d..2cc2075bcd 100644 --- a/test/base/test_concurrency_py3k.py +++ b/test/base/test_concurrency_py3k.py @@ -26,7 +26,7 @@ def go(*fns): return sum(await_only(fn()) for fn in fns) -class TestAsyncioCompat(fixtures.TestBase): +class TestAsyncioCompat(fixtures.AsyncTestBase): @async_test async def test_ok(self): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index ae7a65a3af..38da60a434 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -2734,14 +2734,14 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): def test_where_equal(self): self._test_clause( - self.col == self._data_str, + self.col == self._data_str(), "data_table.range = %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_not_equal(self): self._test_clause( - self.col != self._data_str, + self.col != self._data_str(), "data_table.range <> %(range_1)s", sqltypes.BOOLEANTYPE, ) @@ -2760,94 +2760,94 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): def test_where_less_than(self): self._test_clause( - self.col < self._data_str, + self.col < self._data_str(), "data_table.range < %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_greater_than(self): self._test_clause( - self.col > self._data_str, + self.col > self._data_str(), "data_table.range > %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_less_than_or_equal(self): self._test_clause( - self.col <= self._data_str, + self.col <= self._data_str(), "data_table.range <= %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_where_greater_than_or_equal(self): self._test_clause( - self.col >= self._data_str, + self.col >= self._data_str(), "data_table.range >= %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_contains(self): self._test_clause( - self.col.contains(self._data_str), + self.col.contains(self._data_str()), "data_table.range @> %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_contained_by(self): self._test_clause( - self.col.contained_by(self._data_str), + self.col.contained_by(self._data_str()), "data_table.range <@ %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_overlaps(self): self._test_clause( - self.col.overlaps(self._data_str), + self.col.overlaps(self._data_str()), "data_table.range && %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_strictly_left_of(self): self._test_clause( - self.col << self._data_str, + self.col << self._data_str(), "data_table.range << %(range_1)s", sqltypes.BOOLEANTYPE, ) self._test_clause( - self.col.strictly_left_of(self._data_str), + self.col.strictly_left_of(self._data_str()), "data_table.range << %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_strictly_right_of(self): self._test_clause( - self.col >> self._data_str, + self.col >> self._data_str(), "data_table.range >> %(range_1)s", sqltypes.BOOLEANTYPE, ) self._test_clause( - self.col.strictly_right_of(self._data_str), + self.col.strictly_right_of(self._data_str()), "data_table.range >> %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_not_extend_right_of(self): self._test_clause( - self.col.not_extend_right_of(self._data_str), + self.col.not_extend_right_of(self._data_str()), "data_table.range &< %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_not_extend_left_of(self): self._test_clause( - self.col.not_extend_left_of(self._data_str), + self.col.not_extend_left_of(self._data_str()), "data_table.range &> %(range_1)s", sqltypes.BOOLEANTYPE, ) def test_adjacent_to(self): self._test_clause( - self.col.adjacent_to(self._data_str), + self.col.adjacent_to(self._data_str()), "data_table.range -|- %(range_1)s", sqltypes.BOOLEANTYPE, ) @@ -2920,14 +2920,14 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_insert_text(self, connection): connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) self._assert_data(connection) def test_union_result(self, connection): # insert connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) # select range_ = self.tables.data_table.c.range @@ -2937,7 +2937,7 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_intersection_result(self, connection): # insert connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) # select range_ = self.tables.data_table.c.range @@ -2947,7 +2947,7 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_difference_result(self, connection): # insert connection.execute( - self.tables.data_table.insert(), {"range": self._data_str} + self.tables.data_table.insert(), {"range": self._data_str()} ) # select range_ = self.tables.data_table.c.range @@ -2959,7 +2959,9 @@ class _Int4RangeTests(object): _col_type = INT4RANGE _col_str = "INT4RANGE" - _data_str = "[1,2)" + + def _data_str(self): + return "[1,2)" def _data_obj(self): return self.extras().NumericRange(1, 2) @@ -2969,7 +2971,9 @@ class _Int8RangeTests(object): _col_type = INT8RANGE _col_str = "INT8RANGE" - _data_str = "[9223372036854775806,9223372036854775807)" + + def _data_str(self): + return "[9223372036854775806,9223372036854775807)" def _data_obj(self): return self.extras().NumericRange( @@ -2981,7 +2985,9 @@ class _NumRangeTests(object): _col_type = NUMRANGE _col_str = "NUMRANGE" - _data_str = "[1.0,2.0)" + + def _data_str(self): + return "[1.0,2.0)" def _data_obj(self): return self.extras().NumericRange( @@ -2993,7 +2999,9 @@ class _DateRangeTests(object): _col_type = DATERANGE _col_str = "DATERANGE" - _data_str = "[2013-03-23,2013-03-24)" + + def _data_str(self): + return "[2013-03-23,2013-03-24)" def _data_obj(self): return self.extras().DateRange( @@ -3005,7 +3013,9 @@ class _DateTimeRangeTests(object): _col_type = TSRANGE _col_str = "TSRANGE" - _data_str = "[2013-03-23 14:30,2013-03-23 23:30)" + + def _data_str(self): + return "[2013-03-23 14:30,2013-03-23 23:30)" def _data_obj(self): return self.extras().DateTimeRange( @@ -3031,7 +3041,6 @@ class _DateTimeTZRangeTests(object): self._tstzs = (lower, upper) return self._tstzs - @property def _data_str(self): return "[%s,%s)" % self.tstzs() @@ -3178,7 +3187,7 @@ class JSONRoundTripTest(fixtures.TablesTest): __only_on__ = ("postgresql >= 9.3",) __backend__ = True - test_type = JSON + data_type = JSON @classmethod def define_tables(cls, metadata): @@ -3187,8 +3196,8 @@ class JSONRoundTripTest(fixtures.TablesTest): metadata, Column("id", Integer, primary_key=True), Column("name", String(30), nullable=False), - Column("data", cls.test_type), - Column("nulldata", cls.test_type(none_as_null=True)), + Column("data", cls.data_type), + Column("nulldata", cls.data_type(none_as_null=True)), ) def _fixture_data(self, engine): @@ -3255,7 +3264,7 @@ class JSONRoundTripTest(fixtures.TablesTest): def test_reflect(self): insp = inspect(testing.db) cols = insp.get_columns("data_table") - assert isinstance(cols[2]["type"], self.test_type) + assert isinstance(cols[2]["type"], self.data_type) def test_insert(self, connection): self._test_insert(connection) @@ -3286,7 +3295,7 @@ class JSONRoundTripTest(fixtures.TablesTest): options=dict(json_serializer=dumps, json_deserializer=loads) ) - s = select(cast({"key": "value", "x": "q"}, self.test_type)) + s = select(cast({"key": "value", "x": "q"}, self.data_type)) with engine.begin() as conn: eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) @@ -3366,7 +3375,7 @@ class JSONRoundTripTest(fixtures.TablesTest): s = select( cast( {"key": "value", "key2": {"k1": "v1", "k2": "v2"}}, - self.test_type, + self.data_type, ) ) eq_( @@ -3381,7 +3390,7 @@ class JSONRoundTripTest(fixtures.TablesTest): util.u("réveillé"): util.u("réveillé"), "data": {"k1": util.u("drôle")}, }, - self.test_type, + self.data_type, ) ) eq_( @@ -3483,7 +3492,7 @@ class JSONBTest(JSONTest): class JSONBRoundTripTest(JSONRoundTripTest): __requires__ = ("postgresql_jsonb",) - test_type = JSONB + data_type = JSONB @testing.requires.postgresql_utf8_server_encoding def test_unicode_round_trip(self, connection): diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index c0d5a93d50..7989026d13 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -3626,6 +3626,13 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): "pg8000 parses the SQL itself before passing on " "to PG, doesn't parse this", ) + @testing.fails_on( + "postgresql+asyncpg", + "Asyncpg uses preprated statements that are not compatible with how " + "sqlalchemy passes the query. Fails with " + 'ERROR: column "users.name" must appear in the GROUP BY clause' + " or be used in an aggregate function", + ) @testing.fails_on("firebird", "unknown") def test_values_with_boolean_selects(self): """Tests a values clause that works with select boolean diff --git a/test/requirements.py b/test/requirements.py index cb2f4840f6..5413d217e0 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1330,7 +1330,10 @@ class DefaultRequirements(SuiteRequirements): """dialect makes use of await_() to invoke operations on the DBAPI.""" return only_on( - ["postgresql+asyncpg", "mysql+aiomysql", "mariadb+aiomysql"] + LambdaPredicate( + lambda config: config.db.dialect.is_async, + "Async dialect required", + ) ) @property diff --git a/tox.ini b/tox.ini index 355e8446d0..8c0e0b7498 100644 --- a/tox.ini +++ b/tox.ini @@ -74,11 +74,11 @@ setenv= sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} - py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg?async_fallback=true --dbdriver pg8000} + py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000} mysql: MYSQL={env:TOX_MYSQL:--db mysql} mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql} - py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql?async_fallback=true} + py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql} mssql: MSSQL={env:TOX_MSSQL:--db mssql}