]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Optimize fmean() weighted average (#102626)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Sun, 12 Mar 2023 17:48:25 +0000 (12:48 -0500)
committerGitHub <noreply@github.com>
Sun, 12 Mar 2023 17:48:25 +0000 (12:48 -0500)
Lib/statistics.py

index 07d1fd5ba6e98e90bd54abe697e3738a5bef89a5..7d5d750193a5abdc6c4a69ad9ee6d6b3a262b7c8 100644 (file)
@@ -136,9 +136,9 @@ from fractions import Fraction
 from decimal import Decimal
 from itertools import count, groupby, repeat
 from bisect import bisect_left, bisect_right
-from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
+from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
 from functools import reduce
-from operator import mul, itemgetter
+from operator import itemgetter
 from collections import Counter, namedtuple, defaultdict
 
 _SQRT2 = sqrt(2.0)
@@ -496,28 +496,26 @@ def fmean(data, weights=None):
     >>> fmean([3.5, 4.0, 5.25])
     4.25
     """
-    try:
-        n = len(data)
-    except TypeError:
-        # Handle iterators that do not define __len__().
-        n = 0
-        def count(iterable):
-            nonlocal n
-            for n, x in enumerate(iterable, start=1):
-                yield x
-        data = count(data)
     if weights is None:
+        try:
+            n = len(data)
+        except TypeError:
+            # Handle iterators that do not define __len__().
+            n = 0
+            def count(iterable):
+                nonlocal n
+                for n, x in enumerate(iterable, start=1):
+                    yield x
+            data = count(data)
         total = fsum(data)
         if not n:
             raise StatisticsError('fmean requires at least one data point')
         return total / n
-    try:
-        num_weights = len(weights)
-    except TypeError:
+    if not isinstance(weights, (list, tuple)):
         weights = list(weights)
-        num_weights = len(weights)
-    num = fsum(map(mul, data, weights))
-    if n != num_weights:
+    try:
+        num = sumprod(data, weights)
+    except ValueError:
         raise StatisticsError('data and weights must be the same length')
     den = fsum(weights)
     if not den: