]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rework the exclusions system to have much better support for compound
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Jul 2014 22:26:22 +0000 (18:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Jul 2014 22:26:22 +0000 (18:26 -0400)
rules, better message formatting

lib/sqlalchemy/engine/url.py
lib/sqlalchemy/testing/exclusions.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/schema.py
test/orm/test_naturalpks.py
test/requirements.py

index e3629613fe7aba375b0f92721efd30194be5e279..6544cfbf3e0cf2be04147e61ceea18e4cea94391 100644 (file)
@@ -105,6 +105,18 @@ class URL(object):
             self.database == other.database and \
             self.query == other.query
 
+    def get_backend_name(self):
+        if '+' not in self.drivername:
+            return self.drivername
+        else:
+            return self.drivername.split('+')[0]
+
+    def get_driver_name(self):
+        if '+' not in self.drivername:
+            return self.get_dialect().driver
+        else:
+            return self.drivername.split('+')[1]
+
     def get_dialect(self):
         """Return the SQLAlchemy database dialect class corresponding
         to this URL's driver name.
index fd43865aa578d246d6638c529fe616d7fbd9c06b..f6ef72408c6c3787a9b51f8835371d04fe718840 100644 (file)
@@ -11,83 +11,137 @@ from .plugin.plugin_base import SkipTest
 from ..util import decorator
 from . import config
 from .. import util
-import contextlib
 import inspect
+import contextlib
+
+
+def skip_if(predicate, reason=None):
+    rule = compound()
+    pred = _as_predicate(predicate, reason)
+    rule.skips.add(pred)
+    return rule
 
 
-class skip_if(object):
-    def __init__(self, predicate, reason=None):
-        self.predicate = _as_predicate(predicate)
-        self.reason = reason
+def fails_if(predicate, reason=None):
+    rule = compound()
+    pred = _as_predicate(predicate, reason)
+    rule.fails.add(pred)
+    return rule
 
-    _fails_on = None
+
+class compound(object):
+    def __init__(self):
+        self.fails = set()
+        self.skips = set()
 
     def __add__(self, other):
-        def decorate(fn):
-            return other(self(fn))
-        return decorate
+        return self.add(other)
+
+    def add(self, *others):
+        copy = compound()
+        copy.fails.update(self.fails)
+        copy.skips.update(self.skips)
+        for other in others:
+            copy.fails.update(other.fails)
+            copy.skips.update(other.skips)
+        return copy
+
+    def not_(self):
+        copy = compound()
+        copy.fails.update(NotPredicate(fail) for fail in self.fails)
+        copy.skips.update(NotPredicate(skip) for skip in self.skips)
+        return copy
 
     @property
     def enabled(self):
         return self.enabled_for_config(config._current)
 
     def enabled_for_config(self, config):
-        return not self.predicate(config)
-
-    @contextlib.contextmanager
-    def fail_if(self, name='block'):
-        try:
-            yield
-        except Exception as ex:
-            if self.predicate(config._current):
-                print(("%s failed as expected (%s): %s " % (
-                    name, self.predicate, str(ex))))
-            else:
-                raise
+        for predicate in self.skips.union(self.fails):
+            if predicate(config):
+                return False
         else:
-            if self.predicate(config._current):
-                raise AssertionError(
-                    "Unexpected success for '%s' (%s)" %
-                    (name, self.predicate))
+            return True
+
+    def matching_config_reasons(self, config):
+        return [
+            predicate._as_string(config) for predicate
+            in self.skips.union(self.fails)
+            if predicate(config)
+        ]
 
     def __call__(self, fn):
+        if hasattr(fn, '_sa_exclusion_extend'):
+            fn._sa_exclusion_extend(self)
+            return fn
+
+        def extend(other):
+            self.skips.update(other.skips)
+            self.fails.update(other.fails)
+
         @decorator
         def decorate(fn, *args, **kw):
-            if self.predicate(config._current):
-                if self.reason:
-                    msg = "'%s' : %s" % (
-                        fn.__name__,
-                        self.reason
-                    )
-                else:
-                    msg = "'%s': %s" % (
-                        fn.__name__, self.predicate
-                    )
-                raise SkipTest(msg)
-            else:
-                if self._fails_on:
-                    with self._fails_on.fail_if(name=fn.__name__):
-                        return fn(*args, **kw)
-                else:
-                    return fn(*args, **kw)
-        return decorate(fn)
+            return self._do(config._current, fn, *args, **kw)
+        decorated = decorate(fn)
+        decorated._sa_exclusion_extend = extend
+        return decorated
 
-    def fails_on(self, other, reason=None):
-        self._fails_on = skip_if(other, reason)
-        return self
 
-    def fails_on_everything_except(self, *dbs):
-        self._fails_on = skip_if(fails_on_everything_except(*dbs))
-        return self
+    @contextlib.contextmanager
+    def fail_if(self):
+        all_fails = compound()
+        all_fails.fails.update(self.skips.union(self.fails))
 
+        try:
+            yield
+        except Exception as ex:
+            all_fails._expect_failure(config._current, ex)
+        else:
+            all_fails._expect_success(config._current)
+
+    def _do(self, config, fn, *args, **kw):
+        for skip in self.skips:
+            if skip(config):
+                msg = "'%s' : %s" % (
+                    fn.__name__,
+                    skip._as_string(config)
+                )
+                raise SkipTest(msg)
 
-class fails_if(skip_if):
-    def __call__(self, fn):
-        @decorator
-        def decorate(fn, *args, **kw):
-            with self.fail_if(name=fn.__name__):
-                return fn(*args, **kw)
-        return decorate(fn)
+        try:
+            return_value = fn(*args, **kw)
+        except Exception as ex:
+            self._expect_failure(config, ex, name=fn.__name__)
+        else:
+            self._expect_success(config, name=fn.__name__)
+            return return_value
+
+    def _expect_failure(self, config, ex, name='block'):
+        for fail in self.fails:
+            if fail(config):
+                print(("%s failed as expected (%s): %s " % (
+                    name, fail._as_string(config), str(ex))))
+                break
+        else:
+            raise ex
+
+    def _expect_success(self, config, name='block'):
+        if not self.fails:
+            return
+        for fail in self.fails:
+            if not fail(config):
+                break
+        else:
+            raise AssertionError(
+                "Unexpected success for '%s' (%s)" %
+                (
+                    name,
+                    " and ".join(
+                        fail._as_string(config)
+                        for fail in self.fails
+                    )
+                )
+            )
 
 
 def only_if(predicate, reason=None):
@@ -102,13 +156,18 @@ def succeeds_if(predicate, reason=None):
 
 class Predicate(object):
     @classmethod
-    def as_predicate(cls, predicate):
-        if isinstance(predicate, skip_if):
-            return NotPredicate(predicate.predicate)
+    def as_predicate(cls, predicate, description=None):
+        if isinstance(predicate, compound):
+            return cls.as_predicate(predicate.fails.union(predicate.skips))
+
         elif isinstance(predicate, Predicate):
+            if description and predicate.description is None:
+                predicate.description = description
             return predicate
-        elif isinstance(predicate, list):
-            return OrPredicate([cls.as_predicate(pred) for pred in predicate])
+        elif isinstance(predicate, (list, set)):
+            return OrPredicate(
+                [cls.as_predicate(pred) for pred in predicate],
+                description)
         elif isinstance(predicate, tuple):
             return SpecPredicate(*predicate)
         elif isinstance(predicate, util.string_types):
@@ -119,12 +178,26 @@ class Predicate(object):
                 op = tokens.pop(0)
             if tokens:
                 spec = tuple(int(d) for d in tokens.pop(0).split("."))
-            return SpecPredicate(db, op, spec)
+            return SpecPredicate(db, op, spec, description=description)
         elif util.callable(predicate):
-            return LambdaPredicate(predicate)
+            return LambdaPredicate(predicate, description)
         else:
             assert False, "unknown predicate type: %s" % predicate
 
+    def _format_description(self, config, negate=False):
+        bool_ = self(config)
+        if negate:
+            bool_ = not negate
+        return self.description % {
+            "driver": config.db.url.get_driver_name(),
+            "database": config.db.url.get_backend_name(),
+            "doesnt_support": "doesn't support" if bool_ else "does support",
+            "does_support": "does support" if bool_ else "doesn't support"
+        }
+
+    def _as_string(self, config=None, negate=False):
+        raise NotImplementedError()
+
 
 class BooleanPredicate(Predicate):
     def __init__(self, value, description=None):
@@ -134,14 +207,8 @@ class BooleanPredicate(Predicate):
     def __call__(self, config):
         return self.value
 
-    def _as_string(self, negate=False):
-        if negate:
-            return "not " + self.description
-        else:
-            return self.description
-
-    def __str__(self):
-        return self._as_string()
+    def _as_string(self, config, negate=False):
+        return self._format_description(config, negate=negate)
 
 
 class SpecPredicate(Predicate):
@@ -185,9 +252,9 @@ class SpecPredicate(Predicate):
         else:
             return True
 
-    def _as_string(self, negate=False):
+    def _as_string(self, config, negate=False):
         if self.description is not None:
-            return self.description
+            return self._format_description(config)
         elif self.op is None:
             if negate:
                 return "not %s" % self.db
@@ -207,9 +274,6 @@ class SpecPredicate(Predicate):
                     self.spec
                 )
 
-    def __str__(self):
-        return self._as_string()
-
 
 class LambdaPredicate(Predicate):
     def __init__(self, lambda_, description=None, args=None, kw=None):
@@ -230,25 +294,23 @@ class LambdaPredicate(Predicate):
     def __call__(self, config):
         return self.lambda_(config)
 
-    def _as_string(self, negate=False):
-        if negate:
-            return "not " + self.description
-        else:
-            return self.description
-
-    def __str__(self):
-        return self._as_string()
+    def _as_string(self, config, negate=False):
+        return self._format_description(config)
 
 
 class NotPredicate(Predicate):
-    def __init__(self, predicate):
+    def __init__(self, predicate, description=None):
         self.predicate = predicate
+        self.description = description
 
     def __call__(self, config):
         return not self.predicate(config)
 
-    def __str__(self):
-        return self.predicate._as_string(True)
+    def _as_string(self, config, negate=False):
+        if self.description:
+            return self._format_description(config, not negate)
+        else:
+            return self.predicate._as_string(config, not negate)
 
 
 class OrPredicate(Predicate):
@@ -259,40 +321,32 @@ class OrPredicate(Predicate):
     def __call__(self, config):
         for pred in self.predicates:
             if pred(config):
-                self._str = pred
                 return True
         return False
 
-    _str = None
-
-    def _eval_str(self, negate=False):
-        if self._str is None:
-            if negate:
-                conjunction = " and "
-            else:
-                conjunction = " or "
-            return conjunction.join(p._as_string(negate=negate)
-                                    for p in self.predicates)
+    def _eval_str(self, config, negate=False):
+        if negate:
+            conjunction = " and "
         else:
-            return self._str._as_string(negate=negate)
+            conjunction = " or "
+        return conjunction.join(p._as_string(config, negate=negate)
+                                for p in self.predicates)
 
-    def _negation_str(self):
+    def _negation_str(self, config):
         if self.description is not None:
-            return "Not " + (self.description % {"spec": self._str})
+            return "Not " + self._format_description(config)
         else:
-            return self._eval_str(negate=True)
+            return self._eval_str(config, negate=True)
 
-    def _as_string(self, negate=False):
+    def _as_string(self, config, negate=False):
         if negate:
-            return self._negation_str()
+            return self._negation_str(config)
         else:
             if self.description is not None:
-                return self.description % {"spec": self._str}
+                return self._format_description(config)
             else:
-                return self._eval_str()
+                return self._eval_str(config)
 
-    def __str__(self):
-        return self._as_string()
 
 _as_predicate = Predicate.as_predicate
 
@@ -341,8 +395,8 @@ def fails_on(db, reason=None):
 def fails_on_everything_except(*dbs):
     return succeeds_if(
         OrPredicate([
-                    SpecPredicate(db) for db in dbs
-                    ])
+            SpecPredicate(db) for db in dbs
+        ])
     )
 
 
index 2590f3b1ee57ddaecf6ee4e9427815560315cbaf..fd7cb08f4a64350fffda64bc5e6c54b8148adc38 100644 (file)
@@ -356,7 +356,7 @@ def generate_sub_tests(cls, module):
                 (cls, ),
                 {
                     "__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)),
-                    "__backend__": False}
+                }
             )
             setattr(module, name, subcls)
             yield subcls
@@ -407,7 +407,7 @@ def after_test(test):
     warnings.resetwarnings()
 
 
-def _possible_configs_for_cls(cls):
+def _possible_configs_for_cls(cls, reasons=None):
     all_configs = set(config.Config.all_configs())
     if cls.__unsupported_on__:
         spec = exclusions.db_spec(*cls.__unsupported_on__)
@@ -419,12 +419,6 @@ def _possible_configs_for_cls(cls):
         for config_obj in list(all_configs):
             if not spec(config_obj):
                 all_configs.remove(config_obj)
-    return all_configs
-
-
-def _do_skips(cls):
-    all_configs = _possible_configs_for_cls(cls)
-    reasons = []
 
     if hasattr(cls, '__requires__'):
         requirements = config.requirements
@@ -432,10 +426,11 @@ def _do_skips(cls):
             for requirement in cls.__requires__:
                 check = getattr(requirements, requirement)
 
-                if check.predicate(config_obj):
+                skip_reasons = check.matching_config_reasons(config_obj)
+                if skip_reasons:
                     all_configs.remove(config_obj)
-                    if check.reason:
-                        reasons.append(check.reason)
+                    if reasons is not None:
+                        reasons.extend(skip_reasons)
                     break
 
     if hasattr(cls, '__prefer_requires__'):
@@ -445,11 +440,25 @@ def _do_skips(cls):
             for requirement in cls.__prefer_requires__:
                 check = getattr(requirements, requirement)
 
-                if check.predicate(config_obj):
+                if not check.enabled_for_config(config_obj):
                     non_preferred.add(config_obj)
         if all_configs.difference(non_preferred):
             all_configs.difference_update(non_preferred)
 
+    for db_spec, op, spec in getattr(cls, '__excluded_on__', ()):
+        for config_obj in list(all_configs):
+            if not exclusions.skip_if(
+                    exclusions.SpecPredicate(db_spec, op, spec)
+            ).enabled_for_config(config_obj):
+                all_configs.remove(config_obj)
+
+    return all_configs
+
+
+def _do_skips(cls):
+    reasons = []
+    all_configs = _possible_configs_for_cls(cls, reasons)
+
     if getattr(cls, '__skip_if__', False):
         for c in getattr(cls, '__skip_if__'):
             if c():
@@ -457,24 +466,26 @@ def _do_skips(cls):
                     cls.__name__, c.__name__)
                 )
 
-    for db_spec, op, spec in getattr(cls, '__excluded_on__', ()):
-        for config_obj in list(all_configs):
-            if exclusions.skip_if(
-                    exclusions.SpecPredicate(db_spec, op, spec)
-            ).predicate(config_obj):
-                all_configs.remove(config_obj)
     if not all_configs:
-        raise SkipTest(
-            "'%s' unsupported on DB implementation %s%s" % (
+        if getattr(cls, '__backend__', False):
+            msg = "'%s' unsupported for implementation '%s'" % (
+                cls.__name__, cls.__only_on__)
+        else:
+            msg = "'%s' unsupported on any DB implementation %s%s" % (
                 cls.__name__,
-                ", ".join("'%s' = %s"
-                          % (config_obj.db.name,
-                             config_obj.db.dialect.server_version_info)
-                          for config_obj in config.Config.all_configs()
-                          ),
+                ", ".join(
+                    "'%s(%s)+%s'" % (
+                        config_obj.db.name,
+                        ".".join(
+                            str(dig) for dig in
+                            config_obj.db.dialect.server_version_info),
+                        config_obj.db.driver
+                    )
+                  for config_obj in config.Config.all_configs()
+                ),
                 ", ".join(reasons)
             )
-        )
+        raise SkipTest(msg)
     elif hasattr(cls, '__prefer_backends__'):
         non_preferred = set()
         spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
index 8fe7a509013a7bc98b5c7b9cb8e1bc7956c1ea63..fbb0d63e2b1a73472e183f2e7257ca72a4216934 100644 (file)
@@ -187,7 +187,7 @@ class SuiteRequirements(Requirements):
 
         return exclusions.only_if(
             lambda config: config.db.dialect.implicit_returning,
-            "'returning' not supported by database"
+            "%(database)s %(does_support)s 'returning'"
         )
 
     @property
index 1cb356dd71552252e570ed773696ef0fbba12b6d..9561b1f1ef9c445761acac2c076425f2af8ff4eb 100644 (file)
@@ -67,7 +67,7 @@ def Column(*args, **kw):
     test_opts = dict([(k, kw.pop(k)) for k in list(kw)
                       if k.startswith('test_')])
 
-    if config.requirements.foreign_key_ddl.predicate(config):
+    if not config.requirements.foreign_key_ddl.enabled_for_config(config):
         args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
 
     col = schema.Column(*args, **kw)
index ac5e723c702e6868565818a2ff79b87107d522b1..53b661a49b6a26e4e664ac15dafbaadff05faac0 100644 (file)
@@ -394,7 +394,7 @@ class NaturalPKTest(fixtures.MappedTest):
         self._test_manytomany(True)
 
     @testing.requires.non_updating_cascade
-    @testing.requires.sane_multi_rowcount
+    @testing.requires.sane_multi_rowcount.not_()
     def test_manytomany_nonpassive(self):
         self._test_manytomany(False)
 
index 024f32c5428dfd1575ea3d0abc61be51643f78ec..24984b062c7977fea7a0489b1abf45a5a60053a8 100644 (file)
@@ -60,7 +60,7 @@ class DefaultRequirements(SuiteRequirements):
 
         return skip_if(
                     ['sqlite', 'oracle'],
-                    'target backend does not support ON UPDATE CASCADE'
+                    'target backend %(doesnt_support)s ON UPDATE CASCADE'
                 )
 
     @property
@@ -68,7 +68,8 @@ class DefaultRequirements(SuiteRequirements):
         """target database must *not* support ON UPDATE..CASCADE behavior in
         foreign keys."""
 
-        return fails_on_everything_except('sqlite', 'oracle', '+zxjdbc') + skip_if('mssql')
+        return fails_on_everything_except('sqlite', 'oracle', '+zxjdbc') + \
+            skip_if('mssql')
 
     @property
     def deferrable_fks(self):
@@ -208,7 +209,7 @@ class DefaultRequirements(SuiteRequirements):
         return only_on(
                     ('postgresql', 'sqlite', 'mysql'),
                     "DBAPI has no isolation level support"
-                ).fails_on('postgresql+pypostgresql',
+                ) + fails_on('postgresql+pypostgresql',
                           'pypostgresql bombs on multiple isolation level calls')
 
     @property
@@ -463,9 +464,9 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def sane_multi_rowcount(self):
         return fails_if(
-                    lambda config: not config.db.dialect.supports_sane_multi_rowcount,
-                    "driver doesn't support 'sane' multi row count"
-                )
+            lambda config: not config.db.dialect.supports_sane_multi_rowcount,
+            "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
+        )
 
     @property
     def nullsordering(self):
@@ -717,12 +718,14 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def percent_schema_names(self):
         return skip_if(
-                [
-                    ("+psycopg2", None, None,
-                            "psycopg2 2.4 no longer accepts % in bind placeholders"),
-                    ("mysql", None, None, "executemany() doesn't work here")
-                ]
-            )
+            [
+                (
+                    "+psycopg2", None, None,
+                    "psycopg2 2.4 no longer accepts percent "
+                    "sign in bind placeholders"),
+                ("mysql", None, None, "executemany() doesn't work here")
+            ]
+        )
 
     @property
     def order_by_label_with_expression(self):