]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-46257: Convert statistics._ss() to a single pass algorithm (GH-30403)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Wed, 5 Jan 2022 15:39:10 +0000 (07:39 -0800)
committerGitHub <noreply@github.com>
Wed, 5 Jan 2022 15:39:10 +0000 (09:39 -0600)
Lib/statistics.py
Misc/NEWS.d/next/Library/2022-01-04-11-04-20.bpo-46257._o2ADe.rst [new file with mode: 0644]

index c104571d39053d03bd0214b76760afb61c9dbcb3..eef2453bc7394b0d9fab948ff833aac1d1fa8d60 100644 (file)
@@ -138,7 +138,7 @@ from itertools import groupby, repeat
 from bisect import bisect_left, bisect_right
 from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
 from operator import mul
-from collections import Counter, namedtuple
+from collections import Counter, namedtuple, defaultdict
 
 _SQRT2 = sqrt(2.0)
 
@@ -202,6 +202,43 @@ def _sum(data):
     return (T, total, count)
 
 
+def _ss(data, c=None):
+    """Return sum of square deviations of sequence data.
+
+    If ``c`` is None, the mean is calculated in one pass, and the deviations
+    from the mean are calculated in a second pass. Otherwise, deviations are
+    calculated from ``c`` as given. Use the second case with care, as it can
+    lead to garbage results.
+    """
+    if c is not None:
+        T, total, count = _sum((d := x - c) * d for x in data)
+        return (T, total, count)
+    count = 0
+    sx_partials = defaultdict(int)
+    sxx_partials = defaultdict(int)
+    T = int
+    for typ, values in groupby(data, type):
+        T = _coerce(T, typ)  # or raise TypeError
+        for n, d in map(_exact_ratio, values):
+            count += 1
+            sx_partials[d] += n
+            sxx_partials[d] += n * n
+    if not count:
+        total = Fraction(0)
+    elif None in sx_partials:
+        # The sum will be a NAN or INF. We can ignore all the finite
+        # partials, and just look at this special one.
+        total = sx_partials[None]
+        assert not _isfinite(total)
+    else:
+        sx = sum(Fraction(n, d) for d, n in sx_partials.items())
+        sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
+        # This formula has poor numeric properties for floats,
+        # but with fractions it is exact.
+        total = (count * sxx - sx * sx) / count
+    return (T, total, count)
+
+
 def _isfinite(x):
     try:
         return x.is_finite()  # Likely a Decimal.
@@ -399,13 +436,9 @@ def mean(data):
 
     If ``data`` is empty, StatisticsError will be raised.
     """
-    if iter(data) is data:
-        data = list(data)
-    n = len(data)
+    T, total, n = _sum(data)
     if n < 1:
         raise StatisticsError('mean requires at least one data point')
-    T, total, count = _sum(data)
-    assert count == n
     return _convert(total / n, T)
 
 
@@ -776,41 +809,6 @@ def quantiles(data, *, n=4, method='exclusive'):
 
 # See http://mathworld.wolfram.com/Variance.html
 #     http://mathworld.wolfram.com/SampleVariance.html
-#     http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-#
-# Under no circumstances use the so-called "computational formula for
-# variance", as that is only suitable for hand calculations with a small
-# amount of low-precision data. It has terrible numeric properties.
-#
-# See a comparison of three computational methods here:
-# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
-
-def _ss(data, c=None):
-    """Return sum of square deviations of sequence data.
-
-    If ``c`` is None, the mean is calculated in one pass, and the deviations
-    from the mean are calculated in a second pass. Otherwise, deviations are
-    calculated from ``c`` as given. Use the second case with care, as it can
-    lead to garbage results.
-    """
-    if c is not None:
-        T, total, count = _sum((d := x - c) * d for x in data)
-        return (T, total)
-    T, total, count = _sum(data)
-    mean_n, mean_d = (total / count).as_integer_ratio()
-    partials = Counter()
-    for n, d in map(_exact_ratio, data):
-        diff_n = n * mean_d - d * mean_n
-        diff_d = d * mean_d
-        partials[diff_d * diff_d] += diff_n * diff_n
-    if None in partials:
-        # The sum will be a NAN or INF. We can ignore all the finite
-        # partials, and just look at this special one.
-        total = partials[None]
-        assert not _isfinite(total)
-    else:
-        total = sum(Fraction(n, d) for d, n in partials.items())
-    return (T, total)
 
 
 def variance(data, xbar=None):
@@ -851,12 +849,9 @@ def variance(data, xbar=None):
     Fraction(67, 108)
 
     """
-    if iter(data) is data:
-        data = list(data)
-    n = len(data)
+    T, ss, n = _ss(data, xbar)
     if n < 2:
         raise StatisticsError('variance requires at least two data points')
-    T, ss = _ss(data, xbar)
     return _convert(ss / (n - 1), T)
 
 
@@ -895,12 +890,9 @@ def pvariance(data, mu=None):
     Fraction(13, 72)
 
     """
-    if iter(data) is data:
-        data = list(data)
-    n = len(data)
+    T, ss, n = _ss(data, mu)
     if n < 1:
         raise StatisticsError('pvariance requires at least one data point')
-    T, ss = _ss(data, mu)
     return _convert(ss / n, T)
 
 
@@ -913,12 +905,9 @@ def stdev(data, xbar=None):
     1.0810874155219827
 
     """
-    if iter(data) is data:
-        data = list(data)
-    n = len(data)
+    T, ss, n = _ss(data, xbar)
     if n < 2:
         raise StatisticsError('stdev requires at least two data points')
-    T, ss = _ss(data, xbar)
     mss = ss / (n - 1)
     if issubclass(T, Decimal):
         return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
@@ -934,12 +923,9 @@ def pstdev(data, mu=None):
     0.986893273527251
 
     """
-    if iter(data) is data:
-        data = list(data)
-    n = len(data)
+    T, ss, n = _ss(data, mu)
     if n < 1:
         raise StatisticsError('pstdev requires at least one data point')
-    T, ss = _ss(data, mu)
     mss = ss / n
     if issubclass(T, Decimal):
         return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
diff --git a/Misc/NEWS.d/next/Library/2022-01-04-11-04-20.bpo-46257._o2ADe.rst b/Misc/NEWS.d/next/Library/2022-01-04-11-04-20.bpo-46257._o2ADe.rst
new file mode 100644 (file)
index 0000000..72ae56e
--- /dev/null
@@ -0,0 +1,4 @@
+Optimized the mean, variance, and stdev functions in the statistics module.
+If the input is an iterator, it is consumed in a single pass rather than
+eating memory by conversion to a list.  The single pass algorithm is about
+twice as fast as the previous two pass code.