]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- expect_warnings was not expecting and neither was assert_warnings
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Feb 2015 17:01:48 +0000 (12:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Feb 2015 17:01:48 +0000 (12:01 -0500)
asserting.

lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/warnings.py
test/sql/test_text.py

index 635f6c5399239c9711a73ad604a2cc8736158035..91d0c6339314becca83c4a24ec9eb78470da36f5 100644 (file)
@@ -22,7 +22,7 @@ import contextlib
 from . import mock
 
 
-def expect_warnings(*messages):
+def expect_warnings(*messages, **kw):
     """Context manager which expects one or more warnings.
 
     With no arguments, squelches all SAWarnings emitted via
@@ -30,17 +30,21 @@ def expect_warnings(*messages):
     pass string expressions that will match selected warnings via regex;
     all non-matching warnings are sent through.
 
+    The expect version **asserts** that the warnings were in fact seen.
+
     Note that the test suite sets SAWarning warnings to raise exceptions.
 
     """
-    return _expect_warnings(sa_exc.SAWarning, messages)
+    return _expect_warnings(sa_exc.SAWarning, messages, **kw)
 
 
 @contextlib.contextmanager
-def expect_warnings_on(db, *messages):
+def expect_warnings_on(db, *messages, **kw):
     """Context manager which expects one or more warnings on specific
     dialects.
 
+    The expect version **asserts** that the warnings were in fact seen.
+
     """
     spec = db_spec(db)
 
@@ -49,23 +53,28 @@ def expect_warnings_on(db, *messages):
     elif not _is_excluded(*db):
         yield
     else:
-        with expect_warnings(*messages):
+        with expect_warnings(*messages, **kw):
             yield
 
 
 def emits_warning(*messages):
-    """Decorator form of expect_warnings()."""
+    """Decorator form of expect_warnings().
+
+    Note that emits_warning does **not** assert that the warnings
+    were in fact seen.
+
+    """
 
     @decorator
     def decorate(fn, *args, **kw):
-        with expect_warnings(*messages):
+        with expect_warnings(assert_=False, *messages):
             return fn(*args, **kw)
 
     return decorate
 
 
-def expect_deprecated(*messages):
-    return _expect_warnings(sa_exc.SADeprecationWarning, messages)
+def expect_deprecated(*messages, **kw):
+    return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
 
 
 def emits_warning_on(db, *messages):
@@ -74,6 +83,10 @@ def emits_warning_on(db, *messages):
     With no arguments, squelches all SAWarning failures.  Or pass one or more
     strings; these will be matched to the root of the warning description by
     warnings.filterwarnings().
+
+    Note that emits_warning_on does **not** assert that the warnings
+    were in fact seen.
+
     """
     @decorator
     def decorate(fn, *args, **kw):
@@ -93,19 +106,28 @@ def uses_deprecated(*messages):
     As a special case, you may pass a function name prefixed with //
     and it will be re-written as needed to match the standard warning
     verbiage emitted by the sqlalchemy.util.deprecated decorator.
+
+    Note that uses_deprecated does **not** assert that the warnings
+    were in fact seen.
+
     """
 
     @decorator
     def decorate(fn, *args, **kw):
-        with expect_deprecated(*messages):
+        with expect_deprecated(*messages, assert_=False):
             return fn(*args, **kw)
     return decorate
 
 
 @contextlib.contextmanager
-def _expect_warnings(exc_cls, messages):
+def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
 
-    filters = [re.compile(msg, re.I) for msg in messages]
+    if regex:
+        filters = [re.compile(msg, re.I) for msg in messages]
+    else:
+        filters = messages
+
+    seen = set(filters)
 
     real_warn = warnings.warn
 
@@ -117,7 +139,9 @@ def _expect_warnings(exc_cls, messages):
             return
 
         for filter_ in filters:
-            if filter_.match(msg):
+            if (regex and filter_.match(msg)) or \
+                    (not regex and filter_ == msg):
+                seen.discard(filter_)
                 break
         else:
             real_warn(msg, exception, *arg, **kw)
@@ -125,6 +149,10 @@ def _expect_warnings(exc_cls, messages):
     with mock.patch("warnings.warn", our_warn):
         yield
 
+    if assert_:
+        assert not seen, "Warnings were not seen: %s" % \
+            ", ".join("%r" % (s.pattern if regex else s) for s in seen)
+
 
 def global_cleanup_assertions():
     """Check things that have to be finalized at the end of a test suite.
index 47f1e1404178967b496f5c76d9d4bdb940d15d40..640f02a78f800f39b2207404fd735bab0b814786 100644 (file)
@@ -9,7 +9,7 @@ from __future__ import absolute_import
 
 import warnings
 from .. import exc as sa_exc
-import re
+from . import assertions
 
 
 def setup_filters():
@@ -22,19 +22,13 @@ def setup_filters():
 
 
 def assert_warnings(fn, warning_msgs, regex=False):
-    """Assert that each of the given warnings are emitted by fn."""
-
-    from .assertions import eq_
-
-    with warnings.catch_warnings(record=True) as log:
-        # ensure that nothing is going into __warningregistry__
-        warnings.filterwarnings("always")
-
-        result = fn()
-    for warning in log:
-        popwarn = warning_msgs.pop(0)
-        if regex:
-            assert re.match(popwarn, str(warning.message))
-        else:
-            eq_(popwarn, str(warning.message))
-    return result
+    """Assert that each of the given warnings are emitted by fn.
+
+    Deprecated.  Please use assertions.expect_warnings().
+
+    """
+
+    with assertions._expect_warnings(
+            sa_exc.SAWarning, warning_msgs, regex=regex):
+        return fn()
+
index 60d90196e0cfa1c7f9d04fd0f7ffdcc7de30219c..4302dde48e645f4e308765a5b759217d1d1536e7 100644 (file)
@@ -496,6 +496,10 @@ class TextWarningsTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
     def _test(self, fn, arg, offending_clause, expected):
+        with expect_warnings("Textual "):
+            stmt = fn(arg)
+            self.assert_compile(stmt, expected)
+
         assert_raises_message(
             exc.SAWarning,
             r"Textual (?:SQL|column|SQL FROM) expression %(stmt)r should be "
@@ -505,10 +509,6 @@ class TextWarningsTest(fixtures.TestBase, AssertsCompiledSQL):
             fn, arg
         )
 
-        with expect_warnings("Textual "):
-            stmt = fn(arg)
-            self.assert_compile(stmt, expected)
-
     def test_where(self):
         self._test(
             select([table1.c.myid]).where, "myid == 5", "myid == 5",