]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-146358: Fix warnings.catch_warnings on Free Threading (#146374)
authorVictor Stinner <vstinner@python.org>
Wed, 25 Mar 2026 13:34:50 +0000 (14:34 +0100)
committerGitHub <noreply@github.com>
Wed, 25 Mar 2026 13:34:50 +0000 (14:34 +0100)
catch_warnings now also overrides warnings.showwarning() on Free
Threading to support custom warnings.showwarning().

Lib/_py_warnings.py
Lib/test/test_warnings/__init__.py

index d5a9cec86f3674458eb934deec209d0b664fcf07..81a386c4487d953d8c97cb7101b85580feb5efc9 100644 (file)
@@ -703,8 +703,8 @@ class catch_warnings(object):
                 context = None
                 self._filters = self._module.filters
                 self._module.filters = self._filters[:]
-                self._showwarning = self._module.showwarning
                 self._showwarnmsg_impl = self._module._showwarnmsg_impl
+            self._showwarning = self._module.showwarning
             self._module._filters_mutated_lock_held()
             if self._record:
                 if _use_context:
@@ -712,9 +712,9 @@ class catch_warnings(object):
                 else:
                     log = []
                     self._module._showwarnmsg_impl = log.append
-                    # Reset showwarning() to the default implementation to make sure
-                    # that _showwarnmsg() calls _showwarnmsg_impl()
-                    self._module.showwarning = self._module._showwarning_orig
+                # Reset showwarning() to the default implementation to make sure
+                # that _showwarnmsg() calls _showwarnmsg_impl()
+                self._module.showwarning = self._module._showwarning_orig
             else:
                 log = None
         if self._filter is not None:
@@ -729,8 +729,8 @@ class catch_warnings(object):
                 self._module._warnings_context.set(self._saved_context)
             else:
                 self._module.filters = self._filters
-                self._module.showwarning = self._showwarning
                 self._module._showwarnmsg_impl = self._showwarnmsg_impl
+            self._module.showwarning = self._showwarning
             self._module._filters_mutated_lock_held()
 
 
index a6af5057cc8968882e458abf0d003db2dfe5ddf2..d86844c1a29a9a7b7a1ca5ba8f4e7c6e2b6f1934 100644 (file)
@@ -1,5 +1,6 @@
 from contextlib import contextmanager
 import linecache
+import logging
 import os
 import importlib
 import inspect
@@ -509,6 +510,47 @@ class FilterTests(BaseTest):
                     stderr = stderr.getvalue()
                     self.assertIn(error_msg, stderr)
 
+    def test_catchwarnings_with_showwarning(self):
+        # gh-146358: catch_warnings must override warnings.showwarning()
+        # if it's not the default implementation.
+
+        warns = []
+        def custom_showwarning(message, category, filename, lineno,
+                               file=None, line=None):
+            warns.append(message)
+
+        with self.module.catch_warnings():
+            self.module.resetwarnings()
+
+            with support.swap_attr(self.module, 'showwarning',
+                                   custom_showwarning):
+                with self.module.catch_warnings(record=True) as recorded:
+                    self.module.warn("recorded")
+                self.assertEqual(len(recorded), 1)
+                self.assertEqual(str(recorded[0].message), 'recorded')
+                self.assertIs(self.module.showwarning, custom_showwarning)
+
+                self.module.warn("custom")
+
+        self.assertEqual(len(warns), 1)
+        self.assertEqual(str(warns[0]), "custom")
+
+    def test_catchwarnings_logging(self):
+        # gh-146358: catch_warnings(record=True) must replace the
+        # showwarning() function set by logging.captureWarnings(True).
+
+        with self.module.catch_warnings():
+            self.module.resetwarnings()
+            logging.captureWarnings(True)
+
+            with self.module.catch_warnings(record=True) as recorded:
+                self.module.warn("recorded")
+            self.assertEqual(len(recorded), 1)
+            self.assertEqual(str(recorded[0].message), 'recorded')
+
+            logging.captureWarnings(False)
+
+
 class CFilterTests(FilterTests, unittest.TestCase):
     module = c_warnings