]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support exclusion rules in combinations
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 9 Nov 2019 17:33:16 +0000 (12:33 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 9 Nov 2019 22:57:44 +0000 (17:57 -0500)
Like py.test we need to be able to mark certain combination
elements with exclusion rules.   Add additional logic
to pytestlplugin and exclusions so that the exclusion decorators
can be added to the combination tuples, where they will be applied
to the decorated function along with a qualifier that the test
arguments need to match what's given.

Change-Id: I15d2839954d77a252bab5aaf6e3fd9f388c99dd5
(cherry picked from commit bbe754784ae4630dd0ebf30d3bc2be566f8a8fef)

lib/sqlalchemy/testing/exclusions.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
test/orm/test_defaults.py

index 02014e4b2eac36b05b0ae7d9c6b8d7609f2309bd..cf61091e3f10d868bb63cb5b5f4d40cdc38ce90e 100644 (file)
@@ -35,10 +35,20 @@ class compound(object):
         self.fails = set()
         self.skips = set()
         self.tags = set()
+        self.combinations = {}
 
     def __add__(self, other):
         return self.add(other)
 
+    def with_combination(self, **kw):
+        copy = compound()
+        copy.fails.update(self.fails)
+        copy.skips.update(self.skips)
+        copy.tags.update(self.tags)
+        copy.combinations.update((f, kw) for f in copy.fails)
+        copy.combinations.update((s, kw) for s in copy.skips)
+        return copy
+
     def add(self, *others):
         copy = compound()
         copy.fails.update(self.fails)
@@ -85,6 +95,7 @@ class compound(object):
         self.skips.update(other.skips)
         self.fails.update(other.fails)
         self.tags.update(other.tags)
+        self.combinations.update(other.combinations)
 
     def __call__(self, fn):
         if hasattr(fn, "_sa_exclusion_extend"):
@@ -107,43 +118,63 @@ class compound(object):
         try:
             yield
         except Exception as ex:
-            all_fails._expect_failure(config._current, ex)
+            all_fails._expect_failure(config._current, ex, None)
         else:
-            all_fails._expect_success(config._current)
+            all_fails._expect_success(config._current, None)
+
+    def _check_combinations(self, combination, predicate):
+        if predicate in self.combinations:
+            for k, v in combination:
+                if (
+                    k in self.combinations[predicate]
+                    and self.combinations[predicate][k] != v
+                ):
+                    return False
+        return True
 
     def _do(self, cfg, fn, *args, **kw):
+        if len(args) > 1:
+            insp = inspect_getfullargspec(fn)
+            combination = list(zip(insp.args[1:], args[1:]))
+        else:
+            combination = None
+
         for skip in self.skips:
-            if skip(cfg):
+            if self._check_combinations(combination, skip) and skip(cfg):
                 msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg))
                 config.skip_test(msg)
 
         try:
             return_value = fn(*args, **kw)
         except Exception as ex:
-            self._expect_failure(cfg, ex, name=fn.__name__)
+            self._expect_failure(cfg, ex, combination, name=fn.__name__)
         else:
-            self._expect_success(cfg, name=fn.__name__)
+            self._expect_success(cfg, combination, name=fn.__name__)
             return return_value
 
-    def _expect_failure(self, config, ex, name="block"):
+    def _expect_failure(self, config, ex, combination, name="block"):
         for fail in self.fails:
-            if fail(config):
+            if self._check_combinations(combination, fail) and fail(config):
+                if util.py2k:
+                    str_ex = unicode(ex).encode("utf-8", errors="ignore")
+                else:
+                    str_ex = str(ex)
                 print(
                     (
                         "%s failed as expected (%s): %s "
-                        % (name, fail._as_string(config), str(ex))
+                        % (name, fail._as_string(config), str_ex)
                     )
                 )
                 break
         else:
             util.raise_from_cause(ex)
 
-    def _expect_success(self, config, name="block"):
+    def _expect_success(self, config, combination, name="block"):
         if not self.fails:
             return
 
         for fail in self.fails:
-            if fail(config):
+            if self._check_combinations(combination, fail) and fail(config):
                 raise AssertionError(
                     "Unexpected success for '%s' (%s)"
                     % (
index 3c47cbce840cafbac0bbf434c619b1001bbbc5b0..5cb6d1b4cdcb987765bfcbd431edf81179716dfd 100644 (file)
@@ -317,6 +317,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
         ids for parameter sets are derived using an optional template.
 
         """
+        from sqlalchemy.testing import exclusions
 
         if sys.version_info.major == 3:
             if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
@@ -327,6 +328,22 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
 
         argnames = kw.pop("argnames", None)
 
+        exclusion_combinations = []
+
+        def _filter_exclusions(args):
+            result = []
+            gathered_exclusions = []
+            for a in args:
+                if isinstance(a, exclusions.compound):
+                    gathered_exclusions.append(a)
+                else:
+                    result.append(a)
+
+            exclusion_combinations.extend(
+                [(exclusion, result) for exclusion in gathered_exclusions]
+            )
+            return result
+
         id_ = kw.pop("id_", None)
 
         if id_:
@@ -350,7 +367,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
             ]
             arg_sets = [
                 pytest.param(
-                    *_arg_getter(arg)[1:],
+                    *_arg_getter(_filter_exclusions(arg))[1:],
                     id="-".join(
                         comb_fn(getter(arg)) for getter, comb_fn in fns
                     )
@@ -361,7 +378,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
             # ensure using pytest.param so that even a 1-arg paramset
             # still needs to be a tuple.  otherwise paramtrize tries to
             # interpret a single arg differently than tuple arg
-            arg_sets = [pytest.param(*arg) for arg in arg_sets]
+            arg_sets = [
+                pytest.param(*_filter_exclusions(arg)) for arg in arg_sets
+            ]
 
         def decorate(fn):
             if inspect.isclass(fn):
@@ -374,6 +393,17 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
                     _argnames = getargspec(fn).args[1:]
                 else:
                     _argnames = argnames
+
+                if exclusion_combinations:
+                    for exclusion, combination in exclusion_combinations:
+                        combination_by_kw = {
+                            argname: val
+                            for argname, val in zip(_argnames, combination)
+                        }
+                        exclusion = exclusion.with_combination(
+                            **combination_by_kw
+                        )
+                        fn = exclusion(fn)
                 return pytest.mark.parametrize(_argnames, arg_sets)(fn)
 
         return decorate
index 9f37dbf4da8226a1dc2b684c79a43e71143a7737..07b82b9dab9d1e73f26a1f50a45d35c0e0142b7f 100644 (file)
@@ -262,14 +262,10 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest):
                 ),
             )
 
-    @testing.requires.computed_columns_on_update_returning
-    def test_update_computed_eager(self):
-        self._test_update_computed(True)
-
-    def test_update_computed_noneager(self):
-        self._test_update_computed(False)
-
-    def _test_update_computed(self, eager):
+    @testing.combinations(
+        (True, testing.requires.computed_columns_on_update_returning), (False,)
+    )
+    def test_update_computed(self, eager):
         if eager:
             Thing = self.classes.Thing
         else: