]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support warnings in exclusions
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Oct 2025 21:19:26 +0000 (17:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Oct 2025 21:19:26 +0000 (17:19 -0400)
this adds a new feature to exclusions ``warns_if()`` which applies
the expect_warnings() context manager to a test method.  Additionally,
at the class level these requirements can be extracted from a
``__requirements__`` directive and also added to global Python warnings
filter using catch_warnings().

Change-Id: Ibe28d169106309a930731c77e201402152a38810

lib/sqlalchemy/testing/exclusions.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/plugin/pytestplugin.py

index d28e9d85e0cf9495b5b07ab92964eca646819168..1a6f88bf723fc6feecf77ab3a4db56357b10b782 100644 (file)
@@ -31,10 +31,18 @@ def fails_if(predicate, reason=None):
     return rule
 
 
+def warns_if(predicate, expression, assert_):
+    rule = compound()
+    pred = _as_predicate(predicate)
+    rule.warns[pred] = (expression, assert_)
+    return rule
+
+
 class compound:
     def __init__(self):
         self.fails = set()
         self.skips = set()
+        self.warns = {}
 
     def __add__(self, other):
         return self.add(other)
@@ -49,16 +57,24 @@ class compound:
         copy = compound()
         copy.fails.update(self.fails)
         copy.skips.update(self.skips)
+        copy.warns.update(self.warns)
 
         for other in others:
             copy.fails.update(other.fails)
             copy.skips.update(other.skips)
+            copy.warns.update(other.warns)
         return copy
 
     def not_(self):
         copy = compound()
         copy.fails.update(NotPredicate(fail) for fail in self.fails)
         copy.skips.update(NotPredicate(skip) for skip in self.skips)
+        copy.warns.update(
+            {
+                NotPredicate(warn): element
+                for warn, element in self.warns.items()
+            }
+        )
         return copy
 
     @property
@@ -72,6 +88,13 @@ class compound:
         else:
             return True
 
+    def matching_warnings(self, config):
+        return [
+            message
+            for predicate, (message, assert_) in self.warns.items()
+            if predicate(config)
+        ]
+
     def matching_config_reasons(self, config):
         return [
             predicate._as_string(config)
@@ -82,6 +105,7 @@ class compound:
     def _extend(self, other):
         self.skips.update(other.skips)
         self.fails.update(other.fails)
+        self.warns.update(other.warns)
 
     def __call__(self, fn):
         if hasattr(fn, "_sa_exclusion_extend"):
@@ -117,8 +141,25 @@ class compound:
                 )
                 config.skip_test(msg)
 
+        if self.warns:
+            from .assertions import expect_warnings
+
+            @contextlib.contextmanager
+            def _expect_warnings():
+                with contextlib.ExitStack() as stack:
+                    for expression, assert_ in self.warns.values():
+                        stack.enter_context(
+                            expect_warnings(expression, assert_=assert_)
+                        )
+                    yield
+
+            ctx = _expect_warnings()
+        else:
+            ctx = contextlib.nullcontext()
+
         try:
-            return_value = fn(*args, **kw)
+            with ctx:
+                return_value = fn(*args, **kw)
         except Exception as ex:
             self._expect_failure(cfg, ex, name=fn.__name__)
         else:
index 2dfa441413df816c38ac12195fe88c4755b8847a..96057e09422f2e0e9555d7d36e3ed18fce48f678 100644 (file)
@@ -660,6 +660,12 @@ def _possible_configs_for_cls(cls, reasons=None, sparse=False):
                         reasons.extend(skip_reasons)
                     break
 
+                warnings = check.matching_warnings(config_obj)
+                if warnings:
+                    cls.__warnings__ = getattr(
+                        cls, "__warnings__", ()
+                    ) + tuple(warnings)
+
     if hasattr(cls, "__prefer_requires__"):
         non_preferred = set()
         requirements = config.requirements
index 5b82c14bc479a38b4fc01bd322d68b1cb710d0c8..2071e6c3b0bd41d9c8506f42588f20e944d4f43f 100644 (file)
@@ -435,6 +435,8 @@ def _parametrize_cls(module, cls):
 
 _current_class = None
 
+_current_warning_context = None
+
 
 def pytest_runtest_setup(item):
     from sqlalchemy.testing import asyncio
@@ -445,7 +447,7 @@ def pytest_runtest_setup(item):
     # databases, so we run this outside of the pytest fixture system altogether
     # and ensure asyncio greenlet if any engines are async
 
-    global _current_class
+    global _current_class, _current_warning_context
 
     if isinstance(item, pytest.Function) and _current_class is None:
         asyncio._maybe_async_provisioning(
@@ -454,6 +456,14 @@ def pytest_runtest_setup(item):
         )
         _current_class = item.getparent(pytest.Class)
 
+        if hasattr(_current_class.cls, "__warnings__"):
+            import warnings
+
+            _current_warning_context = warnings.catch_warnings()
+            _current_warning_context.__enter__()
+            for warning_message in _current_class.cls.__warnings__:
+                warnings.filterwarnings("ignore", warning_message)
+
 
 @pytest.hookimpl(hookwrapper=True)
 def pytest_runtest_teardown(item, nextitem):
@@ -470,13 +480,19 @@ def pytest_runtest_teardown(item, nextitem):
     # pytest_runtest_setup since the class has not yet been setup at that
     # time.
     # See https://github.com/pytest-dev/pytest/issues/9343
-    global _current_class, _current_report
+
+    global _current_class, _current_report, _current_warning_context
 
     if _current_class is not None and (
         # last test or a new class
         nextitem is None
         or nextitem.getparent(pytest.Class) is not _current_class
     ):
+
+        if _current_warning_context is not None:
+            _current_warning_context.__exit__(None, None, None)
+            _current_warning_context = None
+
         _current_class = None
 
         try:
@@ -673,7 +689,8 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
 
     def mark_base_test_class(self):
         return pytest.mark.usefixtures(
-            "setup_class_methods", "setup_test_methods"
+            "setup_class_methods",
+            "setup_test_methods",
         )
 
     _combination_id_fns = {