+++ /dev/null
-# testing/exclusions.py
-# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-
-import contextlib
-import operator
-import re
-
-from sqlalchemy import util as sqla_util
-from sqlalchemy.util import decorator
-
-from . import config
-from . import fixture_functions
-from .. import util
-from ..util.compat import inspect_getargspec
-
-
-def skip_if(predicate, reason=None):
- rule = compound()
- pred = _as_predicate(predicate, reason)
- rule.skips.add(pred)
- return rule
-
-
-def fails_if(predicate, reason=None):
- rule = compound()
- pred = _as_predicate(predicate, reason)
- rule.fails.add(pred)
- return rule
-
-
-class compound(object):
- def __init__(self):
- self.fails = set()
- self.skips = set()
- self.tags = set()
- self.combinations = {}
-
- def __add__(self, other):
- return self.add(other)
-
- def with_combination(self, **kw):
- copy = compound()
- copy.fails.update(self.fails)
- copy.skips.update(self.skips)
- copy.tags.update(self.tags)
- copy.combinations.update((f, kw) for f in copy.fails)
- copy.combinations.update((s, kw) for s in copy.skips)
- return copy
-
- def add(self, *others):
- copy = compound()
- copy.fails.update(self.fails)
- copy.skips.update(self.skips)
- copy.tags.update(self.tags)
- for other in others:
- copy.fails.update(other.fails)
- copy.skips.update(other.skips)
- copy.tags.update(other.tags)
- return copy
-
- def not_(self):
- copy = compound()
- copy.fails.update(NotPredicate(fail) for fail in self.fails)
- copy.skips.update(NotPredicate(skip) for skip in self.skips)
- copy.tags.update(self.tags)
- return copy
-
- @property
- def enabled(self):
- return self.enabled_for_config(config._current)
-
- def enabled_for_config(self, config):
- for predicate in self.skips.union(self.fails):
- if predicate(config):
- return False
- else:
- return True
-
- def matching_config_reasons(self, config):
- return [
- predicate._as_string(config)
- for predicate in self.skips.union(self.fails)
- if predicate(config)
- ]
-
- def include_test(self, include_tags, exclude_tags):
- return bool(
- not self.tags.intersection(exclude_tags)
- and (not include_tags or self.tags.intersection(include_tags))
- )
-
- def _extend(self, other):
- self.skips.update(other.skips)
- self.fails.update(other.fails)
- self.tags.update(other.tags)
- self.combinations.update(other.combinations)
-
- def __call__(self, fn):
- if hasattr(fn, "_sa_exclusion_extend"):
- fn._sa_exclusion_extend._extend(self)
- return fn
-
- @decorator
- def decorate(fn, *args, **kw):
- return self._do(config._current, fn, *args, **kw)
-
- decorated = decorate(fn)
- decorated._sa_exclusion_extend = self
- return decorated
-
- @contextlib.contextmanager
- def fail_if(self):
- all_fails = compound()
- all_fails.fails.update(self.skips.union(self.fails))
-
- try:
- yield
- except Exception as ex:
- all_fails._expect_failure(config._current, ex, None)
- else:
- all_fails._expect_success(config._current, None)
-
- def _check_combinations(self, combination, predicate):
- if predicate in self.combinations:
- for k, v in combination:
- if (
- k in self.combinations[predicate]
- and self.combinations[predicate][k] != v
- ):
- return False
- return True
-
- def _do(self, cfg, fn, *args, **kw):
- if len(args) > 1:
- insp = inspect_getargspec(fn)
- combination = list(zip(insp.args[1:], args[1:]))
- else:
- combination = None
-
- for skip in self.skips:
- if self._check_combinations(combination, skip) and skip(cfg):
- msg = "'%s' : %s" % (
- fixture_functions.get_current_test_name(),
- skip._as_string(cfg),
- )
- config.skip_test(msg)
-
- try:
- return_value = fn(*args, **kw)
- except Exception as ex:
- self._expect_failure(cfg, ex, combination, name=fn.__name__)
- else:
- self._expect_success(cfg, combination, name=fn.__name__)
- return return_value
-
- def _expect_failure(self, config, ex, combination, name="block"):
- for fail in self.fails:
- if self._check_combinations(combination, fail) and fail(config):
- if sqla_util.py2k:
- str_ex = unicode(ex).encode( # noqa: F821
- "utf-8", errors="ignore"
- )
- else:
- str_ex = str(ex)
- print(
- (
- "%s failed as expected (%s): %s "
- % (name, fail._as_string(config), str_ex)
- )
- )
- break
- else:
- util.raise_from_cause(ex)
-
- def _expect_success(self, config, combination, name="block"):
- if not self.fails:
- return
-
- for fail in self.fails:
- if self._check_combinations(combination, fail) and fail(config):
- raise AssertionError(
- "Unexpected success for '%s' (%s)"
- % (
- name,
- " and ".join(
- fail._as_string(config) for fail in self.fails
- ),
- )
- )
-
-
-def requires_tag(tagname):
- return tags([tagname])
-
-
-def tags(tagnames):
- comp = compound()
- comp.tags.update(tagnames)
- return comp
-
-
-def only_if(predicate, reason=None):
- predicate = _as_predicate(predicate)
- return skip_if(NotPredicate(predicate), reason)
-
-
-def succeeds_if(predicate, reason=None):
- predicate = _as_predicate(predicate)
- return fails_if(NotPredicate(predicate), reason)
-
-
-class Predicate(object):
- @classmethod
- def as_predicate(cls, predicate, description=None):
- if isinstance(predicate, compound):
- return cls.as_predicate(predicate.enabled_for_config, description)
- elif isinstance(predicate, Predicate):
- if description and predicate.description is None:
- predicate.description = description
- return predicate
- elif isinstance(predicate, (list, set)):
- return OrPredicate(
- [cls.as_predicate(pred) for pred in predicate], description
- )
- elif isinstance(predicate, tuple):
- return SpecPredicate(*predicate)
- elif isinstance(predicate, sqla_util.string_types):
- tokens = re.match(
- r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
- )
- if not tokens:
- raise ValueError(
- "Couldn't locate DB name in predicate: %r" % predicate
- )
- db = tokens.group(1)
- op = tokens.group(2)
- spec = (
- tuple(int(d) for d in tokens.group(3).split("."))
- if tokens.group(3)
- else None
- )
-
- return SpecPredicate(db, op, spec, description=description)
- elif callable(predicate):
- return LambdaPredicate(predicate, description)
- else:
- assert False, "unknown predicate type: %s" % predicate
-
- def _format_description(self, config, negate=False):
- bool_ = self(config)
- if negate:
- bool_ = not negate
- return self.description % {
- "driver": config.db.url.get_driver_name()
- if config
- else "<no driver>",
- "database": config.db.url.get_backend_name()
- if config
- else "<no database>",
- "doesnt_support": "doesn't support" if bool_ else "does support",
- "does_support": "does support" if bool_ else "doesn't support",
- }
-
- def _as_string(self, config=None, negate=False):
- raise NotImplementedError()
-
-
-class BooleanPredicate(Predicate):
- def __init__(self, value, description=None):
- self.value = value
- self.description = description or "boolean %s" % value
-
- def __call__(self, config):
- return self.value
-
- def _as_string(self, config, negate=False):
- return self._format_description(config, negate=negate)
-
-
-class SpecPredicate(Predicate):
- def __init__(self, db, op=None, spec=None, description=None):
- self.db = db
- self.op = op
- self.spec = spec
- self.description = description
-
- _ops = {
- "<": operator.lt,
- ">": operator.gt,
- "==": operator.eq,
- "!=": operator.ne,
- "<=": operator.le,
- ">=": operator.ge,
- "in": operator.contains,
- "between": lambda val, pair: val >= pair[0] and val <= pair[1],
- }
-
- def __call__(self, config):
- engine = config.db
-
- if "+" in self.db:
- dialect, driver = self.db.split("+")
- else:
- dialect, driver = self.db, None
-
- if dialect and engine.name != dialect:
- return False
- if driver is not None and engine.driver != driver:
- return False
-
- if self.op is not None:
- assert driver is None, "DBAPI version specs not supported yet"
-
- version = _server_version(engine)
- oper = (
- hasattr(self.op, "__call__") and self.op or self._ops[self.op]
- )
- return oper(version, self.spec)
- else:
- return True
-
- def _as_string(self, config, negate=False):
- if self.description is not None:
- return self._format_description(config)
- elif self.op is None:
- if negate:
- return "not %s" % self.db
- else:
- return "%s" % self.db
- else:
- if negate:
- return "not %s %s %s" % (self.db, self.op, self.spec)
- else:
- return "%s %s %s" % (self.db, self.op, self.spec)
-
-
-class LambdaPredicate(Predicate):
- def __init__(self, lambda_, description=None, args=None, kw=None):
- spec = inspect_getargspec(lambda_)
- if not spec[0]:
- self.lambda_ = lambda db: lambda_()
- else:
- self.lambda_ = lambda_
- self.args = args or ()
- self.kw = kw or {}
- if description:
- self.description = description
- elif lambda_.__doc__:
- self.description = lambda_.__doc__
- else:
- self.description = "custom function"
-
- def __call__(self, config):
- return self.lambda_(config)
-
- def _as_string(self, config, negate=False):
- return self._format_description(config)
-
-
-class NotPredicate(Predicate):
- def __init__(self, predicate, description=None):
- self.predicate = predicate
- self.description = description
-
- def __call__(self, config):
- return not self.predicate(config)
-
- def _as_string(self, config, negate=False):
- if self.description:
- return self._format_description(config, not negate)
- else:
- return self.predicate._as_string(config, not negate)
-
-
-class OrPredicate(Predicate):
- def __init__(self, predicates, description=None):
- self.predicates = predicates
- self.description = description
-
- def __call__(self, config):
- for pred in self.predicates:
- if pred(config):
- return True
- return False
-
- def _eval_str(self, config, negate=False):
- if negate:
- conjunction = " and "
- else:
- conjunction = " or "
- return conjunction.join(
- p._as_string(config, negate=negate) for p in self.predicates
- )
-
- def _negation_str(self, config):
- if self.description is not None:
- return "Not " + self._format_description(config)
- else:
- return self._eval_str(config, negate=True)
-
- def _as_string(self, config, negate=False):
- if negate:
- return self._negation_str(config)
- else:
- if self.description is not None:
- return self._format_description(config)
- else:
- return self._eval_str(config)
-
-
-_as_predicate = Predicate.as_predicate
-
-
-def _is_excluded(db, op, spec):
- return SpecPredicate(db, op, spec)(config._current)
-
-
-def _server_version(engine):
- """Return a server_version_info tuple."""
-
- # force metadata to be retrieved
- conn = engine.connect()
- version = getattr(engine.dialect, "server_version_info", None)
- if version is None:
- version = ()
- conn.close()
- return version
-
-
-def db_spec(*dbs):
- return OrPredicate([Predicate.as_predicate(db) for db in dbs])
-
-
-def open(): # noqa
- return skip_if(BooleanPredicate(False, "mark as execute"))
-
-
-def closed():
- return skip_if(BooleanPredicate(True, "marked as skip"))
-
-
-def fails(reason=None):
- return fails_if(BooleanPredicate(True, reason or "expected to fail"))
-
-
-@decorator
-def future(fn, *arg):
- return fails_if(LambdaPredicate(fn), "Future feature")
-
-
-def fails_on(db, reason=None):
- return fails_if(db, reason)
-
-
-def fails_on_everything_except(*dbs):
- return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
-
-
-def skip(db, reason=None):
- return skip_if(db, reason)
-
-
-def only_on(dbs, reason=None):
- return only_if(
- OrPredicate(
- [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
- )
- )
-
-
-def exclude(db, op, spec, reason=None):
- return skip_if(SpecPredicate(db, op, spec), reason)
-
-
-def against(config, *queries):
- assert queries, "no queries sent!"
- return OrPredicate([Predicate.as_predicate(query) for query in queries])(
- config
- )
# assume we're a package, use traditional import
from . import plugin_base
+from functools import update_wrapper
import inspect
import itertools
import operator
from sqlalchemy.testing.plugin.pytestplugin import * # noqa
from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc
+py3k = sys.version_info.major >= 3
+
+if py3k:
+ from typing import TYPE_CHECKING
+else:
+ TYPE_CHECKING = False
+
+if TYPE_CHECKING:
+ from typing import Sequence
+
# override selected SQLAlchemy pytest hooks with vendored functionality
def pytest_configure(config):
return inspect.getargspec(fn)
+def _pytest_fn_decorator(target):
+ """Port of langhelpers.decorator with pytest-specific tricks."""
+ # from sqlalchemy rel_1_3_14
+
+ from sqlalchemy.util.langhelpers import format_argspec_plus
+ from sqlalchemy.util.compat import inspect_getfullargspec
+
+ def _exec_code_in_env(code, env, fn_name):
+ exec(code, env)
+ return env[fn_name]
+
+ def decorate(fn, add_positional_parameters=()):
+
+ spec = inspect_getfullargspec(fn)
+ if add_positional_parameters:
+ spec.args.extend(add_positional_parameters)
+
+ metadata = dict(target="target", fn="__fn", name=fn.__name__)
+ metadata.update(format_argspec_plus(spec, grouped=False))
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return %(target)s(%(fn)s, %(apply_kw)s)
+"""
+ % metadata
+ )
+ decorated = _exec_code_in_env(
+ code, {"target": target, "__fn": fn}, fn.__name__
+ )
+ if not add_positional_parameters:
+ decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ decorated.__wrapped__ = fn
+ return update_wrapper(decorated, fn)
+ else:
+ # this is the pytest hacky part. don't do a full update wrapper
+ # because pytest is really being sneaky about finding the args
+ # for the wrapped function
+ decorated.__module__ = fn.__module__
+ decorated.__name__ = fn.__name__
+ return decorated
+
+ return decorate
+
+
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
def skip_test_exception(self, *arg, **kw):
return pytest.skip.Exception(*arg, **kw)
}
def combinations(self, *arg_sets, **kw):
- """facade for pytest.mark.paramtrize.
+ """Facade for pytest.mark.parametrize.
Automatically derives argument names from the callable which in our
case is always a method on a class with positional arguments.
ids for parameter sets are derived using an optional template.
"""
+ # from sqlalchemy rel_1_3_14
from alembic.testing import exclusions
if sys.version_info.major == 3:
argnames = kw.pop("argnames", None)
- exclusion_combinations = []
-
def _filter_exclusions(args):
result = []
gathered_exclusions = []
else:
result.append(a)
- exclusion_combinations.extend(
- [(exclusion, result) for exclusion in gathered_exclusions]
- )
- return result
+ return result, gathered_exclusions
id_ = kw.pop("id_", None)
+ tobuild_pytest_params = []
+ has_exclusions = False
if id_:
_combination_id_fns = self._combination_id_fns
for idx, char in enumerate(id_)
if char in _combination_id_fns
]
- arg_sets = [
- pytest.param(
- *_arg_getter(_filter_exclusions(arg))[1:],
- id="-".join(
- comb_fn(getter(arg)) for getter, comb_fn in fns
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ parameters = _arg_getter(fn_params)[1:]
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (
+ parameters,
+ param_exclusions,
+ "-".join(
+ comb_fn(getter(arg)) for getter, comb_fn in fns
+ ),
)
)
- for arg in [
- (arg,) if not isinstance(arg, tuple) else arg
- for arg in arg_sets
- ]
- ]
+
else:
- # ensure using pytest.param so that even a 1-arg paramset
- # still needs to be a tuple. otherwise paramtrize tries to
- # interpret a single arg differently than tuple arg
- arg_sets = [
- pytest.param(*_filter_exclusions(arg))
- for arg in [
- (arg,) if not isinstance(arg, tuple) else arg
- for arg in arg_sets
- ]
- ]
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (fn_params, param_exclusions, None)
+ )
+
+ pytest_params = []
+ for parameters, param_exclusions, id_ in tobuild_pytest_params:
+ if has_exclusions:
+ parameters += (param_exclusions,)
+
+ param = pytest.param(*parameters, id=id_)
+ pytest_params.append(param)
def decorate(fn):
if inspect.isclass(fn):
+ if has_exclusions:
+ raise NotImplementedError(
+ "exclusions not supported for class level combinations"
+ )
if "_sa_parametrize" not in fn.__dict__:
fn._sa_parametrize = []
- fn._sa_parametrize.append((argnames, arg_sets))
+ fn._sa_parametrize.append((argnames, pytest_params))
return fn
else:
if argnames is None:
- _argnames = getargspec(fn).args[1:]
+ _argnames = getargspec(fn).args[1:] # type: Sequence(str)
else:
- _argnames = argnames
-
- if exclusion_combinations:
- for exclusion, combination in exclusion_combinations:
- combination_by_kw = {
- argname: val
- for argname, val in zip(_argnames, combination)
- }
- exclusion = exclusion.with_combination(
- **combination_by_kw
- )
- fn = exclusion(fn)
- return pytest.mark.parametrize(_argnames, arg_sets)(fn)
+ _argnames = re.split(
+ r", *", argnames
+ ) # type: Sequence(str)
+
+ if has_exclusions:
+ _argnames += ["_exclusions"]
+
+ @_pytest_fn_decorator
+ def check_exclusions(fn, *args, **kw):
+ _exclusions = args[-1]
+ if _exclusions:
+ exlu = exclusions.compound().add(*_exclusions)
+ fn = exlu(fn)
+ return fn(*args[0:-1], **kw)
+
+ def process_metadata(spec):
+ spec.args.append("_exclusions")
+
+ fn = check_exclusions(
+ fn, add_positional_parameters=("_exclusions",)
+ )
+
+ return pytest.mark.parametrize(_argnames, pytest_params)(fn)
return decorate