]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-105730: support more callables in ExceptionGroup.split() and subgroup() (#106035)
authorIrit Katriel <1055913+iritkatriel@users.noreply.github.com>
Fri, 23 Jun 2023 18:47:47 +0000 (19:47 +0100)
committerGitHub <noreply@github.com>
Fri, 23 Jun 2023 18:47:47 +0000 (19:47 +0100)
Doc/library/exceptions.rst
Lib/test/test_exception_group.py
Misc/NEWS.d/next/Core and Builtins/2023-06-23-16-51-02.gh-issue-105730.16haMe.rst [new file with mode: 0644]
Objects/exceptions.c

index 4c84e5f855431a6a855f58e56b6d7f8d920393c9..8e574b8334e445a613027abf357be35580f8af96 100644 (file)
@@ -912,10 +912,11 @@ their subgroups based on the types of the contained exceptions.
       Returns an exception group that contains only the exceptions from the
       current group that match *condition*, or ``None`` if the result is empty.
 
-      The condition can be either a function that accepts an exception and returns
-      true for those that should be in the subgroup, or it can be an exception type
-      or a tuple of exception types, which is used to check for a match using the
-      same check that is used in an ``except`` clause.
+      The condition can be an exception type or tuple of exception types, in which
+      case each exception is checked for a match using the same check that is used
+      in an ``except`` clause.  The condition can also be a callable (other than
+      a type object) that accepts an exception as its single argument and returns
+      true for the exceptions that should be in the subgroup.
 
       The nesting structure of the current exception is preserved in the result,
       as are the values of its :attr:`message`, :attr:`__traceback__`,
@@ -926,6 +927,9 @@ their subgroups based on the types of the contained exceptions.
       including the top-level and any nested exception groups. If the condition is
       true for such an exception group, it is included in the result in full.
 
+      .. versionadded:: 3.13
+         ``condition`` can be any callable which is not a type object.
+
    .. method:: split(condition)
 
       Like :meth:`subgroup`, but returns the pair ``(match, rest)`` where ``match``
index fa159a76ec1aff6604afe509cf3a8922d1fd803e..2658e027ff3e2b34042f4bbab24117329e54d8dd 100644 (file)
@@ -294,6 +294,15 @@ class ExceptionGroupTestBase(unittest.TestCase):
             self.assertEqual(type(exc), type(template))
             self.assertEqual(exc.args, template.args)
 
+class Predicate:
+    def __init__(self, func):
+        self.func = func
+
+    def __call__(self, e):
+        return self.func(e)
+
+    def method(self, e):
+        return self.func(e)
 
 class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
     def setUp(self):
@@ -301,10 +310,15 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
         self.eg_template = [ValueError(1), TypeError(int), ValueError(2)]
 
     def test_basics_subgroup_split__bad_arg_type(self):
+        class C:
+            pass
+
         bad_args = ["bad arg",
+                    C,
                     OSError('instance not type'),
                     [OSError, TypeError],
-                    (OSError, 42)]
+                    (OSError, 42),
+                   ]
         for arg in bad_args:
             with self.assertRaises(TypeError):
                 self.eg.subgroup(arg)
@@ -336,10 +350,14 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
                 self.assertMatchesTemplate(subeg, ExceptionGroup, template)
 
     def test_basics_subgroup_by_predicate__passthrough(self):
-        self.assertIs(self.eg, self.eg.subgroup(lambda e: True))
+        f = lambda e: True
+        for callable in [f, Predicate(f), Predicate(f).method]:
+            self.assertIs(self.eg, self.eg.subgroup(callable))
 
     def test_basics_subgroup_by_predicate__no_match(self):
-        self.assertIsNone(self.eg.subgroup(lambda e: False))
+        f = lambda e: False
+        for callable in [f, Predicate(f), Predicate(f).method]:
+            self.assertIsNone(self.eg.subgroup(callable))
 
     def test_basics_subgroup_by_predicate__match(self):
         eg = self.eg
@@ -350,9 +368,12 @@ class ExceptionGroupSubgroupTests(ExceptionGroupTestBase):
             ((ValueError, TypeError), self.eg_template)]
 
         for match_type, template in testcases:
-            subeg = eg.subgroup(lambda e: isinstance(e, match_type))
-            self.assertEqual(subeg.message, eg.message)
-            self.assertMatchesTemplate(subeg, ExceptionGroup, template)
+            f = lambda e: isinstance(e, match_type)
+            for callable in [f, Predicate(f), Predicate(f).method]:
+                with self.subTest(callable=callable):
+                    subeg = eg.subgroup(f)
+                    self.assertEqual(subeg.message, eg.message)
+                    self.assertMatchesTemplate(subeg, ExceptionGroup, template)
 
 
 class ExceptionGroupSplitTests(ExceptionGroupTestBase):
@@ -399,14 +420,18 @@ class ExceptionGroupSplitTests(ExceptionGroupTestBase):
                 self.assertIsNone(rest)
 
     def test_basics_split_by_predicate__passthrough(self):
-        match, rest = self.eg.split(lambda e: True)
-        self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template)
-        self.assertIsNone(rest)
+        f = lambda e: True
+        for callable in [f, Predicate(f), Predicate(f).method]:
+            match, rest = self.eg.split(callable)
+            self.assertMatchesTemplate(match, ExceptionGroup, self.eg_template)
+            self.assertIsNone(rest)
 
     def test_basics_split_by_predicate__no_match(self):
-        match, rest = self.eg.split(lambda e: False)
-        self.assertIsNone(match)
-        self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template)
+        f = lambda e: False
+        for callable in [f, Predicate(f), Predicate(f).method]:
+            match, rest = self.eg.split(callable)
+            self.assertIsNone(match)
+            self.assertMatchesTemplate(rest, ExceptionGroup, self.eg_template)
 
     def test_basics_split_by_predicate__match(self):
         eg = self.eg
@@ -420,14 +445,16 @@ class ExceptionGroupSplitTests(ExceptionGroupTestBase):
         ]
 
         for match_type, match_template, rest_template in testcases:
-            match, rest = eg.split(lambda e: isinstance(e, match_type))
-            self.assertEqual(match.message, eg.message)
-            self.assertMatchesTemplate(
-                match, ExceptionGroup, match_template)
-            if rest_template is not None:
-                self.assertEqual(rest.message, eg.message)
+            f = lambda e: isinstance(e, match_type)
+            for callable in [f, Predicate(f), Predicate(f).method]:
+                match, rest = eg.split(callable)
+                self.assertEqual(match.message, eg.message)
                 self.assertMatchesTemplate(
-                    rest, ExceptionGroup, rest_template)
+                    match, ExceptionGroup, match_template)
+                if rest_template is not None:
+                    self.assertEqual(rest.message, eg.message)
+                    self.assertMatchesTemplate(
+                        rest, ExceptionGroup, rest_template)
 
 
 class DeepRecursionInSplitAndSubgroup(unittest.TestCase):
diff --git a/Misc/NEWS.d/next/Core and Builtins/2023-06-23-16-51-02.gh-issue-105730.16haMe.rst b/Misc/NEWS.d/next/Core and Builtins/2023-06-23-16-51-02.gh-issue-105730.16haMe.rst
new file mode 100644 (file)
index 0000000..fa70ee0
--- /dev/null
@@ -0,0 +1,2 @@
+Allow any callable other than type objects as the condition predicate in
+:meth:`BaseExceptionGroup.split` and :meth:`BaseExceptionGroup.subgroup`.
index 04ea22c2902df7df9d399177398f138a11a1fa72..f27e6f6c1431a01cb8df01c855577311e0e0d190 100644 (file)
@@ -992,7 +992,7 @@ get_matcher_type(PyObject *value,
 {
     assert(value);
 
-    if (PyFunction_Check(value)) {
+    if (PyCallable_Check(value) && !PyType_Check(value)) {
         *type = EXCEPTION_GROUP_MATCH_BY_PREDICATE;
         return 0;
     }
@@ -1016,7 +1016,7 @@ get_matcher_type(PyObject *value,
 error:
     PyErr_SetString(
         PyExc_TypeError,
-        "expected a function, exception type or tuple of exception types");
+        "expected an exception type, a tuple of exception types, or a callable (other than a class)");
     return -1;
 }
 
@@ -1032,7 +1032,7 @@ exceptiongroup_split_check_match(PyObject *exc,
         return PyErr_GivenExceptionMatches(exc, matcher_value);
     }
     case EXCEPTION_GROUP_MATCH_BY_PREDICATE: {
-        assert(PyFunction_Check(matcher_value));
+        assert(PyCallable_Check(matcher_value) && !PyType_Check(matcher_value));
         PyObject *exc_matches = PyObject_CallOneArg(matcher_value, exc);
         if (exc_matches == NULL) {
             return -1;