]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Extend _sqrtprod() to cover the full range of inputs. Add tests. (GH-107855)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Fri, 11 Aug 2023 16:19:19 +0000 (17:19 +0100)
committerGitHub <noreply@github.com>
Fri, 11 Aug 2023 16:19:19 +0000 (11:19 -0500)
Lib/statistics.py
Lib/test/test_statistics.py

index 93c44f1f13fab73514edbc947c8d1d46cb088235..a8036e9928c464472cc47c1145cfb767ff424354 100644 (file)
@@ -137,6 +137,7 @@ from decimal import Decimal
 from itertools import count, groupby, repeat
 from bisect import bisect_left, bisect_right
 from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
+from math import isfinite, isinf
 from functools import reduce
 from operator import itemgetter
 from collections import Counter, namedtuple, defaultdict
@@ -1005,14 +1006,25 @@ def _mean_stdev(data):
         return float(xbar), float(xbar) / float(ss)
 
 def _sqrtprod(x: float, y: float) -> float:
-    "Return sqrt(x * y) computed with high accuracy."
-    # Square root differential correction:
-    # https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
+    "Return sqrt(x * y) computed with improved accuracy and without overflow/underflow."
     h = sqrt(x * y)
+    if not isfinite(h):
+        if isinf(h) and not isinf(x) and not isinf(y):
+            # Finite inputs overflowed, so scale down, and recompute.
+            scale = 2.0 ** -512  # sqrt(1 / sys.float_info.max)
+            return _sqrtprod(scale * x, scale * y) / scale
+        return h
     if not h:
-        return 0.0
-    x = sumprod((x, h), (y, -h))
-    return h + x / (2.0 * h)
+        if x and y:
+            # Non-zero inputs underflowed, so scale up, and recompute.
+            # Scale:  1 / sqrt(sys.float_info.min * sys.float_info.epsilon)
+            scale = 2.0 ** 537
+            return _sqrtprod(scale * x, scale * y) / scale
+        return h
+    # Improve accuracy with a differential correction.
+    # https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
+    d = sumprod((x, h), (y, -h))
+    return h + d / (2.0 * h)
 
 
 # === Statistics for relations between two inputs ===
index f0fa6454b1f91a70d0548c49b4587c1029126e8c..aa2cf2b1edc5840a8405d4ae65a005406a4fdf38 100644 (file)
@@ -28,6 +28,12 @@ import statistics
 
 # === Helper functions and class ===
 
+# Test copied from Lib/test/test_math.py
+# detect evidence of double-rounding: fsum is not always correctly
+# rounded on machines that suffer from double rounding.
+x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer
+HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4)
+
 def sign(x):
     """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
     return math.copysign(1, x)
@@ -2564,6 +2570,79 @@ class TestCorrelationAndCovariance(unittest.TestCase):
         self.assertAlmostEqual(statistics.correlation(x, y), 1)
         self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
 
+    def test_sqrtprod_helper_function_fundamentals(self):
+        # Verify that results are close to sqrt(x * y)
+        for i in range(100):
+            x = random.expovariate()
+            y = random.expovariate()
+            expected = math.sqrt(x * y)
+            actual = statistics._sqrtprod(x, y)
+            with self.subTest(x=x, y=y, expected=expected, actual=actual):
+                self.assertAlmostEqual(expected, actual)
+
+        x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
+        self.assertEqual(statistics._sqrtprod(x, y), target)
+        self.assertNotEqual(math.sqrt(x * y), target)
+
+        # Test that range extremes avoid underflow and overflow
+        smallest = sys.float_info.min * sys.float_info.epsilon
+        self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest)
+        biggest = sys.float_info.max
+        self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest)
+
+        # Check special values and the sign of the result
+        special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0,
+                          math.nan, -math.nan, math.inf, -math.inf]
+        for x, y in itertools.product(special_values, repeat=2):
+            try:
+                expected = math.sqrt(x * y)
+            except ValueError:
+                expected = 'ValueError'
+            try:
+                actual = statistics._sqrtprod(x, y)
+            except ValueError:
+                actual = 'ValueError'
+            with self.subTest(x=x, y=y, expected=expected, actual=actual):
+                if isinstance(expected, str) and expected == 'ValueError':
+                    self.assertEqual(actual, 'ValueError')
+                    continue
+                self.assertIsInstance(actual, float)
+                if math.isnan(expected):
+                    self.assertTrue(math.isnan(actual))
+                    continue
+                self.assertEqual(actual, expected)
+                self.assertEqual(sign(actual), sign(expected))
+
+    @requires_IEEE_754
+    @unittest.skipIf(HAVE_DOUBLE_ROUNDING,
+                     "accuracy not guaranteed on machines with double rounding")
+    @support.cpython_only    # Allow for a weaker sumprod() implmentation
+    def test_sqrtprod_helper_function_improved_accuracy(self):
+        # Test a known example where accuracy is improved
+        x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
+        self.assertEqual(statistics._sqrtprod(x, y), target)
+        self.assertNotEqual(math.sqrt(x * y), target)
+
+        def reference_value(x: float, y: float) -> float:
+            x = decimal.Decimal(x)
+            y = decimal.Decimal(y)
+            with decimal.localcontext() as ctx:
+                ctx.prec = 200
+                return float((x * y).sqrt())
+
+        # Verify that the new function with improved accuracy
+        # agrees with a reference value more often than old version.
+        new_agreements = 0
+        old_agreements = 0
+        for i in range(10_000):
+            x = random.expovariate()
+            y = random.expovariate()
+            new = statistics._sqrtprod(x, y)
+            old = math.sqrt(x * y)
+            ref = reference_value(x, y)
+            new_agreements += (new == ref)
+            old_agreements += (old == ref)
+        self.assertGreater(new_agreements, old_agreements)
 
     def test_correlation_spearman(self):
         # https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php