]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Re-port bootstrap /plugin_base to get combinatoric functions
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 28 Nov 2019 00:30:08 +0000 (19:30 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Nov 2019 17:29:16 +0000 (12:29 -0500)
The py.test combinatoric functions are available in
SQLAlchemy's testing fixtures for version 1.3 and 1.4 only,
not 1.1 or 1.2.

This also needs to port the py.test test collection
function which was improved in SQLAlchemy 1.3 and also
made to accommodate for pytest.combinations, so we must
still vendor a minimum set of the "bootstrap" mechanism
for now.

Change-Id: I21b426cd686a920b2ae60132681afbb645a68246

alembic/testing/__init__.py
alembic/testing/fixture_functions.py [new file with mode: 0644]
alembic/testing/plugin/__init__.py [new file with mode: 0644]
alembic/testing/plugin/bootstrap.py [new file with mode: 0644]
alembic/testing/plugin/plugin_base.py [new file with mode: 0644]
alembic/testing/plugin/pytestplugin.py [new file with mode: 0644]
alembic/testing/util.py [new file with mode: 0644]
tests/conftest.py
tests/test_autogen_diffs.py

index 4b669268b101abf451a3a78983d1a2d986006c74..f1884a97ed9ff100a738bacaed48b2dab747b5a5 100644 (file)
@@ -16,4 +16,6 @@ from .assertions import is_false  # noqa
 from .assertions import is_not_  # noqa
 from .assertions import is_true  # noqa
 from .assertions import ne_  # noqa
+from .fixture_functions import combinations  # noqa
+from .fixture_functions import fixture  # noqa
 from .fixtures import TestBase  # noqa
diff --git a/alembic/testing/fixture_functions.py b/alembic/testing/fixture_functions.py
new file mode 100644 (file)
index 0000000..2640693
--- /dev/null
@@ -0,0 +1,79 @@
+_fixture_functions = None  # installed by plugin_base
+
+
+def combinations(*comb, **kw):
+    r"""Deliver multiple versions of a test based on positional combinations.
+
+    This is a facade over pytest.mark.parametrize.
+
+
+    :param \*comb: argument combinations.  These are tuples that will be passed
+     positionally to the decorated function.
+
+    :param argnames: optional list of argument names.   These are the names
+     of the arguments in the test function that correspond to the entries
+     in each argument tuple.   pytest.mark.parametrize requires this, however
+     the combinations function will derive it automatically if not present
+     by using ``inspect.getfullargspec(fn).args[1:]``.  Note this assumes the
+     first argument is "self" which is discarded.
+
+    :param id\_: optional id template.  This is a string template that
+     describes how the "id" for each parameter set should be defined, if any.
+     The number of characters in the template should match the number of
+     entries in each argument tuple.   Each character describes how the
+     corresponding entry in the argument tuple should be handled, as far as
+     whether or not it is included in the arguments passed to the function, as
+     well as if it is included in the tokens used to create the id of the
+     parameter set.
+
+     If omitted, the argument combinations are passed to parametrize as is.  If
+     passed, each argument combination is turned into a pytest.param() object,
+     mapping the elements of the argument tuple to produce an id based on a
+     character value in the same position within the string template using the
+     following scheme::
+
+        i - the given argument is a string that is part of the id only, don't
+            pass it as an argument
+
+        n - the given argument should be passed and it should be added to the
+            id by calling the .__name__ attribute
+
+        r - the given argument should be passed and it should be added to the
+            id by calling repr()
+
+        s - the given argument should be passed and it should be added to the
+            id by calling str()
+
+        a - (argument) the given argument should be passed and it should not
+            be used to generated the id
+
+     e.g.::
+
+        @testing.combinations(
+            (operator.eq, "eq"),
+            (operator.ne, "ne"),
+            (operator.gt, "gt"),
+            (operator.lt, "lt"),
+            id_="na"
+        )
+        def test_operator(self, opfunc, name):
+            pass
+
+    The above combination will call ``.__name__`` on the first member of
+    each tuple and use that as the "id" to pytest.param().
+
+
+    """
+    return _fixture_functions.combinations(*comb, **kw)
+
+
+def fixture(*arg, **kw):
+    return _fixture_functions.fixture(*arg, **kw)
+
+
+def get_current_test_name():
+    return _fixture_functions.get_current_test_name()
+
+
+def skip_test(msg):
+    raise _fixture_functions.skip_test_exception(msg)
diff --git a/alembic/testing/plugin/__init__.py b/alembic/testing/plugin/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/alembic/testing/plugin/bootstrap.py b/alembic/testing/plugin/bootstrap.py
new file mode 100644 (file)
index 0000000..8200ec1
--- /dev/null
@@ -0,0 +1,35 @@
+"""
+Bootstrapper for test framework plugins.
+
+This is vendored from SQLAlchemy so that we can use local overrides
+for plugin_base.py and pytestplugin.py.
+
+"""
+
+
+import os
+import sys
+
+
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
+
+
+def load_file_as_module(name):
+    path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
+    if sys.version_info >= (3, 3):
+        from importlib import machinery
+
+        mod = machinery.SourceFileLoader(name, path).load_module()
+    else:
+        import imp
+
+        mod = imp.load_source(name, path)
+    return mod
+
+
+if to_bootstrap == "pytest":
+    sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+    sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
+else:
+    raise Exception("unknown bootstrap: %s" % to_bootstrap)  # noqa
diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py
new file mode 100644 (file)
index 0000000..dc31c58
--- /dev/null
@@ -0,0 +1,83 @@
+"""vendored plugin_base functions from the most recent SQLAlchemy versions.
+
+Alembic tests need to run on older versions of SQLAlchemy that don't
+necessarily have all the latest testing fixtures.
+
+"""
+from __future__ import absolute_import
+
+import abc
+import sys
+
+from sqlalchemy.testing.plugin.plugin_base import *  # noqa
+from sqlalchemy.testing.plugin.plugin_base import post
+
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+
+    ABC = abc.ABC
+else:
+
+    class ABC(object):
+        __metaclass__ = abc.ABCMeta
+
+
+# override selected SQLAlchemy pytest hooks with vendored functionality
+def want_class(name, cls):
+    from sqlalchemy.testing import config
+    from sqlalchemy.testing import fixtures
+
+    if not issubclass(cls, fixtures.TestBase):
+        return False
+    elif name.startswith("_"):
+        return False
+    elif (
+        config.options.backend_only
+        and not getattr(cls, "__backend__", False)
+        and not getattr(cls, "__sparse_backend__", False)
+    ):
+        return False
+    else:
+        return True
+
+
+@post
+def _init_symbols(options, file_config):
+    from sqlalchemy.testing import config
+    from alembic.testing import fixture_functions as alembic_config
+
+    config._fixture_functions = (
+        alembic_config._fixture_functions
+    ) = _fixture_fn_class()
+
+
+class FixtureFunctions(ABC):
+    @abc.abstractmethod
+    def skip_test_exception(self, *arg, **kw):
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def combinations(self, *args, **kw):
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def param_ident(self, *args, **kw):
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def fixture(self, *arg, **kw):
+        raise NotImplementedError()
+
+    def get_current_test_name(self):
+        raise NotImplementedError()
+
+
+_fixture_fn_class = None
+
+
+def set_fixture_functions(fixture_fn_class):
+    from sqlalchemy.testing.plugin import plugin_base
+
+    global _fixture_fn_class
+    _fixture_fn_class = plugin_base._fixture_fn_class = fixture_fn_class
diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py
new file mode 100644 (file)
index 0000000..0b2da89
--- /dev/null
@@ -0,0 +1,217 @@
+"""vendored pytestplugin functions from the most recent SQLAlchemy versions.
+
+Alembic tests need to run on older versions of SQLAlchemy that don't
+necessarily have all the latest testing fixtures.
+
+"""
+try:
+    # installed by bootstrap.py
+    import sqla_plugin_base as plugin_base
+except ImportError:
+    # assume we're a package, use traditional import
+    from . import plugin_base
+
+import inspect
+import itertools
+import operator
+import os
+import re
+import sys
+
+import pytest
+from sqlalchemy.testing.plugin.pytestplugin import *  # noqa
+from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc
+
+
+# override selected SQLAlchemy pytest hooks with vendored functionality
+def pytest_configure(config):
+    spc(config)
+
+    plugin_base.set_fixture_functions(PytestFixtureFunctions)
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+
+    if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+        return [
+            pytest.Class(parametrize_cls.__name__, parent=collector)
+            for parametrize_cls in _parametrize_cls(collector.module, obj)
+        ]
+    elif (
+        inspect.isfunction(obj)
+        and isinstance(collector, pytest.Instance)
+        and plugin_base.want_method(collector.cls, obj)
+    ):
+        # None means, fall back to default logic, which includes
+        # method-level parametrize
+        return None
+    else:
+        # empty list means skip this item
+        return []
+
+
+_current_class = None
+
+
+def _parametrize_cls(module, cls):
+    """implement a class-based version of pytest parametrize."""
+
+    if "_sa_parametrize" not in cls.__dict__:
+        return [cls]
+
+    _sa_parametrize = cls._sa_parametrize
+    classes = []
+    for full_param_set in itertools.product(
+        *[params for argname, params in _sa_parametrize]
+    ):
+        cls_variables = {}
+
+        for argname, param in zip(
+            [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
+        ):
+            if not argname:
+                raise TypeError("need argnames for class-based combinations")
+            argname_split = re.split(r",\s*", argname)
+            for arg, val in zip(argname_split, param.values):
+                cls_variables[arg] = val
+        parametrized_name = "_".join(
+            # token is a string, but in py2k py.test is giving us a unicode,
+            # so call str() on it.
+            str(re.sub(r"\W", "", token))
+            for param in full_param_set
+            for token in param.id.split("-")
+        )
+        name = "%s_%s" % (cls.__name__, parametrized_name)
+        newcls = type.__new__(type, name, (cls,), cls_variables)
+        setattr(module, name, newcls)
+        classes.append(newcls)
+    return classes
+
+
+def getargspec(fn):
+    if sys.version_info.major == 3:
+        return inspect.getfullargspec(fn)
+    else:
+        return inspect.getargspec(fn)
+
+
+class PytestFixtureFunctions(plugin_base.FixtureFunctions):
+    def skip_test_exception(self, *arg, **kw):
+        return pytest.skip.Exception(*arg, **kw)
+
+    _combination_id_fns = {
+        "i": lambda obj: obj,
+        "r": repr,
+        "s": str,
+        "n": operator.attrgetter("__name__"),
+    }
+
+    def combinations(self, *arg_sets, **kw):
+        """facade for pytest.mark.paramtrize.
+
+        Automatically derives argument names from the callable which in our
+        case is always a method on a class with positional arguments.
+
+        ids for parameter sets are derived using an optional template.
+
+        """
+        from sqlalchemy.testing import exclusions
+
+        if sys.version_info.major == 3:
+            if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
+                arg_sets = list(arg_sets[0])
+        else:
+            if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
+                arg_sets = list(arg_sets[0])
+
+        argnames = kw.pop("argnames", None)
+
+        exclusion_combinations = []
+
+        def _filter_exclusions(args):
+            result = []
+            gathered_exclusions = []
+            for a in args:
+                if isinstance(a, exclusions.compound):
+                    gathered_exclusions.append(a)
+                else:
+                    result.append(a)
+
+            exclusion_combinations.extend(
+                [(exclusion, result) for exclusion in gathered_exclusions]
+            )
+            return result
+
+        id_ = kw.pop("id_", None)
+
+        if id_:
+            _combination_id_fns = self._combination_id_fns
+
+            # because itemgetter is not consistent for one argument vs.
+            # multiple, make it multiple in all cases and use a slice
+            # to omit the first argument
+            _arg_getter = operator.itemgetter(
+                0,
+                *[
+                    idx
+                    for idx, char in enumerate(id_)
+                    if char in ("n", "r", "s", "a")
+                ]
+            )
+            fns = [
+                (operator.itemgetter(idx), _combination_id_fns[char])
+                for idx, char in enumerate(id_)
+                if char in _combination_id_fns
+            ]
+            arg_sets = [
+                pytest.param(
+                    *_arg_getter(_filter_exclusions(arg))[1:],
+                    id="-".join(
+                        comb_fn(getter(arg)) for getter, comb_fn in fns
+                    )
+                )
+                for arg in arg_sets
+            ]
+        else:
+            # ensure using pytest.param so that even a 1-arg paramset
+            # still needs to be a tuple.  otherwise paramtrize tries to
+            # interpret a single arg differently than tuple arg
+            arg_sets = [
+                pytest.param(*_filter_exclusions(arg)) for arg in arg_sets
+            ]
+
+        def decorate(fn):
+            if inspect.isclass(fn):
+                if "_sa_parametrize" not in fn.__dict__:
+                    fn._sa_parametrize = []
+                fn._sa_parametrize.append((argnames, arg_sets))
+                return fn
+            else:
+                if argnames is None:
+                    _argnames = getargspec(fn).args[1:]
+                else:
+                    _argnames = argnames
+
+                if exclusion_combinations:
+                    for exclusion, combination in exclusion_combinations:
+                        combination_by_kw = {
+                            argname: val
+                            for argname, val in zip(_argnames, combination)
+                        }
+                        exclusion = exclusion.with_combination(
+                            **combination_by_kw
+                        )
+                        fn = exclusion(fn)
+                return pytest.mark.parametrize(_argnames, arg_sets)(fn)
+
+        return decorate
+
+    def param_ident(self, *parameters):
+        ident = parameters[0]
+        return pytest.param(*parameters[1:], id=ident)
+
+    def fixture(self, *arg, **kw):
+        return pytest.fixture(*arg, **kw)
+
+    def get_current_test_name(self):
+        return os.environ.get("PYTEST_CURRENT_TEST")
diff --git a/alembic/testing/util.py b/alembic/testing/util.py
new file mode 100644 (file)
index 0000000..87dfcd5
--- /dev/null
@@ -0,0 +1,80 @@
+# testing/util.py
+# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+
+def flag_combinations(*combinations):
+    """A facade around @testing.combinations() oriented towards boolean
+    keyword-based arguments.
+
+    Basically generates a nice looking identifier based on the keywords
+    and also sets up the argument names.
+
+    E.g.::
+
+        @testing.flag_combinations(
+            dict(lazy=False, passive=False),
+            dict(lazy=True, passive=False),
+            dict(lazy=False, passive=True),
+            dict(lazy=False, passive=True, raiseload=True),
+        )
+
+
+    would result in::
+
+        @testing.combinations(
+            ('', False, False, False),
+            ('lazy', True, False, False),
+            ('lazy_passive', True, True, False),
+            ('lazy_passive', True, True, True),
+            id_='iaaa',
+            argnames='lazy,passive,raiseload'
+        )
+
+    """
+    from sqlalchemy.testing import config
+
+    keys = set()
+
+    for d in combinations:
+        keys.update(d)
+
+    keys = sorted(keys)
+
+    return config.combinations(
+        *[
+            ("_".join(k for k in keys if d.get(k, False)),)
+            + tuple(d.get(k, False) for k in keys)
+            for d in combinations
+        ],
+        id_="i" + ("a" * len(keys)),
+        argnames=",".join(keys)
+    )
+
+
+def metadata_fixture(ddl="function"):
+    """Provide MetaData for a pytest fixture."""
+
+    from sqlalchemy.testing import config
+    from . import fixture_functions
+
+    def decorate(fn):
+        def run_ddl(self):
+            from sqlalchemy import schema
+
+            metadata = self.metadata = schema.MetaData()
+            try:
+                result = fn(self, metadata)
+                metadata.create_all(config.db)
+                # TODO:
+                # somehow get a per-function dml erase fixture here
+                yield result
+            finally:
+                metadata.drop_all(config.db)
+
+        return fixture_functions.fixture(scope=ddl)(run_ddl)
+
+    return decorate
index 4290f42843c3c54af297664a2c1c142731bca0a2..a83dff5833678cd5c7c8b62acce8583ba6b44025 100755 (executable)
@@ -8,15 +8,26 @@ installs SQLAlchemy's testing plugin into the local environment.
 """
 import os
 
-import sqlalchemy
+import pytest
+
+pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
+
 
 # ideally, SQLAlchemy would allow us to just import bootstrap,
 # but for now we have to use its "load from a file" approach
 
+# use bootstrapping so that test plugins are loaded
+# without touching the main library before coverage starts
 bootstrap_file = os.path.join(
-    os.path.dirname(sqlalchemy.__file__), "testing", "plugin", "bootstrap.py"
+    os.path.dirname(__file__),
+    "..",
+    "alembic",
+    "testing",
+    "plugin",
+    "bootstrap.py",
 )
 
+
 with open(bootstrap_file) as f:
     code = compile(f.read(), "bootstrap.py", "exec")
     to_bootstrap = "pytest"
index ef581c277984a87fda25cac62402c6f9a9892971..ca818e6dab0234160657de4c32042d28cfa32992 100644 (file)
@@ -29,6 +29,7 @@ from sqlalchemy.types import NULLTYPE
 from sqlalchemy.types import VARBINARY
 
 from alembic import autogenerate
+from alembic import testing
 from alembic.migration import MigrationContext
 from alembic.operations import ops
 from alembic.testing import assert_raises_message
@@ -633,7 +634,8 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
 
 
 class CompareTypeSpecificityTest(TestBase):
-    def _fixture(self):
+    @testing.fixture
+    def impl_fixture(self):
         from alembic.ddl import impl
         from sqlalchemy.engine import default
 
@@ -641,7 +643,7 @@ class CompareTypeSpecificityTest(TestBase):
             default.DefaultDialect(), None, False, True, None, {}
         )
 
-    def test_typedec_to_nonstandard(self):
+    def test_typedec_to_nonstandard(self, impl_fixture):
         class PasswordType(TypeDecorator):
             impl = VARBINARY
 
@@ -655,65 +657,40 @@ class CompareTypeSpecificityTest(TestBase):
                     impl = VARBINARY(self.length)
                 return dialect.type_descriptor(impl)
 
-        impl = self._fixture()
-        impl.compare_type(
+        impl_fixture.compare_type(
             Column("x", sqlite.NUMERIC(50)), Column("x", PasswordType(50))
         )
 
-    def test_string(self):
-        t1 = String(30)
-        t2 = String(40)
-        t3 = VARCHAR(30)
-        t4 = Integer
-
-        impl = self._fixture()
-        is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
-        is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
-        is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
-
-    def test_numeric(self):
-        t1 = Numeric(10, 5)
-        t2 = Numeric(12, 5)
-        t3 = DECIMAL(10, 5)
-        t4 = DateTime
-
-        impl = self._fixture()
-        is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
-        is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
-        is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
-
-    def test_numeric_noprecision(self):
-        t1 = Numeric()
-        t2 = Numeric(scale=5)
-
-        impl = self._fixture()
-        is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
-
-    def test_integer(self):
-        t1 = Integer()
-        t2 = SmallInteger()
-        t3 = BIGINT()
-        t4 = String()
-        t5 = INTEGER()
-        t6 = BigInteger()
-
-        impl = self._fixture()
-        is_(impl.compare_type(Column("x", t5), Column("x", t1)), False)
-        is_(impl.compare_type(Column("x", t3), Column("x", t1)), True)
-        is_(impl.compare_type(Column("x", t3), Column("x", t6)), False)
-        is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
-        is_(impl.compare_type(Column("x", t5), Column("x", t2)), True)
-        is_(impl.compare_type(Column("x", t1), Column("x", t4)), True)
-
-    def test_datetime(self):
-        t1 = DateTime()
-        t2 = DateTime(timezone=False)
-        t3 = DateTime(timezone=True)
-
-        impl = self._fixture()
-        is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
-        is_(impl.compare_type(Column("x", t1), Column("x", t3)), True)
-        is_(impl.compare_type(Column("x", t2), Column("x", t3)), True)
+    @testing.combinations(
+        (VARCHAR(30), String(30), False),
+        (VARCHAR(30), String(40), True),
+        (VARCHAR(30), Integer(), True),
+        (DECIMAL(10, 5), Numeric(10, 5), False),
+        (DECIMAL(10, 5), Numeric(12, 5), True),
+        (DECIMAL(10, 5), DateTime(), True),
+        (Numeric(), Numeric(scale=5), False),
+        (INTEGER(), Integer(), False),
+        (BIGINT(), Integer(), True),
+        (BIGINT(), BigInteger(), False),
+        (BIGINT(), SmallInteger(), True),
+        (INTEGER(), SmallInteger(), True),
+        (Integer(), String(), True),
+        (DateTime(), DateTime(timezone=False), False),
+        (DateTime(), DateTime(timezone=True), True),
+        (DateTime(timezone=False), DateTime(timezone=True), True),
+        id_="ssa",
+        argnames="compare_from,compare_to,expected",
+    )
+    def test_compare_type(
+        self, impl_fixture, compare_from, compare_to, expected
+    ):
+
+        is_(
+            impl_fixture.compare_type(
+                Column("x", compare_from), Column("x", compare_to)
+            ),
+            expected,
+        )
 
 
 class AutogenSystemColTest(AutogenTest, TestBase):