]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-112540: Support zero inputs in geometric_mean() (gh-112880)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Fri, 8 Dec 2023 18:05:56 +0000 (12:05 -0600)
committerGitHub <noreply@github.com>
Fri, 8 Dec 2023 18:05:56 +0000 (12:05 -0600)
Lib/statistics.py
Lib/test/test_statistics.py
Misc/NEWS.d/next/Library/2023-12-08-11-17-17.gh-issue-112540.Pm5egX.rst [new file with mode: 0644]

index 4da06889c6db4691c160a12929853fe0e2d8d749..83aaedb04515e02cf4f4296174da3c308c71d87c 100644 (file)
@@ -527,8 +527,10 @@ def fmean(data, weights=None):
 def geometric_mean(data):
     """Convert data to floats and compute the geometric mean.
 
-    Raises a StatisticsError if the input dataset is empty,
-    if it contains a zero, or if it contains a negative value.
+    Raises a StatisticsError if the input dataset is empty
+    or if it contains a negative value.
+
+    Returns zero if the product of inputs is zero.
 
     No special efforts are made to achieve exact results.
     (However, this may change in the future.)
@@ -536,11 +538,25 @@ def geometric_mean(data):
     >>> round(geometric_mean([54, 24, 36]), 9)
     36.0
     """
-    try:
-        return exp(fmean(map(log, data)))
-    except ValueError:
-        raise StatisticsError('geometric mean requires a non-empty dataset '
-                              'containing positive numbers') from None
+    n = 0
+    found_zero = False
+    def count_positive(iterable):
+        nonlocal n, found_zero
+        for n, x in enumerate(iterable, start=1):
+            if x > 0.0 or math.isnan(x):
+                yield x
+            elif x == 0.0:
+                found_zero = True
+            else:
+                raise StatisticsError('No negative inputs allowed', x)
+    total = fsum(map(log, count_positive(data)))
+    if not n:
+        raise StatisticsError('Must have a non-empty dataset')
+    if math.isnan(total):
+        return math.nan
+    if found_zero:
+        return math.nan if total == math.inf else 0.0
+    return exp(total / n)
 
 
 def harmonic_mean(data, weights=None):
index b24fc3c3d077fe4c99319d05539985149f1c6d4d..bf2c254c9ee7d9ab5da0c3831bd7ace10278e021 100644 (file)
@@ -2302,10 +2302,12 @@ class TestGeometricMean(unittest.TestCase):
         StatisticsError = statistics.StatisticsError
         with self.assertRaises(StatisticsError):
             geometric_mean([])                      # empty input
-        with self.assertRaises(StatisticsError):
-            geometric_mean([3.5, 0.0, 5.25])        # zero input
         with self.assertRaises(StatisticsError):
             geometric_mean([3.5, -4.0, 5.25])       # negative input
+        with self.assertRaises(StatisticsError):
+            geometric_mean([0.0, -4.0, 5.25])       # negative input with zero
+        with self.assertRaises(StatisticsError):
+            geometric_mean([3.5, -math.inf, 5.25])  # negative infinity
         with self.assertRaises(StatisticsError):
             geometric_mean(iter([]))                # empty iterator
         with self.assertRaises(TypeError):
@@ -2328,6 +2330,12 @@ class TestGeometricMean(unittest.TestCase):
         with self.assertRaises(ValueError):
             geometric_mean([Inf, -Inf])
 
+        # Cases with zero
+        self.assertEqual(geometric_mean([3, 0.0, 5]), 0.0)         # Any zero gives a zero
+        self.assertEqual(geometric_mean([3, -0.0, 5]), 0.0)        # Negative zero allowed
+        self.assertTrue(math.isnan(geometric_mean([0, NaN])))      # NaN beats zero
+        self.assertTrue(math.isnan(geometric_mean([0, Inf])))      # Because 0.0 * Inf -> NaN
+
     def test_mixed_int_and_float(self):
         # Regression test for b.p.o. issue #28327
         geometric_mean = statistics.geometric_mean
diff --git a/Misc/NEWS.d/next/Library/2023-12-08-11-17-17.gh-issue-112540.Pm5egX.rst b/Misc/NEWS.d/next/Library/2023-12-08-11-17-17.gh-issue-112540.Pm5egX.rst
new file mode 100644 (file)
index 0000000..263b13d
--- /dev/null
@@ -0,0 +1,2 @@
+The statistics.geometric_mean() function now returns zero for datasets
+containing a zero.  Formerly, it would raise an exception.