from .assertions import is_not_ # noqa
from .assertions import is_true # noqa
from .assertions import ne_ # noqa
+from .fixture_functions import combinations # noqa
+from .fixture_functions import fixture # noqa
from .fixtures import TestBase # noqa
--- /dev/null
+_fixture_functions = None # installed by plugin_base
+
+
+def combinations(*comb, **kw):
+ r"""Deliver multiple versions of a test based on positional combinations.
+
+ This is a facade over pytest.mark.parametrize.
+
+
+ :param \*comb: argument combinations. These are tuples that will be passed
+ positionally to the decorated function.
+
+ :param argnames: optional list of argument names. These are the names
+ of the arguments in the test function that correspond to the entries
+ in each argument tuple. pytest.mark.parametrize requires this, however
+ the combinations function will derive it automatically if not present
+ by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
+ first argument is "self" which is discarded.
+
+ :param id\_: optional id template. This is a string template that
+ describes how the "id" for each parameter set should be defined, if any.
+ The number of characters in the template should match the number of
+ entries in each argument tuple. Each character describes how the
+ corresponding entry in the argument tuple should be handled, as far as
+ whether or not it is included in the arguments passed to the function, as
+ well as if it is included in the tokens used to create the id of the
+ parameter set.
+
+ If omitted, the argument combinations are passed to parametrize as is. If
+ passed, each argument combination is turned into a pytest.param() object,
+ mapping the elements of the argument tuple to produce an id based on a
+ character value in the same position within the string template using the
+ following scheme::
+
+ i - the given argument is a string that is part of the id only, don't
+ pass it as an argument
+
+ n - the given argument should be passed and it should be added to the
+ id by calling the .__name__ attribute
+
+ r - the given argument should be passed and it should be added to the
+ id by calling repr()
+
+ s - the given argument should be passed and it should be added to the
+ id by calling str()
+
+ a - (argument) the given argument should be passed and it should not
+ be used to generated the id
+
+ e.g.::
+
+ @testing.combinations(
+ (operator.eq, "eq"),
+ (operator.ne, "ne"),
+ (operator.gt, "gt"),
+ (operator.lt, "lt"),
+ id_="na"
+ )
+ def test_operator(self, opfunc, name):
+ pass
+
+ The above combination will call ``.__name__`` on the first member of
+ each tuple and use that as the "id" to pytest.param().
+
+
+ """
+ return _fixture_functions.combinations(*comb, **kw)
+
+
+def fixture(*arg, **kw):
+ return _fixture_functions.fixture(*arg, **kw)
+
+
+def get_current_test_name():
+ return _fixture_functions.get_current_test_name()
+
+
+def skip_test(msg):
+ raise _fixture_functions.skip_test_exception(msg)
--- /dev/null
+"""
+Bootstrapper for test framework plugins.
+
+This is vendored from SQLAlchemy so that we can use local overrides
+for plugin_base.py and pytestplugin.py.
+
+"""
+
+
+import os
+import sys
+
+
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
+
+
+def load_file_as_module(name):
+ path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
+ if sys.version_info >= (3, 3):
+ from importlib import machinery
+
+ mod = machinery.SourceFileLoader(name, path).load_module()
+ else:
+ import imp
+
+ mod = imp.load_source(name, path)
+ return mod
+
+
+if to_bootstrap == "pytest":
+ sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+ sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
+else:
+ raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
--- /dev/null
+"""vendored plugin_base functions from the most recent SQLAlchemy versions.
+
+Alembic tests need to run on older versions of SQLAlchemy that don't
+necessarily have all the latest testing fixtures.
+
+"""
+from __future__ import absolute_import
+
+import abc
+import sys
+
+from sqlalchemy.testing.plugin.plugin_base import * # noqa
+from sqlalchemy.testing.plugin.plugin_base import post
+
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+
+ ABC = abc.ABC
+else:
+
+ class ABC(object):
+ __metaclass__ = abc.ABCMeta
+
+
+# override selected SQLAlchemy pytest hooks with vendored functionality
+def want_class(name, cls):
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing import fixtures
+
+ if not issubclass(cls, fixtures.TestBase):
+ return False
+ elif name.startswith("_"):
+ return False
+ elif (
+ config.options.backend_only
+ and not getattr(cls, "__backend__", False)
+ and not getattr(cls, "__sparse_backend__", False)
+ ):
+ return False
+ else:
+ return True
+
+
+@post
+def _init_symbols(options, file_config):
+ from sqlalchemy.testing import config
+ from alembic.testing import fixture_functions as alembic_config
+
+ config._fixture_functions = (
+ alembic_config._fixture_functions
+ ) = _fixture_fn_class()
+
+
+class FixtureFunctions(ABC):
+ @abc.abstractmethod
+ def skip_test_exception(self, *arg, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def combinations(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def param_ident(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def fixture(self, *arg, **kw):
+ raise NotImplementedError()
+
+ def get_current_test_name(self):
+ raise NotImplementedError()
+
+
+_fixture_fn_class = None
+
+
+def set_fixture_functions(fixture_fn_class):
+ from sqlalchemy.testing.plugin import plugin_base
+
+ global _fixture_fn_class
+ _fixture_fn_class = plugin_base._fixture_fn_class = fixture_fn_class
--- /dev/null
+"""vendored pytestplugin functions from the most recent SQLAlchemy versions.
+
+Alembic tests need to run on older versions of SQLAlchemy that don't
+necessarily have all the latest testing fixtures.
+
+"""
+try:
+ # installed by bootstrap.py
+ import sqla_plugin_base as plugin_base
+except ImportError:
+ # assume we're a package, use traditional import
+ from . import plugin_base
+
+import inspect
+import itertools
+import operator
+import os
+import re
+import sys
+
+import pytest
+from sqlalchemy.testing.plugin.pytestplugin import * # noqa
+from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc
+
+
+# override selected SQLAlchemy pytest hooks with vendored functionality
+def pytest_configure(config):
+ spc(config)
+
+ plugin_base.set_fixture_functions(PytestFixtureFunctions)
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+
+ if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+ return [
+ pytest.Class(parametrize_cls.__name__, parent=collector)
+ for parametrize_cls in _parametrize_cls(collector.module, obj)
+ ]
+ elif (
+ inspect.isfunction(obj)
+ and isinstance(collector, pytest.Instance)
+ and plugin_base.want_method(collector.cls, obj)
+ ):
+ # None means, fall back to default logic, which includes
+ # method-level parametrize
+ return None
+ else:
+ # empty list means skip this item
+ return []
+
+
+_current_class = None
+
+
+def _parametrize_cls(module, cls):
+ """implement a class-based version of pytest parametrize."""
+
+ if "_sa_parametrize" not in cls.__dict__:
+ return [cls]
+
+ _sa_parametrize = cls._sa_parametrize
+ classes = []
+ for full_param_set in itertools.product(
+ *[params for argname, params in _sa_parametrize]
+ ):
+ cls_variables = {}
+
+ for argname, param in zip(
+ [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
+ ):
+ if not argname:
+ raise TypeError("need argnames for class-based combinations")
+ argname_split = re.split(r",\s*", argname)
+ for arg, val in zip(argname_split, param.values):
+ cls_variables[arg] = val
+ parametrized_name = "_".join(
+ # token is a string, but in py2k py.test is giving us a unicode,
+ # so call str() on it.
+ str(re.sub(r"\W", "", token))
+ for param in full_param_set
+ for token in param.id.split("-")
+ )
+ name = "%s_%s" % (cls.__name__, parametrized_name)
+ newcls = type.__new__(type, name, (cls,), cls_variables)
+ setattr(module, name, newcls)
+ classes.append(newcls)
+ return classes
+
+
+def getargspec(fn):
+ if sys.version_info.major == 3:
+ return inspect.getfullargspec(fn)
+ else:
+ return inspect.getargspec(fn)
+
+
+class PytestFixtureFunctions(plugin_base.FixtureFunctions):
+ def skip_test_exception(self, *arg, **kw):
+ return pytest.skip.Exception(*arg, **kw)
+
+ _combination_id_fns = {
+ "i": lambda obj: obj,
+ "r": repr,
+ "s": str,
+ "n": operator.attrgetter("__name__"),
+ }
+
+ def combinations(self, *arg_sets, **kw):
+ """facade for pytest.mark.paramtrize.
+
+ Automatically derives argument names from the callable which in our
+ case is always a method on a class with positional arguments.
+
+ ids for parameter sets are derived using an optional template.
+
+ """
+ from sqlalchemy.testing import exclusions
+
+ if sys.version_info.major == 3:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
+ arg_sets = list(arg_sets[0])
+ else:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
+ arg_sets = list(arg_sets[0])
+
+ argnames = kw.pop("argnames", None)
+
+ exclusion_combinations = []
+
+ def _filter_exclusions(args):
+ result = []
+ gathered_exclusions = []
+ for a in args:
+ if isinstance(a, exclusions.compound):
+ gathered_exclusions.append(a)
+ else:
+ result.append(a)
+
+ exclusion_combinations.extend(
+ [(exclusion, result) for exclusion in gathered_exclusions]
+ )
+ return result
+
+ id_ = kw.pop("id_", None)
+
+ if id_:
+ _combination_id_fns = self._combination_id_fns
+
+ # because itemgetter is not consistent for one argument vs.
+ # multiple, make it multiple in all cases and use a slice
+ # to omit the first argument
+ _arg_getter = operator.itemgetter(
+ 0,
+ *[
+ idx
+ for idx, char in enumerate(id_)
+ if char in ("n", "r", "s", "a")
+ ]
+ )
+ fns = [
+ (operator.itemgetter(idx), _combination_id_fns[char])
+ for idx, char in enumerate(id_)
+ if char in _combination_id_fns
+ ]
+ arg_sets = [
+ pytest.param(
+ *_arg_getter(_filter_exclusions(arg))[1:],
+ id="-".join(
+ comb_fn(getter(arg)) for getter, comb_fn in fns
+ )
+ )
+ for arg in arg_sets
+ ]
+ else:
+ # ensure using pytest.param so that even a 1-arg paramset
+ # still needs to be a tuple. otherwise paramtrize tries to
+ # interpret a single arg differently than tuple arg
+ arg_sets = [
+ pytest.param(*_filter_exclusions(arg)) for arg in arg_sets
+ ]
+
+ def decorate(fn):
+ if inspect.isclass(fn):
+ if "_sa_parametrize" not in fn.__dict__:
+ fn._sa_parametrize = []
+ fn._sa_parametrize.append((argnames, arg_sets))
+ return fn
+ else:
+ if argnames is None:
+ _argnames = getargspec(fn).args[1:]
+ else:
+ _argnames = argnames
+
+ if exclusion_combinations:
+ for exclusion, combination in exclusion_combinations:
+ combination_by_kw = {
+ argname: val
+ for argname, val in zip(_argnames, combination)
+ }
+ exclusion = exclusion.with_combination(
+ **combination_by_kw
+ )
+ fn = exclusion(fn)
+ return pytest.mark.parametrize(_argnames, arg_sets)(fn)
+
+ return decorate
+
+ def param_ident(self, *parameters):
+ ident = parameters[0]
+ return pytest.param(*parameters[1:], id=ident)
+
+ def fixture(self, *arg, **kw):
+ return pytest.fixture(*arg, **kw)
+
+ def get_current_test_name(self):
+ return os.environ.get("PYTEST_CURRENT_TEST")
--- /dev/null
+# testing/util.py
+# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+
+def flag_combinations(*combinations):
+ """A facade around @testing.combinations() oriented towards boolean
+ keyword-based arguments.
+
+ Basically generates a nice looking identifier based on the keywords
+ and also sets up the argument names.
+
+ E.g.::
+
+ @testing.flag_combinations(
+ dict(lazy=False, passive=False),
+ dict(lazy=True, passive=False),
+ dict(lazy=False, passive=True),
+ dict(lazy=False, passive=True, raiseload=True),
+ )
+
+
+ would result in::
+
+ @testing.combinations(
+ ('', False, False, False),
+ ('lazy', True, False, False),
+ ('lazy_passive', True, True, False),
+ ('lazy_passive', True, True, True),
+ id_='iaaa',
+ argnames='lazy,passive,raiseload'
+ )
+
+ """
+ from sqlalchemy.testing import config
+
+ keys = set()
+
+ for d in combinations:
+ keys.update(d)
+
+ keys = sorted(keys)
+
+ return config.combinations(
+ *[
+ ("_".join(k for k in keys if d.get(k, False)),)
+ + tuple(d.get(k, False) for k in keys)
+ for d in combinations
+ ],
+ id_="i" + ("a" * len(keys)),
+ argnames=",".join(keys)
+ )
+
+
+def metadata_fixture(ddl="function"):
+ """Provide MetaData for a pytest fixture."""
+
+ from sqlalchemy.testing import config
+ from . import fixture_functions
+
+ def decorate(fn):
+ def run_ddl(self):
+ from sqlalchemy import schema
+
+ metadata = self.metadata = schema.MetaData()
+ try:
+ result = fn(self, metadata)
+ metadata.create_all(config.db)
+ # TODO:
+ # somehow get a per-function dml erase fixture here
+ yield result
+ finally:
+ metadata.drop_all(config.db)
+
+ return fixture_functions.fixture(scope=ddl)(run_ddl)
+
+ return decorate
"""
import os
-import sqlalchemy
+import pytest
+
+pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
+
# ideally, SQLAlchemy would allow us to just import bootstrap,
# but for now we have to use its "load from a file" approach
+# use bootstrapping so that test plugins are loaded
+# without touching the main library before coverage starts
bootstrap_file = os.path.join(
- os.path.dirname(sqlalchemy.__file__), "testing", "plugin", "bootstrap.py"
+ os.path.dirname(__file__),
+ "..",
+ "alembic",
+ "testing",
+ "plugin",
+ "bootstrap.py",
)
+
with open(bootstrap_file) as f:
code = compile(f.read(), "bootstrap.py", "exec")
to_bootstrap = "pytest"
from sqlalchemy.types import VARBINARY
from alembic import autogenerate
+from alembic import testing
from alembic.migration import MigrationContext
from alembic.operations import ops
from alembic.testing import assert_raises_message
class CompareTypeSpecificityTest(TestBase):
- def _fixture(self):
+ @testing.fixture
+ def impl_fixture(self):
from alembic.ddl import impl
from sqlalchemy.engine import default
default.DefaultDialect(), None, False, True, None, {}
)
- def test_typedec_to_nonstandard(self):
+ def test_typedec_to_nonstandard(self, impl_fixture):
class PasswordType(TypeDecorator):
impl = VARBINARY
impl = VARBINARY(self.length)
return dialect.type_descriptor(impl)
- impl = self._fixture()
- impl.compare_type(
+ impl_fixture.compare_type(
Column("x", sqlite.NUMERIC(50)), Column("x", PasswordType(50))
)
- def test_string(self):
- t1 = String(30)
- t2 = String(40)
- t3 = VARCHAR(30)
- t4 = Integer
-
- impl = self._fixture()
- is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
- is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
- is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
-
- def test_numeric(self):
- t1 = Numeric(10, 5)
- t2 = Numeric(12, 5)
- t3 = DECIMAL(10, 5)
- t4 = DateTime
-
- impl = self._fixture()
- is_(impl.compare_type(Column("x", t3), Column("x", t1)), False)
- is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
- is_(impl.compare_type(Column("x", t3), Column("x", t4)), True)
-
- def test_numeric_noprecision(self):
- t1 = Numeric()
- t2 = Numeric(scale=5)
-
- impl = self._fixture()
- is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
-
- def test_integer(self):
- t1 = Integer()
- t2 = SmallInteger()
- t3 = BIGINT()
- t4 = String()
- t5 = INTEGER()
- t6 = BigInteger()
-
- impl = self._fixture()
- is_(impl.compare_type(Column("x", t5), Column("x", t1)), False)
- is_(impl.compare_type(Column("x", t3), Column("x", t1)), True)
- is_(impl.compare_type(Column("x", t3), Column("x", t6)), False)
- is_(impl.compare_type(Column("x", t3), Column("x", t2)), True)
- is_(impl.compare_type(Column("x", t5), Column("x", t2)), True)
- is_(impl.compare_type(Column("x", t1), Column("x", t4)), True)
-
- def test_datetime(self):
- t1 = DateTime()
- t2 = DateTime(timezone=False)
- t3 = DateTime(timezone=True)
-
- impl = self._fixture()
- is_(impl.compare_type(Column("x", t1), Column("x", t2)), False)
- is_(impl.compare_type(Column("x", t1), Column("x", t3)), True)
- is_(impl.compare_type(Column("x", t2), Column("x", t3)), True)
+ @testing.combinations(
+ (VARCHAR(30), String(30), False),
+ (VARCHAR(30), String(40), True),
+ (VARCHAR(30), Integer(), True),
+ (DECIMAL(10, 5), Numeric(10, 5), False),
+ (DECIMAL(10, 5), Numeric(12, 5), True),
+ (DECIMAL(10, 5), DateTime(), True),
+ (Numeric(), Numeric(scale=5), False),
+ (INTEGER(), Integer(), False),
+ (BIGINT(), Integer(), True),
+ (BIGINT(), BigInteger(), False),
+ (BIGINT(), SmallInteger(), True),
+ (INTEGER(), SmallInteger(), True),
+ (Integer(), String(), True),
+ (DateTime(), DateTime(timezone=False), False),
+ (DateTime(), DateTime(timezone=True), True),
+ (DateTime(timezone=False), DateTime(timezone=True), True),
+ id_="ssa",
+ argnames="compare_from,compare_to,expected",
+ )
+ def test_compare_type(
+ self, impl_fixture, compare_from, compare_to, expected
+ ):
+
+ is_(
+ impl_fixture.compare_type(
+ Column("x", compare_from), Column("x", compare_to)
+ ),
+ expected,
+ )
class AutogenSystemColTest(AutogenTest, TestBase):