]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refine this a bit to better check for exception type
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 31 Aug 2014 22:00:49 +0000 (18:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 31 Aug 2014 22:00:49 +0000 (18:00 -0400)
lib/sqlalchemy/testing/assertions.py

index dbe365ad5cd63e447bfd63ac1b85e3617acf2ae6..bf7c27a890681d3f4951f5139a91b618dc09eb1f 100644 (file)
@@ -11,7 +11,7 @@ from . import util as testutil
 from sqlalchemy import pool, orm, util
 from sqlalchemy.engine import default, url
 from sqlalchemy.util import decorator
-from sqlalchemy import types as sqltypes, schema
+from sqlalchemy import types as sqltypes, schema, exc as sa_exc
 import warnings
 import re
 from .exclusions import db_spec, _is_excluded
@@ -33,8 +33,7 @@ def expect_warnings(*messages):
     Note that the test suite sets SAWarning warnings to raise exceptions.
 
     """
-    return _expect_warnings(
-        "sqlalchemy.util.deprecations.warnings.warn", messages)
+    return _expect_warnings(sa_exc.SAWarning, messages)
 
 
 @contextlib.contextmanager
@@ -66,8 +65,7 @@ def emits_warning(*messages):
 
 
 def expect_deprecated(*messages):
-    return _expect_warnings(
-        "sqlalchemy.util.deprecations.warnings.warn", messages)
+    return _expect_warnings(sa_exc.SADeprecationWarning, messages)
 
 
 def emits_warning_on(db, *messages):
@@ -105,13 +103,16 @@ def uses_deprecated(*messages):
 
 
 @contextlib.contextmanager
-def _expect_warnings(to_patch, messages):
+def _expect_warnings(exc_cls, messages):
 
     filters = [re.compile(msg, re.I) for msg in messages]
 
     real_warn = warnings.warn
 
     def our_warn(msg, exception, *arg, **kw):
+        if not issubclass(exception, exc_cls):
+            return real_warn(msg, exception, *arg, **kw)
+
         if not filters:
             return
 
@@ -121,7 +122,7 @@ def _expect_warnings(to_patch, messages):
         else:
             real_warn(msg, exception, *arg, **kw)
 
-    with mock.patch(to_patch, our_warn):
+    with mock.patch("warnings.warn", our_warn):
         yield