]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-73468: Add math.fma() function (#116667)
authorVictor Stinner <vstinner@python.org>
Sun, 17 Mar 2024 13:58:26 +0000 (14:58 +0100)
committerGitHub <noreply@github.com>
Sun, 17 Mar 2024 13:58:26 +0000 (13:58 +0000)
Added new math.fma() function, wrapping C99's ``fma()`` operation:
fused multiply-add function.

Co-authored-by: Mark Dickinson <mdickinson@enthought.com>
Doc/library/math.rst
Doc/whatsnew/3.13.rst
Lib/test/test_math.py
Misc/NEWS.d/next/Library/2024-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst [new file with mode: 0644]
Modules/clinic/mathmodule.c.h
Modules/mathmodule.c

index 93755be717e2ef430464d31c8aa4a5e60e9ef128..1475d26486de5fb190cb0a34a0c2b46799045684 100644 (file)
@@ -82,6 +82,22 @@ Number-theoretic and representation functions
    should return an :class:`~numbers.Integral` value.
 
 
+.. function:: fma(x, y, z)
+
+   Fused multiply-add operation. Return ``(x * y) + z``, computed as though with
+   infinite precision and range followed by a single round to the ``float``
+   format. This operation often provides better accuracy than the direct
+   expression ``(x * y) + z``.
+
+   This function follows the specification of the fusedMultiplyAdd operation
+   described in the IEEE 754 standard. The standard leaves one case
+   implementation-defined, namely the result of ``fma(0, inf, nan)``
+   and ``fma(inf, 0, nan)``. In these cases, ``math.fma`` returns a NaN,
+   and does not raise any exception.
+
+   .. versionadded:: 3.13
+
+
 .. function:: fmod(x, y)
 
    Return ``fmod(x, y)``, as defined by the platform C library. Note that the
index 03ae601890538002ac6e11f5c39d1fa2edee0dc2..78b5868b6549c6049f12fb00b74f728ed2f18aae 100644 (file)
@@ -383,6 +383,16 @@ marshal
   code objects which are incompatible between Python versions.
   (Contributed by Serhiy Storchaka in :gh:`113626`.)
 
+math
+----
+
+A new function :func:`~math.fma` for fused multiply-add operations has been
+added. This function computes ``x * y + z`` with only a single round, and so
+avoids any intermediate loss of precision. It wraps the ``fma()`` function
+provided by C99, and follows the specification of the IEEE 754
+"fusedMultiplyAdd" operation for special cases.
+(Contributed by Mark Dickinson and Victor Stinner in :gh:`73468`.)
+
 mmap
 ----
 
index ad382fc2b5989145a33c449f7ad5d1c1770a02d2..aaa3b16d33fb7dc52ba3adbaacafa1987d174421 100644 (file)
@@ -2613,6 +2613,244 @@ class IsCloseTests(unittest.TestCase):
         self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
 
 
+class FMATests(unittest.TestCase):
+    """ Tests for math.fma. """
+
+    def test_fma_nan_results(self):
+        # Selected representative values.
+        values = [
+            -math.inf, -1e300, -2.3, -1e-300, -0.0,
+            0.0, 1e-300, 2.3, 1e300, math.inf, math.nan
+        ]
+
+        # If any input is a NaN, the result should be a NaN, too.
+        for a, b in itertools.product(values, repeat=2):
+            self.assertIsNaN(math.fma(math.nan, a, b))
+            self.assertIsNaN(math.fma(a, math.nan, b))
+            self.assertIsNaN(math.fma(a, b, math.nan))
+
+    def test_fma_infinities(self):
+        # Cases involving infinite inputs or results.
+        positives = [1e-300, 2.3, 1e300, math.inf]
+        finites = [-1e300, -2.3, -1e-300, -0.0, 0.0, 1e-300, 2.3, 1e300]
+        non_nans = [-math.inf, -2.3, -0.0, 0.0, 2.3, math.inf]
+
+        # ValueError due to inf * 0 computation.
+        for c in non_nans:
+            for infinity in [math.inf, -math.inf]:
+                for zero in [0.0, -0.0]:
+                    with self.assertRaises(ValueError):
+                        math.fma(infinity, zero, c)
+                    with self.assertRaises(ValueError):
+                        math.fma(zero, infinity, c)
+
+        # ValueError when a*b and c both infinite of opposite signs.
+        for b in positives:
+            with self.assertRaises(ValueError):
+                math.fma(math.inf, b, -math.inf)
+            with self.assertRaises(ValueError):
+                math.fma(math.inf, -b, math.inf)
+            with self.assertRaises(ValueError):
+                math.fma(-math.inf, -b, -math.inf)
+            with self.assertRaises(ValueError):
+                math.fma(-math.inf, b, math.inf)
+            with self.assertRaises(ValueError):
+                math.fma(b, math.inf, -math.inf)
+            with self.assertRaises(ValueError):
+                math.fma(-b, math.inf, math.inf)
+            with self.assertRaises(ValueError):
+                math.fma(-b, -math.inf, -math.inf)
+            with self.assertRaises(ValueError):
+                math.fma(b, -math.inf, math.inf)
+
+        # Infinite result when a*b and c both infinite of the same sign.
+        for b in positives:
+            self.assertEqual(math.fma(math.inf, b, math.inf), math.inf)
+            self.assertEqual(math.fma(math.inf, -b, -math.inf), -math.inf)
+            self.assertEqual(math.fma(-math.inf, -b, math.inf), math.inf)
+            self.assertEqual(math.fma(-math.inf, b, -math.inf), -math.inf)
+            self.assertEqual(math.fma(b, math.inf, math.inf), math.inf)
+            self.assertEqual(math.fma(-b, math.inf, -math.inf), -math.inf)
+            self.assertEqual(math.fma(-b, -math.inf, math.inf), math.inf)
+            self.assertEqual(math.fma(b, -math.inf, -math.inf), -math.inf)
+
+        # Infinite result when a*b finite, c infinite.
+        for a, b in itertools.product(finites, finites):
+            self.assertEqual(math.fma(a, b, math.inf), math.inf)
+            self.assertEqual(math.fma(a, b, -math.inf), -math.inf)
+
+        # Infinite result when a*b infinite, c finite.
+        for b, c in itertools.product(positives, finites):
+            self.assertEqual(math.fma(math.inf, b, c), math.inf)
+            self.assertEqual(math.fma(-math.inf, b, c), -math.inf)
+            self.assertEqual(math.fma(-math.inf, -b, c), math.inf)
+            self.assertEqual(math.fma(math.inf, -b, c), -math.inf)
+
+            self.assertEqual(math.fma(b, math.inf, c), math.inf)
+            self.assertEqual(math.fma(b, -math.inf, c), -math.inf)
+            self.assertEqual(math.fma(-b, -math.inf, c), math.inf)
+            self.assertEqual(math.fma(-b, math.inf, c), -math.inf)
+
+    # gh-73468: On WASI and FreeBSD, libc fma() doesn't implement IEE 754-2008
+    # properly: it doesn't use the right sign when the result is zero.
+    @unittest.skipIf(support.is_wasi,
+                     "WASI fma() doesn't implement IEE 754-2008 properly")
+    @unittest.skipIf(sys.platform.startswith('freebsd'),
+                     "FreeBSD fma() doesn't implement IEE 754-2008 properly")
+    def test_fma_zero_result(self):
+        nonnegative_finites = [0.0, 1e-300, 2.3, 1e300]
+
+        # Zero results from exact zero inputs.
+        for b in nonnegative_finites:
+            self.assertIsPositiveZero(math.fma(0.0, b, 0.0))
+            self.assertIsPositiveZero(math.fma(0.0, b, -0.0))
+            self.assertIsNegativeZero(math.fma(0.0, -b, -0.0))
+            self.assertIsPositiveZero(math.fma(0.0, -b, 0.0))
+            self.assertIsPositiveZero(math.fma(-0.0, -b, 0.0))
+            self.assertIsPositiveZero(math.fma(-0.0, -b, -0.0))
+            self.assertIsNegativeZero(math.fma(-0.0, b, -0.0))
+            self.assertIsPositiveZero(math.fma(-0.0, b, 0.0))
+
+            self.assertIsPositiveZero(math.fma(b, 0.0, 0.0))
+            self.assertIsPositiveZero(math.fma(b, 0.0, -0.0))
+            self.assertIsNegativeZero(math.fma(-b, 0.0, -0.0))
+            self.assertIsPositiveZero(math.fma(-b, 0.0, 0.0))
+            self.assertIsPositiveZero(math.fma(-b, -0.0, 0.0))
+            self.assertIsPositiveZero(math.fma(-b, -0.0, -0.0))
+            self.assertIsNegativeZero(math.fma(b, -0.0, -0.0))
+            self.assertIsPositiveZero(math.fma(b, -0.0, 0.0))
+
+        # Exact zero result from nonzero inputs.
+        self.assertIsPositiveZero(math.fma(2.0, 2.0, -4.0))
+        self.assertIsPositiveZero(math.fma(2.0, -2.0, 4.0))
+        self.assertIsPositiveZero(math.fma(-2.0, -2.0, -4.0))
+        self.assertIsPositiveZero(math.fma(-2.0, 2.0, 4.0))
+
+        # Underflow to zero.
+        tiny = 1e-300
+        self.assertIsPositiveZero(math.fma(tiny, tiny, 0.0))
+        self.assertIsNegativeZero(math.fma(tiny, -tiny, 0.0))
+        self.assertIsPositiveZero(math.fma(-tiny, -tiny, 0.0))
+        self.assertIsNegativeZero(math.fma(-tiny, tiny, 0.0))
+        self.assertIsPositiveZero(math.fma(tiny, tiny, -0.0))
+        self.assertIsNegativeZero(math.fma(tiny, -tiny, -0.0))
+        self.assertIsPositiveZero(math.fma(-tiny, -tiny, -0.0))
+        self.assertIsNegativeZero(math.fma(-tiny, tiny, -0.0))
+
+        # Corner case where rounding the multiplication would
+        # give the wrong result.
+        x = float.fromhex('0x1p-500')
+        y = float.fromhex('0x1p-550')
+        z = float.fromhex('0x1p-1000')
+        self.assertIsNegativeZero(math.fma(x-y, x+y, -z))
+        self.assertIsPositiveZero(math.fma(y-x, x+y, z))
+        self.assertIsNegativeZero(math.fma(y-x, -(x+y), -z))
+        self.assertIsPositiveZero(math.fma(x-y, -(x+y), z))
+
+    def test_fma_overflow(self):
+        a = b = float.fromhex('0x1p512')
+        c = float.fromhex('0x1p1023')
+        # Overflow from multiplication.
+        with self.assertRaises(OverflowError):
+            math.fma(a, b, 0.0)
+        self.assertEqual(math.fma(a, b/2.0, 0.0), c)
+        # Overflow from the addition.
+        with self.assertRaises(OverflowError):
+            math.fma(a, b/2.0, c)
+        # No overflow, even though a*b overflows a float.
+        self.assertEqual(math.fma(a, b, -c), c)
+
+        # Extreme case: a * b is exactly at the overflow boundary, so the
+        # tiniest offset makes a difference between overflow and a finite
+        # result.
+        a = float.fromhex('0x1.ffffffc000000p+511')
+        b = float.fromhex('0x1.0000002000000p+512')
+        c = float.fromhex('0x0.0000000000001p-1022')
+        with self.assertRaises(OverflowError):
+            math.fma(a, b, 0.0)
+        with self.assertRaises(OverflowError):
+            math.fma(a, b, c)
+        self.assertEqual(math.fma(a, b, -c),
+                         float.fromhex('0x1.fffffffffffffp+1023'))
+
+        # Another extreme case: here a*b is about as large as possible subject
+        # to math.fma(a, b, c) being finite.
+        a = float.fromhex('0x1.ae565943785f9p+512')
+        b = float.fromhex('0x1.3094665de9db8p+512')
+        c = float.fromhex('0x1.fffffffffffffp+1023')
+        self.assertEqual(math.fma(a, b, -c), c)
+
+    def test_fma_single_round(self):
+        a = float.fromhex('0x1p-50')
+        self.assertEqual(math.fma(a - 1.0, a + 1.0, 1.0), a*a)
+
+    def test_random(self):
+        # A collection of randomly generated inputs for which the naive FMA
+        # (with two rounds) gives a different result from a singly-rounded FMA.
+
+        # tuples (a, b, c, expected)
+        test_values = [
+            ('0x1.694adde428b44p-1', '0x1.371b0d64caed7p-1',
+             '0x1.f347e7b8deab8p-4', '0x1.19f10da56c8adp-1'),
+            ('0x1.605401ccc6ad6p-2', '0x1.ce3a40bf56640p-2',
+             '0x1.96e3bf7bf2e20p-2', '0x1.1af6d8aa83101p-1'),
+            ('0x1.e5abd653a67d4p-2', '0x1.a2e400209b3e6p-1',
+             '0x1.a90051422ce13p-1', '0x1.37d68cc8c0fbbp+0'),
+            ('0x1.f94e8efd54700p-2', '0x1.123065c812cebp-1',
+             '0x1.458f86fb6ccd0p-1', '0x1.ccdcee26a3ff3p-1'),
+            ('0x1.bd926f1eedc96p-1', '0x1.eee9ca68c5740p-1',
+             '0x1.960c703eb3298p-2', '0x1.3cdcfb4fdb007p+0'),
+            ('0x1.27348350fbccdp-1', '0x1.3b073914a53f1p-1',
+             '0x1.e300da5c2b4cbp-1', '0x1.4c51e9a3c4e29p+0'),
+            ('0x1.2774f00b3497bp-1', '0x1.7038ec336bff0p-2',
+             '0x1.2f6f2ccc3576bp-1', '0x1.99ad9f9c2688bp-1'),
+            ('0x1.51d5a99300e5cp-1', '0x1.5cd74abd445a1p-1',
+             '0x1.8880ab0bbe530p-1', '0x1.3756f96b91129p+0'),
+            ('0x1.73cb965b821b8p-2', '0x1.218fd3d8d5371p-1',
+             '0x1.d1ea966a1f758p-2', '0x1.5217b8fd90119p-1'),
+            ('0x1.4aa98e890b046p-1', '0x1.954d85dff1041p-1',
+             '0x1.122b59317ebdfp-1', '0x1.0bf644b340cc5p+0'),
+            ('0x1.e28f29e44750fp-1', '0x1.4bcc4fdcd18fep-1',
+             '0x1.fd47f81298259p-1', '0x1.9b000afbc9995p+0'),
+            ('0x1.d2e850717fe78p-3', '0x1.1dd7531c303afp-1',
+             '0x1.e0869746a2fc2p-2', '0x1.316df6eb26439p-1'),
+            ('0x1.cf89c75ee6fbap-2', '0x1.b23decdc66825p-1',
+             '0x1.3d1fe76ac6168p-1', '0x1.00d8ea4c12abbp+0'),
+            ('0x1.3265ae6f05572p-2', '0x1.16d7ec285f7a2p-1',
+             '0x1.0b8405b3827fbp-1', '0x1.5ef33c118a001p-1'),
+            ('0x1.c4d1bf55ec1a5p-1', '0x1.bc59618459e12p-2',
+             '0x1.ce5b73dc1773dp-1', '0x1.496cf6164f99bp+0'),
+            ('0x1.d350026ac3946p-1', '0x1.9a234e149a68cp-2',
+             '0x1.f5467b1911fd6p-2', '0x1.b5cee3225caa5p-1'),
+        ]
+        for a_hex, b_hex, c_hex, expected_hex in test_values:
+            a = float.fromhex(a_hex)
+            b = float.fromhex(b_hex)
+            c = float.fromhex(c_hex)
+            expected = float.fromhex(expected_hex)
+            self.assertEqual(math.fma(a, b, c), expected)
+            self.assertEqual(math.fma(b, a, c), expected)
+
+    # Custom assertions.
+    def assertIsNaN(self, value):
+        self.assertTrue(
+            math.isnan(value),
+            msg="Expected a NaN, got {!r}".format(value)
+        )
+
+    def assertIsPositiveZero(self, value):
+        self.assertTrue(
+            value == 0 and math.copysign(1, value) > 0,
+            msg="Expected a positive zero, got {!r}".format(value)
+        )
+
+    def assertIsNegativeZero(self, value):
+        self.assertTrue(
+            value == 0 and math.copysign(1, value) < 0,
+            msg="Expected a negative zero, got {!r}".format(value)
+        )
+
+
 def load_tests(loader, tests, pattern):
     from doctest import DocFileSuite
     tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt")))
diff --git a/Misc/NEWS.d/next/Library/2024-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst b/Misc/NEWS.d/next/Library/2024-03-12-17-53-14.gh-issue-73468.z4ZzvJ.rst
new file mode 100644 (file)
index 0000000..c91f4eb
--- /dev/null
@@ -0,0 +1,2 @@
+Added new :func:`math.fma` function, wrapping C99's ``fma()`` operation:
+fused multiply-add function. Patch by Mark Dickinson and Victor Stinner.
index ca14c03f16f706acc41158f1150d7f261ac42378..666b6b3790dae5390512793b6f042bb9599e10fe 100644 (file)
@@ -204,6 +204,67 @@ PyDoc_STRVAR(math_log10__doc__,
 #define MATH_LOG10_METHODDEF    \
     {"log10", (PyCFunction)math_log10, METH_O, math_log10__doc__},
 
+PyDoc_STRVAR(math_fma__doc__,
+"fma($module, x, y, z, /)\n"
+"--\n"
+"\n"
+"Fused multiply-add operation.\n"
+"\n"
+"Compute (x * y) + z with a single round.");
+
+#define MATH_FMA_METHODDEF    \
+    {"fma", _PyCFunction_CAST(math_fma), METH_FASTCALL, math_fma__doc__},
+
+static PyObject *
+math_fma_impl(PyObject *module, double x, double y, double z);
+
+static PyObject *
+math_fma(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
+{
+    PyObject *return_value = NULL;
+    double x;
+    double y;
+    double z;
+
+    if (!_PyArg_CheckPositional("fma", nargs, 3, 3)) {
+        goto exit;
+    }
+    if (PyFloat_CheckExact(args[0])) {
+        x = PyFloat_AS_DOUBLE(args[0]);
+    }
+    else
+    {
+        x = PyFloat_AsDouble(args[0]);
+        if (x == -1.0 && PyErr_Occurred()) {
+            goto exit;
+        }
+    }
+    if (PyFloat_CheckExact(args[1])) {
+        y = PyFloat_AS_DOUBLE(args[1]);
+    }
+    else
+    {
+        y = PyFloat_AsDouble(args[1]);
+        if (y == -1.0 && PyErr_Occurred()) {
+            goto exit;
+        }
+    }
+    if (PyFloat_CheckExact(args[2])) {
+        z = PyFloat_AS_DOUBLE(args[2]);
+    }
+    else
+    {
+        z = PyFloat_AsDouble(args[2]);
+        if (z == -1.0 && PyErr_Occurred()) {
+            goto exit;
+        }
+    }
+    return_value = math_fma_impl(module, x, y, z);
+
+exit:
+    return return_value;
+}
+
 PyDoc_STRVAR(math_fmod__doc__,
 "fmod($module, x, y, /)\n"
 "--\n"
@@ -950,4 +1011,4 @@ math_ulp(PyObject *module, PyObject *arg)
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=6b2eeaed8d8a76d5 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=9fe3f007f474e015 input=a9049054013a1b77]*/
index a877bfcd6afb6872334d69e2f3f09735f5485cab..8ba0431f4a47b79826ff248e244b884e2203e009 100644 (file)
@@ -2321,6 +2321,48 @@ math_log10(PyObject *module, PyObject *x)
 }
 
 
+/*[clinic input]
+math.fma
+
+    x: double
+    y: double
+    z: double
+    /
+
+Fused multiply-add operation.
+
+Compute (x * y) + z with a single round.
+[clinic start generated code]*/
+
+static PyObject *
+math_fma_impl(PyObject *module, double x, double y, double z)
+/*[clinic end generated code: output=4fc8626dbc278d17 input=e3ad1f4a4c89626e]*/
+{
+    double r = fma(x, y, z);
+
+    /* Fast path: if we got a finite result, we're done. */
+    if (Py_IS_FINITE(r)) {
+        return PyFloat_FromDouble(r);
+    }
+
+    /* Non-finite result. Raise an exception if appropriate, else return r. */
+    if (Py_IS_NAN(r)) {
+        if (!Py_IS_NAN(x) && !Py_IS_NAN(y) && !Py_IS_NAN(z)) {
+            /* NaN result from non-NaN inputs. */
+            PyErr_SetString(PyExc_ValueError, "invalid operation in fma");
+            return NULL;
+        }
+    }
+    else if (Py_IS_FINITE(x) && Py_IS_FINITE(y) && Py_IS_FINITE(z)) {
+        /* Infinite result from finite inputs. */
+        PyErr_SetString(PyExc_OverflowError, "overflow in fma");
+        return NULL;
+    }
+
+    return PyFloat_FromDouble(r);
+}
+
+
 /*[clinic input]
 math.fmod
 
@@ -4094,6 +4136,7 @@ static PyMethodDef math_methods[] = {
     {"fabs",            math_fabs,      METH_O,         math_fabs_doc},
     MATH_FACTORIAL_METHODDEF
     MATH_FLOOR_METHODDEF
+    MATH_FMA_METHODDEF
     MATH_FMOD_METHODDEF
     MATH_FREXP_METHODDEF
     MATH_FSUM_METHODDEF