]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Use exclusions from sqlalchemy
authorCaselIT <cfederico87@gmail.com>
Tue, 15 Dec 2020 22:17:20 +0000 (23:17 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 Dec 2020 15:57:43 +0000 (10:57 -0500)
Since alembic 1.5 will support only sqlalchemy 1.3+ it can
safely use the exclusions from sqlalchemy.

Enable test for computed columns with sqlite >= 3.31

Fixes: #759
Change-Id: If629e0c0aa264205fa571e26b3ba8398634c374d

alembic/testing/__init__.py
alembic/testing/exclusions.py [deleted file]
alembic/testing/plugin/pytestplugin.py
tests/requirements.py

index f009da930b3b0287dc0e71e2fc799cd642d7160e..23c0f19bec2a33266c8aec5f265173a89c670ea7 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy.testing import config  # noqa
+from sqlalchemy.testing import exclusions  # noqa
 from sqlalchemy.testing import emits_warning  # noqa
 from sqlalchemy.testing import engines  # noqa
 from sqlalchemy.testing import mock  # noqa
@@ -7,7 +8,6 @@ from sqlalchemy.testing import uses_deprecated  # noqa
 from sqlalchemy.testing.config import requirements as requires  # noqa
 
 from alembic import util  # noqa
-from . import exclusions  # noqa
 from .assertions import assert_raises  # noqa
 from .assertions import assert_raises_message  # noqa
 from .assertions import emits_python_deprecation_warning  # noqa
diff --git a/alembic/testing/exclusions.py b/alembic/testing/exclusions.py
deleted file mode 100644 (file)
index 91f2d5b..0000000
+++ /dev/null
@@ -1,484 +0,0 @@
-# testing/exclusions.py
-# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-
-import contextlib
-import operator
-import re
-
-from sqlalchemy import util as sqla_util
-from sqlalchemy.util import decorator
-
-from . import config
-from . import fixture_functions
-from .. import util
-from ..util.compat import inspect_getargspec
-
-
-def skip_if(predicate, reason=None):
-    rule = compound()
-    pred = _as_predicate(predicate, reason)
-    rule.skips.add(pred)
-    return rule
-
-
-def fails_if(predicate, reason=None):
-    rule = compound()
-    pred = _as_predicate(predicate, reason)
-    rule.fails.add(pred)
-    return rule
-
-
-class compound(object):
-    def __init__(self):
-        self.fails = set()
-        self.skips = set()
-        self.tags = set()
-        self.combinations = {}
-
-    def __add__(self, other):
-        return self.add(other)
-
-    def with_combination(self, **kw):
-        copy = compound()
-        copy.fails.update(self.fails)
-        copy.skips.update(self.skips)
-        copy.tags.update(self.tags)
-        copy.combinations.update((f, kw) for f in copy.fails)
-        copy.combinations.update((s, kw) for s in copy.skips)
-        return copy
-
-    def add(self, *others):
-        copy = compound()
-        copy.fails.update(self.fails)
-        copy.skips.update(self.skips)
-        copy.tags.update(self.tags)
-        for other in others:
-            copy.fails.update(other.fails)
-            copy.skips.update(other.skips)
-            copy.tags.update(other.tags)
-        return copy
-
-    def not_(self):
-        copy = compound()
-        copy.fails.update(NotPredicate(fail) for fail in self.fails)
-        copy.skips.update(NotPredicate(skip) for skip in self.skips)
-        copy.tags.update(self.tags)
-        return copy
-
-    @property
-    def enabled(self):
-        return self.enabled_for_config(config._current)
-
-    def enabled_for_config(self, config):
-        for predicate in self.skips.union(self.fails):
-            if predicate(config):
-                return False
-        else:
-            return True
-
-    def matching_config_reasons(self, config):
-        return [
-            predicate._as_string(config)
-            for predicate in self.skips.union(self.fails)
-            if predicate(config)
-        ]
-
-    def include_test(self, include_tags, exclude_tags):
-        return bool(
-            not self.tags.intersection(exclude_tags)
-            and (not include_tags or self.tags.intersection(include_tags))
-        )
-
-    def _extend(self, other):
-        self.skips.update(other.skips)
-        self.fails.update(other.fails)
-        self.tags.update(other.tags)
-        self.combinations.update(other.combinations)
-
-    def __call__(self, fn):
-        if hasattr(fn, "_sa_exclusion_extend"):
-            fn._sa_exclusion_extend._extend(self)
-            return fn
-
-        @decorator
-        def decorate(fn, *args, **kw):
-            return self._do(config._current, fn, *args, **kw)
-
-        decorated = decorate(fn)
-        decorated._sa_exclusion_extend = self
-        return decorated
-
-    @contextlib.contextmanager
-    def fail_if(self):
-        all_fails = compound()
-        all_fails.fails.update(self.skips.union(self.fails))
-
-        try:
-            yield
-        except Exception as ex:
-            all_fails._expect_failure(config._current, ex, None)
-        else:
-            all_fails._expect_success(config._current, None)
-
-    def _check_combinations(self, combination, predicate):
-        if predicate in self.combinations:
-            for k, v in combination:
-                if (
-                    k in self.combinations[predicate]
-                    and self.combinations[predicate][k] != v
-                ):
-                    return False
-        return True
-
-    def _do(self, cfg, fn, *args, **kw):
-        if len(args) > 1:
-            insp = inspect_getargspec(fn)
-            combination = list(zip(insp.args[1:], args[1:]))
-        else:
-            combination = None
-
-        for skip in self.skips:
-            if self._check_combinations(combination, skip) and skip(cfg):
-                msg = "'%s' : %s" % (
-                    fixture_functions.get_current_test_name(),
-                    skip._as_string(cfg),
-                )
-                config.skip_test(msg)
-
-        try:
-            return_value = fn(*args, **kw)
-        except Exception as ex:
-            self._expect_failure(cfg, ex, combination, name=fn.__name__)
-        else:
-            self._expect_success(cfg, combination, name=fn.__name__)
-            return return_value
-
-    def _expect_failure(self, config, ex, combination, name="block"):
-        for fail in self.fails:
-            if self._check_combinations(combination, fail) and fail(config):
-                if sqla_util.py2k:
-                    str_ex = unicode(ex).encode(  # noqa: F821
-                        "utf-8", errors="ignore"
-                    )
-                else:
-                    str_ex = str(ex)
-                print(
-                    (
-                        "%s failed as expected (%s): %s "
-                        % (name, fail._as_string(config), str_ex)
-                    )
-                )
-                break
-        else:
-            util.raise_from_cause(ex)
-
-    def _expect_success(self, config, combination, name="block"):
-        if not self.fails:
-            return
-
-        for fail in self.fails:
-            if self._check_combinations(combination, fail) and fail(config):
-                raise AssertionError(
-                    "Unexpected success for '%s' (%s)"
-                    % (
-                        name,
-                        " and ".join(
-                            fail._as_string(config) for fail in self.fails
-                        ),
-                    )
-                )
-
-
-def requires_tag(tagname):
-    return tags([tagname])
-
-
-def tags(tagnames):
-    comp = compound()
-    comp.tags.update(tagnames)
-    return comp
-
-
-def only_if(predicate, reason=None):
-    predicate = _as_predicate(predicate)
-    return skip_if(NotPredicate(predicate), reason)
-
-
-def succeeds_if(predicate, reason=None):
-    predicate = _as_predicate(predicate)
-    return fails_if(NotPredicate(predicate), reason)
-
-
-class Predicate(object):
-    @classmethod
-    def as_predicate(cls, predicate, description=None):
-        if isinstance(predicate, compound):
-            return cls.as_predicate(predicate.enabled_for_config, description)
-        elif isinstance(predicate, Predicate):
-            if description and predicate.description is None:
-                predicate.description = description
-            return predicate
-        elif isinstance(predicate, (list, set)):
-            return OrPredicate(
-                [cls.as_predicate(pred) for pred in predicate], description
-            )
-        elif isinstance(predicate, tuple):
-            return SpecPredicate(*predicate)
-        elif isinstance(predicate, sqla_util.string_types):
-            tokens = re.match(
-                r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
-            )
-            if not tokens:
-                raise ValueError(
-                    "Couldn't locate DB name in predicate: %r" % predicate
-                )
-            db = tokens.group(1)
-            op = tokens.group(2)
-            spec = (
-                tuple(int(d) for d in tokens.group(3).split("."))
-                if tokens.group(3)
-                else None
-            )
-
-            return SpecPredicate(db, op, spec, description=description)
-        elif callable(predicate):
-            return LambdaPredicate(predicate, description)
-        else:
-            assert False, "unknown predicate type: %s" % predicate
-
-    def _format_description(self, config, negate=False):
-        bool_ = self(config)
-        if negate:
-            bool_ = not negate
-        return self.description % {
-            "driver": config.db.url.get_driver_name()
-            if config
-            else "<no driver>",
-            "database": config.db.url.get_backend_name()
-            if config
-            else "<no database>",
-            "doesnt_support": "doesn't support" if bool_ else "does support",
-            "does_support": "does support" if bool_ else "doesn't support",
-        }
-
-    def _as_string(self, config=None, negate=False):
-        raise NotImplementedError()
-
-
-class BooleanPredicate(Predicate):
-    def __init__(self, value, description=None):
-        self.value = value
-        self.description = description or "boolean %s" % value
-
-    def __call__(self, config):
-        return self.value
-
-    def _as_string(self, config, negate=False):
-        return self._format_description(config, negate=negate)
-
-
-class SpecPredicate(Predicate):
-    def __init__(self, db, op=None, spec=None, description=None):
-        self.db = db
-        self.op = op
-        self.spec = spec
-        self.description = description
-
-    _ops = {
-        "<": operator.lt,
-        ">": operator.gt,
-        "==": operator.eq,
-        "!=": operator.ne,
-        "<=": operator.le,
-        ">=": operator.ge,
-        "in": operator.contains,
-        "between": lambda val, pair: val >= pair[0] and val <= pair[1],
-    }
-
-    def __call__(self, config):
-        engine = config.db
-
-        if "+" in self.db:
-            dialect, driver = self.db.split("+")
-        else:
-            dialect, driver = self.db, None
-
-        if dialect and engine.name != dialect:
-            return False
-        if driver is not None and engine.driver != driver:
-            return False
-
-        if self.op is not None:
-            assert driver is None, "DBAPI version specs not supported yet"
-
-            version = _server_version(engine)
-            oper = (
-                hasattr(self.op, "__call__") and self.op or self._ops[self.op]
-            )
-            return oper(version, self.spec)
-        else:
-            return True
-
-    def _as_string(self, config, negate=False):
-        if self.description is not None:
-            return self._format_description(config)
-        elif self.op is None:
-            if negate:
-                return "not %s" % self.db
-            else:
-                return "%s" % self.db
-        else:
-            if negate:
-                return "not %s %s %s" % (self.db, self.op, self.spec)
-            else:
-                return "%s %s %s" % (self.db, self.op, self.spec)
-
-
-class LambdaPredicate(Predicate):
-    def __init__(self, lambda_, description=None, args=None, kw=None):
-        spec = inspect_getargspec(lambda_)
-        if not spec[0]:
-            self.lambda_ = lambda db: lambda_()
-        else:
-            self.lambda_ = lambda_
-        self.args = args or ()
-        self.kw = kw or {}
-        if description:
-            self.description = description
-        elif lambda_.__doc__:
-            self.description = lambda_.__doc__
-        else:
-            self.description = "custom function"
-
-    def __call__(self, config):
-        return self.lambda_(config)
-
-    def _as_string(self, config, negate=False):
-        return self._format_description(config)
-
-
-class NotPredicate(Predicate):
-    def __init__(self, predicate, description=None):
-        self.predicate = predicate
-        self.description = description
-
-    def __call__(self, config):
-        return not self.predicate(config)
-
-    def _as_string(self, config, negate=False):
-        if self.description:
-            return self._format_description(config, not negate)
-        else:
-            return self.predicate._as_string(config, not negate)
-
-
-class OrPredicate(Predicate):
-    def __init__(self, predicates, description=None):
-        self.predicates = predicates
-        self.description = description
-
-    def __call__(self, config):
-        for pred in self.predicates:
-            if pred(config):
-                return True
-        return False
-
-    def _eval_str(self, config, negate=False):
-        if negate:
-            conjunction = " and "
-        else:
-            conjunction = " or "
-        return conjunction.join(
-            p._as_string(config, negate=negate) for p in self.predicates
-        )
-
-    def _negation_str(self, config):
-        if self.description is not None:
-            return "Not " + self._format_description(config)
-        else:
-            return self._eval_str(config, negate=True)
-
-    def _as_string(self, config, negate=False):
-        if negate:
-            return self._negation_str(config)
-        else:
-            if self.description is not None:
-                return self._format_description(config)
-            else:
-                return self._eval_str(config)
-
-
-_as_predicate = Predicate.as_predicate
-
-
-def _is_excluded(db, op, spec):
-    return SpecPredicate(db, op, spec)(config._current)
-
-
-def _server_version(engine):
-    """Return a server_version_info tuple."""
-
-    # force metadata to be retrieved
-    conn = engine.connect()
-    version = getattr(engine.dialect, "server_version_info", None)
-    if version is None:
-        version = ()
-    conn.close()
-    return version
-
-
-def db_spec(*dbs):
-    return OrPredicate([Predicate.as_predicate(db) for db in dbs])
-
-
-def open():  # noqa
-    return skip_if(BooleanPredicate(False, "mark as execute"))
-
-
-def closed():
-    return skip_if(BooleanPredicate(True, "marked as skip"))
-
-
-def fails(reason=None):
-    return fails_if(BooleanPredicate(True, reason or "expected to fail"))
-
-
-@decorator
-def future(fn, *arg):
-    return fails_if(LambdaPredicate(fn), "Future feature")
-
-
-def fails_on(db, reason=None):
-    return fails_if(db, reason)
-
-
-def fails_on_everything_except(*dbs):
-    return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
-
-
-def skip(db, reason=None):
-    return skip_if(db, reason)
-
-
-def only_on(dbs, reason=None):
-    return only_if(
-        OrPredicate(
-            [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
-        )
-    )
-
-
-def exclude(db, op, spec, reason=None):
-    return skip_if(SpecPredicate(db, op, spec), reason)
-
-
-def against(config, *queries):
-    assert queries, "no queries sent!"
-    return OrPredicate([Predicate.as_predicate(query) for query in queries])(
-        config
-    )
index ba3d35bbe1e6064273a037b7b078447ff45673cb..6b76a17e9fa6d78d7e7431736f417aeda7e5d2e0 100644 (file)
@@ -11,6 +11,7 @@ except ImportError:
     # assume we're a package, use traditional import
     from . import plugin_base
 
+from functools import update_wrapper
 import inspect
 import itertools
 import operator
@@ -22,6 +23,16 @@ import pytest
 from sqlalchemy.testing.plugin.pytestplugin import *  # noqa
 from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc
 
+py3k = sys.version_info.major >= 3
+
+if py3k:
+    from typing import TYPE_CHECKING
+else:
+    TYPE_CHECKING = False
+
+if TYPE_CHECKING:
+    from typing import Sequence
+
 
 # override selected SQLAlchemy pytest hooks with vendored functionality
 def pytest_configure(config):
@@ -97,6 +108,50 @@ def getargspec(fn):
         return inspect.getargspec(fn)
 
 
+def _pytest_fn_decorator(target):
+    """Port of langhelpers.decorator with pytest-specific tricks."""
+    # from sqlalchemy rel_1_3_14
+
+    from sqlalchemy.util.langhelpers import format_argspec_plus
+    from sqlalchemy.util.compat import inspect_getfullargspec
+
+    def _exec_code_in_env(code, env, fn_name):
+        exec(code, env)
+        return env[fn_name]
+
+    def decorate(fn, add_positional_parameters=()):
+
+        spec = inspect_getfullargspec(fn)
+        if add_positional_parameters:
+            spec.args.extend(add_positional_parameters)
+
+        metadata = dict(target="target", fn="__fn", name=fn.__name__)
+        metadata.update(format_argspec_plus(spec, grouped=False))
+        code = (
+            """\
+def %(name)s(%(args)s):
+    return %(target)s(%(fn)s, %(apply_kw)s)
+"""
+            % metadata
+        )
+        decorated = _exec_code_in_env(
+            code, {"target": target, "__fn": fn}, fn.__name__
+        )
+        if not add_positional_parameters:
+            decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+            decorated.__wrapped__ = fn
+            return update_wrapper(decorated, fn)
+        else:
+            # this is the pytest hacky part.  don't do a full update wrapper
+            # because pytest is really being sneaky about finding the args
+            # for the wrapped function
+            decorated.__module__ = fn.__module__
+            decorated.__name__ = fn.__name__
+            return decorated
+
+    return decorate
+
+
 class PytestFixtureFunctions(plugin_base.FixtureFunctions):
     def skip_test_exception(self, *arg, **kw):
         return pytest.skip.Exception(*arg, **kw)
@@ -109,7 +164,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
     }
 
     def combinations(self, *arg_sets, **kw):
-        """facade for pytest.mark.paramtrize.
+        """Facade for pytest.mark.parametrize.
 
         Automatically derives argument names from the callable which in our
         case is always a method on a class with positional arguments.
@@ -117,6 +172,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
         ids for parameter sets are derived using an optional template.
 
         """
+        # from sqlalchemy rel_1_3_14
         from alembic.testing import exclusions
 
         if sys.version_info.major == 3:
@@ -128,8 +184,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
 
         argnames = kw.pop("argnames", None)
 
-        exclusion_combinations = []
-
         def _filter_exclusions(args):
             result = []
             gathered_exclusions = []
@@ -139,13 +193,12 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
                 else:
                     result.append(a)
 
-            exclusion_combinations.extend(
-                [(exclusion, result) for exclusion in gathered_exclusions]
-            )
-            return result
+            return result, gathered_exclusions
 
         id_ = kw.pop("id_", None)
 
+        tobuild_pytest_params = []
+        has_exclusions = False
         if id_:
             _combination_id_fns = self._combination_id_fns
 
@@ -165,53 +218,88 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
                 for idx, char in enumerate(id_)
                 if char in _combination_id_fns
             ]
-            arg_sets = [
-                pytest.param(
-                    *_arg_getter(_filter_exclusions(arg))[1:],
-                    id="-".join(
-                        comb_fn(getter(arg)) for getter, comb_fn in fns
+
+            for arg in arg_sets:
+                if not isinstance(arg, tuple):
+                    arg = (arg,)
+
+                fn_params, param_exclusions = _filter_exclusions(arg)
+
+                parameters = _arg_getter(fn_params)[1:]
+
+                if param_exclusions:
+                    has_exclusions = True
+
+                tobuild_pytest_params.append(
+                    (
+                        parameters,
+                        param_exclusions,
+                        "-".join(
+                            comb_fn(getter(arg)) for getter, comb_fn in fns
+                        ),
                     )
                 )
-                for arg in [
-                    (arg,) if not isinstance(arg, tuple) else arg
-                    for arg in arg_sets
-                ]
-            ]
+
         else:
-            # ensure using pytest.param so that even a 1-arg paramset
-            # still needs to be a tuple.  otherwise paramtrize tries to
-            # interpret a single arg differently than tuple arg
-            arg_sets = [
-                pytest.param(*_filter_exclusions(arg))
-                for arg in [
-                    (arg,) if not isinstance(arg, tuple) else arg
-                    for arg in arg_sets
-                ]
-            ]
+
+            for arg in arg_sets:
+                if not isinstance(arg, tuple):
+                    arg = (arg,)
+
+                fn_params, param_exclusions = _filter_exclusions(arg)
+
+                if param_exclusions:
+                    has_exclusions = True
+
+                tobuild_pytest_params.append(
+                    (fn_params, param_exclusions, None)
+                )
+
+        pytest_params = []
+        for parameters, param_exclusions, id_ in tobuild_pytest_params:
+            if has_exclusions:
+                parameters += (param_exclusions,)
+
+            param = pytest.param(*parameters, id=id_)
+            pytest_params.append(param)
 
         def decorate(fn):
             if inspect.isclass(fn):
+                if has_exclusions:
+                    raise NotImplementedError(
+                        "exclusions not supported for class level combinations"
+                    )
                 if "_sa_parametrize" not in fn.__dict__:
                     fn._sa_parametrize = []
-                fn._sa_parametrize.append((argnames, arg_sets))
+                fn._sa_parametrize.append((argnames, pytest_params))
                 return fn
             else:
                 if argnames is None:
-                    _argnames = getargspec(fn).args[1:]
+                    _argnames = getargspec(fn).args[1:]  # type: Sequence(str)
                 else:
-                    _argnames = argnames
-
-                if exclusion_combinations:
-                    for exclusion, combination in exclusion_combinations:
-                        combination_by_kw = {
-                            argname: val
-                            for argname, val in zip(_argnames, combination)
-                        }
-                        exclusion = exclusion.with_combination(
-                            **combination_by_kw
-                        )
-                        fn = exclusion(fn)
-                return pytest.mark.parametrize(_argnames, arg_sets)(fn)
+                    _argnames = re.split(
+                        r", *", argnames
+                    )  # type: Sequence(str)
+
+                if has_exclusions:
+                    _argnames += ["_exclusions"]
+
+                    @_pytest_fn_decorator
+                    def check_exclusions(fn, *args, **kw):
+                        _exclusions = args[-1]
+                        if _exclusions:
+                            exlu = exclusions.compound().add(*_exclusions)
+                            fn = exlu(fn)
+                        return fn(*args[0:-1], **kw)
+
+                    def process_metadata(spec):
+                        spec.args.append("_exclusions")
+
+                    fn = check_exclusions(
+                        fn, add_positional_parameters=("_exclusions",)
+                    )
+
+                return pytest.mark.parametrize(_argnames, pytest_params)(fn)
 
         return decorate
 
index eb6066db6c70a5188b3a92bfd28303b07242f2ec..c077462a9bd3563f5c75961d028ad411846dd177 100644 (file)
@@ -169,8 +169,8 @@ class DefaultRequirements(SuiteRequirements):
     def computed_columns(self):
         # TODO: in theory if these could come from SQLAlchemy dialects
         # that would be helpful
-        return self.computed_columns_api + exclusions.only_on(
-            ["postgresql >= 12", "oracle", "mssql", "mysql >= 5.7", "mariadb"]
+        return self.computed_columns_api + exclusions.skip_if(
+            ["postgresql < 12", "sqlite < 3.31", "mysql < 5.7"]
         )
 
     @property