]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-101410: support custom messages for domain errors in the math module (#124299)
authorSergey B Kirpichev <skirpichev@gmail.com>
Thu, 23 Jan 2025 13:55:25 +0000 (16:55 +0300)
committerGitHub <noreply@github.com>
Thu, 23 Jan 2025 13:55:25 +0000 (13:55 +0000)
This adds basic support to override default messages for domain errors
in the math_1() helper.  The sqrt(), atanh(), log2(), log10() and log()
functions were modified as examples.  New macro supports gradual
changing of error messages in other 1-arg functions.

Co-authored-by: CharlieZhao <zhaoyu_hit@qq.com>
Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Lib/test/test_math.py
Misc/NEWS.d/next/Library/2023-02-01-16-41-31.gh-issue-101410.Dt2aQE.rst [new file with mode: 0644]
Modules/mathmodule.c

index 6976a5d85da01994ae3a95afd2632b2d6ae24b25..2c57d288bc03ff6765176ab6aec07947a7c1dee1 100644 (file)
@@ -2503,6 +2503,46 @@ class MathTests(unittest.TestCase):
         self.assertRaises(TypeError, math.atan2, 1.0)
         self.assertRaises(TypeError, math.atan2, 1.0, 2.0, 3.0)
 
+    def test_exception_messages(self):
+        x = -1.1
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a nonnegative input, got {x}"):
+            math.sqrt(x)
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {x}"):
+            math.log(x)
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {x}"):
+            math.log(123, x)
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {x}"):
+            math.log(x, 123)
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {x}"):
+            math.log2(x)
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {x}"):
+            math.log10(x)
+        x = decimal.Decimal('-1.1')
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {x}"):
+            math.log(x)
+        x = fractions.Fraction(1, 10**400)
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {float(x)}"):
+            math.log(x)
+        x = -123
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a positive input, got {x}"):
+            math.log(x)
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a float or nonnegative integer, got {x}"):
+            math.gamma(x)
+        x = 1.0
+        with self.assertRaisesRegex(ValueError,
+                                    f"expected a number between -1 and 1, got {x}"):
+            math.atanh(x)
+
     # Custom assertions.
 
     def assertIsNaN(self, value):
diff --git a/Misc/NEWS.d/next/Library/2023-02-01-16-41-31.gh-issue-101410.Dt2aQE.rst b/Misc/NEWS.d/next/Library/2023-02-01-16-41-31.gh-issue-101410.Dt2aQE.rst
new file mode 100644 (file)
index 0000000..3149368
--- /dev/null
@@ -0,0 +1,3 @@
+Support custom messages for domain errors in the :mod:`math` module
+(:func:`math.sqrt`, :func:`math.log` and :func:`math.atanh` were modified as
+examples).  Patch by Charlie Zhao and Sergey B Kirpichev.
index 29638114dd94a9108053c4cf3c8b629a5ddf3edf..b4c15a143f9838e091ebd9b23037e8009494646d 100644 (file)
@@ -858,12 +858,15 @@ math_lcm_impl(PyObject *module, PyObject * const *args,
  * true (1), but may return false (0) without setting up an exception.
  */
 static int
-is_error(double x)
+is_error(double x, int raise_edom)
 {
     int result = 1;     /* presumption of guilt */
     assert(errno);      /* non-zero errno is a precondition for calling */
-    if (errno == EDOM)
-        PyErr_SetString(PyExc_ValueError, "math domain error");
+    if (errno == EDOM) {
+        if (raise_edom) {
+            PyErr_SetString(PyExc_ValueError, "math domain error");
+        }
+    }
 
     else if (errno == ERANGE) {
         /* ANSI C generally requires libm functions to set ERANGE
@@ -928,7 +931,8 @@ is_error(double x)
 */
 
 static PyObject *
-math_1(PyObject *arg, double (*func) (double), int can_overflow)
+math_1(PyObject *arg, double (*func) (double), int can_overflow,
+       const char *err_msg)
 {
     double x, r;
     x = PyFloat_AsDouble(arg);
@@ -936,25 +940,34 @@ math_1(PyObject *arg, double (*func) (double), int can_overflow)
         return NULL;
     errno = 0;
     r = (*func)(x);
-    if (isnan(r) && !isnan(x)) {
-        PyErr_SetString(PyExc_ValueError,
-                        "math domain error"); /* invalid arg */
-        return NULL;
-    }
+    if (isnan(r) && !isnan(x))
+        goto domain_err; /* domain error */
     if (isinf(r) && isfinite(x)) {
         if (can_overflow)
             PyErr_SetString(PyExc_OverflowError,
                             "math range error"); /* overflow */
         else
-            PyErr_SetString(PyExc_ValueError,
-                            "math domain error"); /* singularity */
+            goto domain_err; /* singularity */
         return NULL;
     }
-    if (isfinite(r) && errno && is_error(r))
+    if (isfinite(r) && errno && is_error(r, 1))
         /* this branch unnecessary on most platforms */
         return NULL;
 
     return PyFloat_FromDouble(r);
+
+domain_err:
+    if (err_msg) {
+        char *buf = PyOS_double_to_string(x, 'r', 0, Py_DTSF_ADD_DOT_0, NULL);
+        if (buf) {
+            PyErr_Format(PyExc_ValueError, err_msg, buf);
+            PyMem_Free(buf);
+        }
+       }
+       else {
+               PyErr_SetString(PyExc_ValueError, "math domain error");
+       }
+    return NULL;
 }
 
 /* variant of math_1, to be used when the function being wrapped is known to
@@ -962,7 +975,7 @@ math_1(PyObject *arg, double (*func) (double), int can_overflow)
    errno = ERANGE for overflow). */
 
 static PyObject *
-math_1a(PyObject *arg, double (*func) (double))
+math_1a(PyObject *arg, double (*func) (double), const char *err_msg)
 {
     double x, r;
     x = PyFloat_AsDouble(arg);
@@ -970,8 +983,17 @@ math_1a(PyObject *arg, double (*func) (double))
         return NULL;
     errno = 0;
     r = (*func)(x);
-    if (errno && is_error(r))
+    if (errno && is_error(r, err_msg ? 0 : 1)) {
+        if (err_msg && errno == EDOM) {
+            assert(!PyErr_Occurred()); /* exception is not set by is_error() */
+            char *buf = PyOS_double_to_string(x, 'r', 0, Py_DTSF_ADD_DOT_0, NULL);
+            if (buf) {
+                PyErr_Format(PyExc_ValueError, err_msg, buf);
+                PyMem_Free(buf);
+            }
+        }
         return NULL;
+    }
     return PyFloat_FromDouble(r);
 }
 
@@ -1031,7 +1053,7 @@ math_2(PyObject *const *args, Py_ssize_t nargs,
         else
             errno = 0;
     }
-    if (errno && is_error(r))
+    if (errno && is_error(r, 1))
         return NULL;
     else
         return PyFloat_FromDouble(r);
@@ -1039,13 +1061,25 @@ math_2(PyObject *const *args, Py_ssize_t nargs,
 
 #define FUNC1(funcname, func, can_overflow, docstring)                  \
     static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
-        return math_1(args, func, can_overflow);                            \
+        return math_1(args, func, can_overflow, NULL);                  \
+    }\
+    PyDoc_STRVAR(math_##funcname##_doc, docstring);
+
+#define FUNC1D(funcname, func, can_overflow, docstring, err_msg)        \
+    static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
+        return math_1(args, func, can_overflow, err_msg);               \
     }\
     PyDoc_STRVAR(math_##funcname##_doc, docstring);
 
 #define FUNC1A(funcname, func, docstring)                               \
     static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
-        return math_1a(args, func);                                     \
+        return math_1a(args, func, NULL);                               \
+    }\
+    PyDoc_STRVAR(math_##funcname##_doc, docstring);
+
+#define FUNC1AD(funcname, func, docstring, err_msg)                     \
+    static PyObject * math_##funcname(PyObject *self, PyObject *args) { \
+        return math_1a(args, func, err_msg);                            \
     }\
     PyDoc_STRVAR(math_##funcname##_doc, docstring);
 
@@ -1077,9 +1111,10 @@ FUNC2(atan2, atan2,
       "atan2($module, y, x, /)\n--\n\n"
       "Return the arc tangent (measured in radians) of y/x.\n\n"
       "Unlike atan(y/x), the signs of both x and y are considered.")
-FUNC1(atanh, atanh, 0,
+FUNC1D(atanh, atanh, 0,
       "atanh($module, x, /)\n--\n\n"
-      "Return the inverse hyperbolic tangent of x.")
+      "Return the inverse hyperbolic tangent of x.",
+      "expected a number between -1 and 1, got %s")
 FUNC1(cbrt, cbrt, 0,
       "cbrt($module, x, /)\n--\n\n"
       "Return the cube root of x.")
@@ -1190,9 +1225,10 @@ math_floor(PyObject *module, PyObject *number)
     return PyLong_FromDouble(floor(x));
 }
 
-FUNC1A(gamma, m_tgamma,
+FUNC1AD(gamma, m_tgamma,
       "gamma($module, x, /)\n--\n\n"
-      "Gamma function at x.")
+      "Gamma function at x.",
+      "expected a float or nonnegative integer, got %s")
 FUNC1A(lgamma, m_lgamma,
       "lgamma($module, x, /)\n--\n\n"
       "Natural logarithm of absolute value of Gamma function at x.")
@@ -1212,9 +1248,10 @@ FUNC1(sin, sin, 0,
 FUNC1(sinh, sinh, 1,
       "sinh($module, x, /)\n--\n\n"
       "Return the hyperbolic sine of x.")
-FUNC1(sqrt, sqrt, 0,
+FUNC1D(sqrt, sqrt, 0,
       "sqrt($module, x, /)\n--\n\n"
-      "Return the square root of x.")
+      "Return the square root of x.",
+      "expected a nonnegative input, got %s")
 FUNC1(tan, tan, 0,
       "tan($module, x, /)\n--\n\n"
       "Return the tangent of x (measured in radians).")
@@ -2141,7 +2178,7 @@ math_ldexp_impl(PyObject *module, double x, PyObject *i)
             errno = ERANGE;
     }
 
-    if (errno && is_error(r))
+    if (errno && is_error(r, 1))
         return NULL;
     return PyFloat_FromDouble(r);
 }
@@ -2195,8 +2232,8 @@ loghelper(PyObject* arg, double (*func)(double))
 
         /* Negative or zero inputs give a ValueError. */
         if (!_PyLong_IsPositive((PyLongObject *)arg)) {
-            PyErr_SetString(PyExc_ValueError,
-                            "math domain error");
+            PyErr_Format(PyExc_ValueError,
+                         "expected a positive input, got %S", arg);
             return NULL;
         }
 
@@ -2220,7 +2257,7 @@ loghelper(PyObject* arg, double (*func)(double))
     }
 
     /* Else let libm handle it by itself. */
-    return math_1(arg, func, 0);
+    return math_1(arg, func, 0, "expected a positive input, got %s");
 }
 
 
@@ -2369,7 +2406,7 @@ math_fmod_impl(PyObject *module, double x, double y)
         else
             errno = 0;
     }
-    if (errno && is_error(r))
+    if (errno && is_error(r, 1))
         return NULL;
     else
         return PyFloat_FromDouble(r);
@@ -3010,7 +3047,7 @@ math_pow_impl(PyObject *module, double x, double y)
         }
     }
 
-    if (errno && is_error(r))
+    if (errno && is_error(r, 1))
         return NULL;
     else
         return PyFloat_FromDouble(r);