From: CaselIT Date: Tue, 15 Dec 2020 22:17:20 +0000 (+0100) Subject: Use exclusions from sqlalchemy X-Git-Tag: rel_1_5_0~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=22af3d1cd92dbd485988be3422f2fc1f2ca9ea0b;p=thirdparty%2Fsqlalchemy%2Falembic.git Use exclusions from sqlalchemy Since alembic 1.5 will support only sqlalchemy 1.3+ it can safely use the exclusions from sqlalchemy. Enable test for computed columns with sqlite >= 3.31 Fixes: #759 Change-Id: If629e0c0aa264205fa571e26b3ba8398634c374d --- diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index f009da93..23c0f19b 100644 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -1,4 +1,5 @@ from sqlalchemy.testing import config # noqa +from sqlalchemy.testing import exclusions # noqa from sqlalchemy.testing import emits_warning # noqa from sqlalchemy.testing import engines # noqa from sqlalchemy.testing import mock # noqa @@ -7,7 +8,6 @@ from sqlalchemy.testing import uses_deprecated # noqa from sqlalchemy.testing.config import requirements as requires # noqa from alembic import util # noqa -from . import exclusions # noqa from .assertions import assert_raises # noqa from .assertions import assert_raises_message # noqa from .assertions import emits_python_deprecation_warning # noqa diff --git a/alembic/testing/exclusions.py b/alembic/testing/exclusions.py deleted file mode 100644 index 91f2d5b6..00000000 --- a/alembic/testing/exclusions.py +++ /dev/null @@ -1,484 +0,0 @@ -# testing/exclusions.py -# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - - -import contextlib -import operator -import re - -from sqlalchemy import util as sqla_util -from sqlalchemy.util import decorator - -from . import config -from . import fixture_functions -from .. import util -from ..util.compat import inspect_getargspec - - -def skip_if(predicate, reason=None): - rule = compound() - pred = _as_predicate(predicate, reason) - rule.skips.add(pred) - return rule - - -def fails_if(predicate, reason=None): - rule = compound() - pred = _as_predicate(predicate, reason) - rule.fails.add(pred) - return rule - - -class compound(object): - def __init__(self): - self.fails = set() - self.skips = set() - self.tags = set() - self.combinations = {} - - def __add__(self, other): - return self.add(other) - - def with_combination(self, **kw): - copy = compound() - copy.fails.update(self.fails) - copy.skips.update(self.skips) - copy.tags.update(self.tags) - copy.combinations.update((f, kw) for f in copy.fails) - copy.combinations.update((s, kw) for s in copy.skips) - return copy - - def add(self, *others): - copy = compound() - copy.fails.update(self.fails) - copy.skips.update(self.skips) - copy.tags.update(self.tags) - for other in others: - copy.fails.update(other.fails) - copy.skips.update(other.skips) - copy.tags.update(other.tags) - return copy - - def not_(self): - copy = compound() - copy.fails.update(NotPredicate(fail) for fail in self.fails) - copy.skips.update(NotPredicate(skip) for skip in self.skips) - copy.tags.update(self.tags) - return copy - - @property - def enabled(self): - return self.enabled_for_config(config._current) - - def enabled_for_config(self, config): - for predicate in self.skips.union(self.fails): - if predicate(config): - return False - else: - return True - - def matching_config_reasons(self, config): - return [ - predicate._as_string(config) - for predicate in self.skips.union(self.fails) - if predicate(config) - ] - - def include_test(self, include_tags, exclude_tags): - return bool( - not self.tags.intersection(exclude_tags) - and (not include_tags or self.tags.intersection(include_tags)) - ) - - def _extend(self, other): - self.skips.update(other.skips) - self.fails.update(other.fails) - self.tags.update(other.tags) - self.combinations.update(other.combinations) - - def __call__(self, fn): - if hasattr(fn, "_sa_exclusion_extend"): - fn._sa_exclusion_extend._extend(self) - return fn - - @decorator - def decorate(fn, *args, **kw): - return self._do(config._current, fn, *args, **kw) - - decorated = decorate(fn) - decorated._sa_exclusion_extend = self - return decorated - - @contextlib.contextmanager - def fail_if(self): - all_fails = compound() - all_fails.fails.update(self.skips.union(self.fails)) - - try: - yield - except Exception as ex: - all_fails._expect_failure(config._current, ex, None) - else: - all_fails._expect_success(config._current, None) - - def _check_combinations(self, combination, predicate): - if predicate in self.combinations: - for k, v in combination: - if ( - k in self.combinations[predicate] - and self.combinations[predicate][k] != v - ): - return False - return True - - def _do(self, cfg, fn, *args, **kw): - if len(args) > 1: - insp = inspect_getargspec(fn) - combination = list(zip(insp.args[1:], args[1:])) - else: - combination = None - - for skip in self.skips: - if self._check_combinations(combination, skip) and skip(cfg): - msg = "'%s' : %s" % ( - fixture_functions.get_current_test_name(), - skip._as_string(cfg), - ) - config.skip_test(msg) - - try: - return_value = fn(*args, **kw) - except Exception as ex: - self._expect_failure(cfg, ex, combination, name=fn.__name__) - else: - self._expect_success(cfg, combination, name=fn.__name__) - return return_value - - def _expect_failure(self, config, ex, combination, name="block"): - for fail in self.fails: - if self._check_combinations(combination, fail) and fail(config): - if sqla_util.py2k: - str_ex = unicode(ex).encode( # noqa: F821 - "utf-8", errors="ignore" - ) - else: - str_ex = str(ex) - print( - ( - "%s failed as expected (%s): %s " - % (name, fail._as_string(config), str_ex) - ) - ) - break - else: - util.raise_from_cause(ex) - - def _expect_success(self, config, combination, name="block"): - if not self.fails: - return - - for fail in self.fails: - if self._check_combinations(combination, fail) and fail(config): - raise AssertionError( - "Unexpected success for '%s' (%s)" - % ( - name, - " and ".join( - fail._as_string(config) for fail in self.fails - ), - ) - ) - - -def requires_tag(tagname): - return tags([tagname]) - - -def tags(tagnames): - comp = compound() - comp.tags.update(tagnames) - return comp - - -def only_if(predicate, reason=None): - predicate = _as_predicate(predicate) - return skip_if(NotPredicate(predicate), reason) - - -def succeeds_if(predicate, reason=None): - predicate = _as_predicate(predicate) - return fails_if(NotPredicate(predicate), reason) - - -class Predicate(object): - @classmethod - def as_predicate(cls, predicate, description=None): - if isinstance(predicate, compound): - return cls.as_predicate(predicate.enabled_for_config, description) - elif isinstance(predicate, Predicate): - if description and predicate.description is None: - predicate.description = description - return predicate - elif isinstance(predicate, (list, set)): - return OrPredicate( - [cls.as_predicate(pred) for pred in predicate], description - ) - elif isinstance(predicate, tuple): - return SpecPredicate(*predicate) - elif isinstance(predicate, sqla_util.string_types): - tokens = re.match( - r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate - ) - if not tokens: - raise ValueError( - "Couldn't locate DB name in predicate: %r" % predicate - ) - db = tokens.group(1) - op = tokens.group(2) - spec = ( - tuple(int(d) for d in tokens.group(3).split(".")) - if tokens.group(3) - else None - ) - - return SpecPredicate(db, op, spec, description=description) - elif callable(predicate): - return LambdaPredicate(predicate, description) - else: - assert False, "unknown predicate type: %s" % predicate - - def _format_description(self, config, negate=False): - bool_ = self(config) - if negate: - bool_ = not negate - return self.description % { - "driver": config.db.url.get_driver_name() - if config - else "", - "database": config.db.url.get_backend_name() - if config - else "", - "doesnt_support": "doesn't support" if bool_ else "does support", - "does_support": "does support" if bool_ else "doesn't support", - } - - def _as_string(self, config=None, negate=False): - raise NotImplementedError() - - -class BooleanPredicate(Predicate): - def __init__(self, value, description=None): - self.value = value - self.description = description or "boolean %s" % value - - def __call__(self, config): - return self.value - - def _as_string(self, config, negate=False): - return self._format_description(config, negate=negate) - - -class SpecPredicate(Predicate): - def __init__(self, db, op=None, spec=None, description=None): - self.db = db - self.op = op - self.spec = spec - self.description = description - - _ops = { - "<": operator.lt, - ">": operator.gt, - "==": operator.eq, - "!=": operator.ne, - "<=": operator.le, - ">=": operator.ge, - "in": operator.contains, - "between": lambda val, pair: val >= pair[0] and val <= pair[1], - } - - def __call__(self, config): - engine = config.db - - if "+" in self.db: - dialect, driver = self.db.split("+") - else: - dialect, driver = self.db, None - - if dialect and engine.name != dialect: - return False - if driver is not None and engine.driver != driver: - return False - - if self.op is not None: - assert driver is None, "DBAPI version specs not supported yet" - - version = _server_version(engine) - oper = ( - hasattr(self.op, "__call__") and self.op or self._ops[self.op] - ) - return oper(version, self.spec) - else: - return True - - def _as_string(self, config, negate=False): - if self.description is not None: - return self._format_description(config) - elif self.op is None: - if negate: - return "not %s" % self.db - else: - return "%s" % self.db - else: - if negate: - return "not %s %s %s" % (self.db, self.op, self.spec) - else: - return "%s %s %s" % (self.db, self.op, self.spec) - - -class LambdaPredicate(Predicate): - def __init__(self, lambda_, description=None, args=None, kw=None): - spec = inspect_getargspec(lambda_) - if not spec[0]: - self.lambda_ = lambda db: lambda_() - else: - self.lambda_ = lambda_ - self.args = args or () - self.kw = kw or {} - if description: - self.description = description - elif lambda_.__doc__: - self.description = lambda_.__doc__ - else: - self.description = "custom function" - - def __call__(self, config): - return self.lambda_(config) - - def _as_string(self, config, negate=False): - return self._format_description(config) - - -class NotPredicate(Predicate): - def __init__(self, predicate, description=None): - self.predicate = predicate - self.description = description - - def __call__(self, config): - return not self.predicate(config) - - def _as_string(self, config, negate=False): - if self.description: - return self._format_description(config, not negate) - else: - return self.predicate._as_string(config, not negate) - - -class OrPredicate(Predicate): - def __init__(self, predicates, description=None): - self.predicates = predicates - self.description = description - - def __call__(self, config): - for pred in self.predicates: - if pred(config): - return True - return False - - def _eval_str(self, config, negate=False): - if negate: - conjunction = " and " - else: - conjunction = " or " - return conjunction.join( - p._as_string(config, negate=negate) for p in self.predicates - ) - - def _negation_str(self, config): - if self.description is not None: - return "Not " + self._format_description(config) - else: - return self._eval_str(config, negate=True) - - def _as_string(self, config, negate=False): - if negate: - return self._negation_str(config) - else: - if self.description is not None: - return self._format_description(config) - else: - return self._eval_str(config) - - -_as_predicate = Predicate.as_predicate - - -def _is_excluded(db, op, spec): - return SpecPredicate(db, op, spec)(config._current) - - -def _server_version(engine): - """Return a server_version_info tuple.""" - - # force metadata to be retrieved - conn = engine.connect() - version = getattr(engine.dialect, "server_version_info", None) - if version is None: - version = () - conn.close() - return version - - -def db_spec(*dbs): - return OrPredicate([Predicate.as_predicate(db) for db in dbs]) - - -def open(): # noqa - return skip_if(BooleanPredicate(False, "mark as execute")) - - -def closed(): - return skip_if(BooleanPredicate(True, "marked as skip")) - - -def fails(reason=None): - return fails_if(BooleanPredicate(True, reason or "expected to fail")) - - -@decorator -def future(fn, *arg): - return fails_if(LambdaPredicate(fn), "Future feature") - - -def fails_on(db, reason=None): - return fails_if(db, reason) - - -def fails_on_everything_except(*dbs): - return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) - - -def skip(db, reason=None): - return skip_if(db, reason) - - -def only_on(dbs, reason=None): - return only_if( - OrPredicate( - [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)] - ) - ) - - -def exclude(db, op, spec, reason=None): - return skip_if(SpecPredicate(db, op, spec), reason) - - -def against(config, *queries): - assert queries, "no queries sent!" - return OrPredicate([Predicate.as_predicate(query) for query in queries])( - config - ) diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py index ba3d35bb..6b76a17e 100644 --- a/alembic/testing/plugin/pytestplugin.py +++ b/alembic/testing/plugin/pytestplugin.py @@ -11,6 +11,7 @@ except ImportError: # assume we're a package, use traditional import from . import plugin_base +from functools import update_wrapper import inspect import itertools import operator @@ -22,6 +23,16 @@ import pytest from sqlalchemy.testing.plugin.pytestplugin import * # noqa from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc +py3k = sys.version_info.major >= 3 + +if py3k: + from typing import TYPE_CHECKING +else: + TYPE_CHECKING = False + +if TYPE_CHECKING: + from typing import Sequence + # override selected SQLAlchemy pytest hooks with vendored functionality def pytest_configure(config): @@ -97,6 +108,50 @@ def getargspec(fn): return inspect.getargspec(fn) +def _pytest_fn_decorator(target): + """Port of langhelpers.decorator with pytest-specific tricks.""" + # from sqlalchemy rel_1_3_14 + + from sqlalchemy.util.langhelpers import format_argspec_plus + from sqlalchemy.util.compat import inspect_getfullargspec + + def _exec_code_in_env(code, env, fn_name): + exec(code, env) + return env[fn_name] + + def decorate(fn, add_positional_parameters=()): + + spec = inspect_getfullargspec(fn) + if add_positional_parameters: + spec.args.extend(add_positional_parameters) + + metadata = dict(target="target", fn="__fn", name=fn.__name__) + metadata.update(format_argspec_plus(spec, grouped=False)) + code = ( + """\ +def %(name)s(%(args)s): + return %(target)s(%(fn)s, %(apply_kw)s) +""" + % metadata + ) + decorated = _exec_code_in_env( + code, {"target": target, "__fn": fn}, fn.__name__ + ) + if not add_positional_parameters: + decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + decorated.__wrapped__ = fn + return update_wrapper(decorated, fn) + else: + # this is the pytest hacky part. don't do a full update wrapper + # because pytest is really being sneaky about finding the args + # for the wrapped function + decorated.__module__ = fn.__module__ + decorated.__name__ = fn.__name__ + return decorated + + return decorate + + class PytestFixtureFunctions(plugin_base.FixtureFunctions): def skip_test_exception(self, *arg, **kw): return pytest.skip.Exception(*arg, **kw) @@ -109,7 +164,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): } def combinations(self, *arg_sets, **kw): - """facade for pytest.mark.paramtrize. + """Facade for pytest.mark.parametrize. Automatically derives argument names from the callable which in our case is always a method on a class with positional arguments. @@ -117,6 +172,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): ids for parameter sets are derived using an optional template. """ + # from sqlalchemy rel_1_3_14 from alembic.testing import exclusions if sys.version_info.major == 3: @@ -128,8 +184,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): argnames = kw.pop("argnames", None) - exclusion_combinations = [] - def _filter_exclusions(args): result = [] gathered_exclusions = [] @@ -139,13 +193,12 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): else: result.append(a) - exclusion_combinations.extend( - [(exclusion, result) for exclusion in gathered_exclusions] - ) - return result + return result, gathered_exclusions id_ = kw.pop("id_", None) + tobuild_pytest_params = [] + has_exclusions = False if id_: _combination_id_fns = self._combination_id_fns @@ -165,53 +218,88 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): for idx, char in enumerate(id_) if char in _combination_id_fns ] - arg_sets = [ - pytest.param( - *_arg_getter(_filter_exclusions(arg))[1:], - id="-".join( - comb_fn(getter(arg)) for getter, comb_fn in fns + + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + parameters = _arg_getter(fn_params)[1:] + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + ( + parameters, + param_exclusions, + "-".join( + comb_fn(getter(arg)) for getter, comb_fn in fns + ), ) ) - for arg in [ - (arg,) if not isinstance(arg, tuple) else arg - for arg in arg_sets - ] - ] + else: - # ensure using pytest.param so that even a 1-arg paramset - # still needs to be a tuple. otherwise paramtrize tries to - # interpret a single arg differently than tuple arg - arg_sets = [ - pytest.param(*_filter_exclusions(arg)) - for arg in [ - (arg,) if not isinstance(arg, tuple) else arg - for arg in arg_sets - ] - ] + + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + (fn_params, param_exclusions, None) + ) + + pytest_params = [] + for parameters, param_exclusions, id_ in tobuild_pytest_params: + if has_exclusions: + parameters += (param_exclusions,) + + param = pytest.param(*parameters, id=id_) + pytest_params.append(param) def decorate(fn): if inspect.isclass(fn): + if has_exclusions: + raise NotImplementedError( + "exclusions not supported for class level combinations" + ) if "_sa_parametrize" not in fn.__dict__: fn._sa_parametrize = [] - fn._sa_parametrize.append((argnames, arg_sets)) + fn._sa_parametrize.append((argnames, pytest_params)) return fn else: if argnames is None: - _argnames = getargspec(fn).args[1:] + _argnames = getargspec(fn).args[1:] # type: Sequence(str) else: - _argnames = argnames - - if exclusion_combinations: - for exclusion, combination in exclusion_combinations: - combination_by_kw = { - argname: val - for argname, val in zip(_argnames, combination) - } - exclusion = exclusion.with_combination( - **combination_by_kw - ) - fn = exclusion(fn) - return pytest.mark.parametrize(_argnames, arg_sets)(fn) + _argnames = re.split( + r", *", argnames + ) # type: Sequence(str) + + if has_exclusions: + _argnames += ["_exclusions"] + + @_pytest_fn_decorator + def check_exclusions(fn, *args, **kw): + _exclusions = args[-1] + if _exclusions: + exlu = exclusions.compound().add(*_exclusions) + fn = exlu(fn) + return fn(*args[0:-1], **kw) + + def process_metadata(spec): + spec.args.append("_exclusions") + + fn = check_exclusions( + fn, add_positional_parameters=("_exclusions",) + ) + + return pytest.mark.parametrize(_argnames, pytest_params)(fn) return decorate diff --git a/tests/requirements.py b/tests/requirements.py index eb6066db..c077462a 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -169,8 +169,8 @@ class DefaultRequirements(SuiteRequirements): def computed_columns(self): # TODO: in theory if these could come from SQLAlchemy dialects # that would be helpful - return self.computed_columns_api + exclusions.only_on( - ["postgresql >= 12", "oracle", "mssql", "mysql >= 5.7", "mariadb"] + return self.computed_columns_api + exclusions.skip_if( + ["postgresql < 12", "sqlite < 3.31", "mysql < 5.7"] ) @property