]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-100485: Tweaks to sumprod() (GH-100857)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Sun, 8 Jan 2023 19:38:24 +0000 (13:38 -0600)
committerGitHub <noreply@github.com>
Sun, 8 Jan 2023 19:38:24 +0000 (13:38 -0600)
Doc/whatsnew/3.12.rst
Lib/test/test_math.py
Modules/mathmodule.c

index 2f50ece4dab3fb147f3e4a312083103671cbe8b5..b882bb607f911ce6490797bb297c2a8fe6567c47 100644 (file)
@@ -262,6 +262,12 @@ dis
   :data:`~dis.hasarg` collection instead.
   (Contributed by Irit Katriel in :gh:`94216`.)
 
+math
+----
+
+* Added :func:`math.sumprod` for computing a sum of products.
+  (Contributed by Raymond Hettinger in :gh:`100485`.)
+
 os
 --
 
index 65fe169671eae2d3ab6f216b4b2147aaa9a6a8d2..b8ac8f32055d33ea678fd74c2ced8c93b3f2c728 100644 (file)
@@ -1294,6 +1294,7 @@ class MathTests(unittest.TestCase):
         self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0)
         self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0)
 
+    @support.requires_resource('cpu')
     def test_sumprod_stress(self):
         sumprod = math.sumprod
         product = itertools.product
index 9545ad2a99eb43eec3f64c0dbb90f443462fd3f3..11e815c6f17e1d6bcffbee09b3afd88817167a49 100644 (file)
@@ -2832,7 +2832,7 @@ long_add_would_overflow(long a, long b)
 }
 
 /*
-Double length extended precision floating point arithmetic
+Double and triple length extended precision floating point arithmetic
 based on ideas from three sources:
 
   Improved Kahan–Babuška algorithm by Arnold Neumaier
@@ -2845,22 +2845,22 @@ based on ideas from three sources:
   Ultimately Fast Accurate Summation by Siegfried M. Rump
   https://www.tuhh.de/ti3/paper/rump/Ru08b.pdf
 
-The double length routines allow for quite a bit of instruction
-level parallelism.  On a 3.22 Ghz Apple M1 Max, the incremental
-cost of increasing the input vector size by one is 6.0 nsec.
+Double length functions:
+* dl_split() exact split of a C double into two half precision components.
+* dl_mul() exact multiplication of two C doubles.
 
-dl_zero() returns an extended precision zero
-dl_split() exactly splits a double into two half precision components.
-dl_add() performs compensated summation to keep a running total.
-dl_mul() implements lossless multiplication of doubles.
-dl_fma() implements an extended precision fused-multiply-add.
-dl_to_d() converts from extended precision to double precision.
+Triple length functions and constant:
+* tl_zero is a triple length zero for starting or resetting an accumulation.
+* tl_add() compensated addition of a C double to a triple length number.
+* tl_fma() performs a triple length fused-multiply-add.
+* tl_to_d() converts from triple length number back to a C double.
 
 */
 
 typedef struct{ double hi; double lo; } DoubleLength;
+typedef struct{ double hi; double lo; double tiny; } TripleLength;
 
-static const DoubleLength dl_zero = {0.0, 0.0};
+static const TripleLength tl_zero = {0.0, 0.0, 0.0};
 
 static inline DoubleLength
 twosum(double a, double b)
@@ -2874,11 +2874,20 @@ twosum(double a, double b)
     return  (DoubleLength) {s, t};
 }
 
-static inline DoubleLength
-dl_add(DoubleLength total, double x)
+static inline TripleLength
+tl_add(TripleLength total, double x)
 {
-    DoubleLength s = twosum(total.hi, x);
-    return (DoubleLength) {s.hi, total.lo + s.lo};
+    /* Input:       x     total.hi   total.lo    total.tiny
+                   |--- twosum ---|
+                    s.hi      s.lo
+                             |--- twosum ---|
+                              t.hi      t.lo
+                                       |--- single sum ---|
+       Output:      s.hi     t.hi       tiny
+     */
+    DoubleLength s = twosum(x, total.hi);
+    DoubleLength t = twosum(s.lo, total.lo);
+    return (TripleLength) {s.hi, t.hi, t.lo + total.tiny};
 }
 
 static inline DoubleLength
@@ -2902,18 +2911,18 @@ dl_mul(double x, double y)
     return (DoubleLength) {z, zz};
 }
 
-static inline DoubleLength
-dl_fma(DoubleLength total, double p, double q)
+static inline TripleLength
+tl_fma(TripleLength total, double p, double q)
 {
     DoubleLength product = dl_mul(p, q);
-    total = dl_add(total, product.hi);
-    return  dl_add(total, product.lo);
+    total = tl_add(total, product.hi);
+    return  tl_add(total, product.lo);
 }
 
 static inline double
-dl_to_d(DoubleLength total)
+tl_to_d(TripleLength total)
 {
-    return total.hi + total.lo;
+    return total.tiny + total.lo + total.hi;
 }
 
 /*[clinic input]
@@ -2944,7 +2953,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
     bool int_path_enabled = true, int_total_in_use = false;
     bool flt_path_enabled = true, flt_total_in_use = false;
     long int_total = 0;
-    DoubleLength flt_total = dl_zero;
+    TripleLength flt_total = tl_zero;
 
     p_it = PyObject_GetIter(p);
     if (p_it == NULL) {
@@ -3079,7 +3088,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
                 } else {
                     goto finalize_flt_path;
                 }
-                DoubleLength new_flt_total = dl_fma(flt_total, flt_p, flt_q);
+                TripleLength new_flt_total = tl_fma(flt_total, flt_p, flt_q);
                 if (isfinite(new_flt_total.hi)) {
                     flt_total = new_flt_total;
                     flt_total_in_use = true;
@@ -3093,7 +3102,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
             // We're finished, overflowed, have a non-float, or got a non-finite value
             flt_path_enabled = false;
             if (flt_total_in_use) {
-                term_i = PyFloat_FromDouble(dl_to_d(flt_total));
+                term_i = PyFloat_FromDouble(tl_to_d(flt_total));
                 if (term_i == NULL) {
                     goto err_exit;
                 }
@@ -3104,7 +3113,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
                 Py_SETREF(total, new_total);
                 new_total = NULL;
                 Py_CLEAR(term_i);
-                flt_total = dl_zero;
+                flt_total = tl_zero;
                 flt_total_in_use = false;
             }
         }