]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Itertool recipe improvements (GH-103399)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Sun, 9 Apr 2023 19:17:37 +0000 (14:17 -0500)
committerGitHub <noreply@github.com>
Sun, 9 Apr 2023 19:17:37 +0000 (14:17 -0500)
Doc/library/itertools.rst

index 70e5b7905f20a98e9364fa9d4d71588cc2afd863..e57c393a6b370bcbb9fae5e560dc231abe2fc995 100644 (file)
@@ -769,8 +769,8 @@ well as with the built-in itertools such as ``map()``, ``filter()``,
 
 A secondary purpose of the recipes is to serve as an incubator.  The
 ``accumulate()``, ``compress()``, and ``pairwise()`` itertools started out as
-recipes.  Currently, the ``iter_index()`` recipe is being tested to see
-whether it proves its worth.
+recipes.  Currently, the ``sliding_window()`` and ``iter_index()`` recipes
+are being tested to see whether they prove their worth.
 
 Substantially all of these recipes and many, many others can be installed from
 the `more-itertools project <https://pypi.org/project/more-itertools/>`_ found
@@ -806,6 +806,23 @@ which incur interpreter overhead.
        "Return function(0), function(1), ..."
        return map(function, count(start))
 
+   def repeatfunc(func, times=None, *args):
+       """Repeat calls to func with specified arguments.
+
+       Example:  repeatfunc(random.random)
+       """
+       if times is None:
+           return starmap(func, repeat(args))
+       return starmap(func, repeat(args, times))
+
+   def flatten(list_of_lists):
+       "Flatten one level of nesting"
+       return chain.from_iterable(list_of_lists)
+
+   def ncycles(iterable, n):
+       "Returns the sequence elements n times"
+       return chain.from_iterable(repeat(tuple(iterable), n))
+
    def tail(n, iterable):
        "Return an iterator over the last n items"
        # tail(3, 'ABCDEFG') --> E F G
@@ -825,70 +842,27 @@ which incur interpreter overhead.
        "Returns the nth item or a default value"
        return next(islice(iterable, n, None), default)
 
-   def all_equal(iterable):
-       "Returns True if all the elements are equal to each other"
-       g = groupby(iterable)
-       return next(g, True) and not next(g, False)
-
    def quantify(iterable, pred=bool):
        "Count how many times the predicate is True"
        return sum(map(pred, iterable))
 
-   def ncycles(iterable, n):
-       "Returns the sequence elements n times"
-       return chain.from_iterable(repeat(tuple(iterable), n))
-
-   def sum_of_squares(it):
-       "Add up the squares of the input values."
-       # sum_of_squares([10, 20, 30]) -> 1400
-       return math.sumprod(*tee(it))
-
-   def transpose(it):
-       "Swap the rows and columns of the input."
-       # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
-       return zip(*it, strict=True)
-
-   def matmul(m1, m2):
-       "Multiply two matrices."
-       # matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]) --> (49, 80), (41, 60)
-       n = len(m2[0])
-       return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)
-
-   def convolve(signal, kernel):
-       # See:  https://betterexplained.com/articles/intuitive-convolution/
-       # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
-       # convolve(data, [1, -1]) --> 1st finite difference (1st derivative)
-       # convolve(data, [1, -2, 1]) --> 2nd finite difference (2nd derivative)
-       kernel = tuple(kernel)[::-1]
-       n = len(kernel)
-       window = collections.deque([0], maxlen=n) * n
-       for x in chain(signal, repeat(0, n-1)):
-           window.append(x)
-           yield math.sumprod(kernel, window)
+   def all_equal(iterable):
+       "Returns True if all the elements are equal to each other"
+       g = groupby(iterable)
+       return next(g, True) and not next(g, False)
 
-   def polynomial_from_roots(roots):
-       """Compute a polynomial's coefficients from its roots.
+   def first_true(iterable, default=False, pred=None):
+       """Returns the first true value in the iterable.
 
-          (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
-       """
-       # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60]
-       expansion = [1]
-       for r in roots:
-           expansion = convolve(expansion, (1, -r))
-       return list(expansion)
+       If no true value is found, returns *default*
 
-   def polynomial_eval(coefficients, x):
-       """Evaluate a polynomial at a specific value.
+       If *pred* is not None, returns the first item
+       for which pred(item) is true.
 
-       Computes with better numeric stability than Horner's method.
        """
-       # Evaluate x³ -4x² -17x + 60 at x = 2.5
-       # polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125
-       n = len(coefficients)
-       if n == 0:
-           return x * 0  # coerce zero to the type of x
-       powers = map(pow, repeat(x), reversed(range(n)))
-       return math.sumprod(coefficients, powers)
+       # first_true([a,b,c], x) --> a or b or c or x
+       # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
+       return next(filter(pred, iterable), default)
 
    def iter_index(iterable, value, start=0):
        "Return indices where a value occurs in a sequence or iterable."
@@ -913,44 +887,28 @@ which incur interpreter overhead.
            except ValueError:
                pass
 
-   def sieve(n):
-       "Primes less than n"
-       # sieve(30) --> 2 3 5 7 11 13 17 19 23 29
-       data = bytearray((0, 1)) * (n // 2)
-       data[:3] = 0, 0, 0
-       limit = math.isqrt(n) + 1
-       for p in compress(range(limit), data):
-           data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
-       data[2] = 1
-       return iter_index(data, 1) if n > 2 else iter([])
-
-   def factor(n):
-       "Prime factors of n."
-       # factor(99) --> 3 3 11
-       for prime in sieve(math.isqrt(n) + 1):
-           while True:
-               quotient, remainder = divmod(n, prime)
-               if remainder:
-                   break
-               yield prime
-               n = quotient
-               if n == 1:
-                   return
-       if n > 1:
-           yield n
+   def iter_except(func, exception, first=None):
+       """ Call a function repeatedly until an exception is raised.
 
-   def flatten(list_of_lists):
-       "Flatten one level of nesting"
-       return chain.from_iterable(list_of_lists)
+       Converts a call-until-exception interface to an iterator interface.
+       Like builtins.iter(func, sentinel) but uses an exception instead
+       of a sentinel to end the loop.
 
-   def repeatfunc(func, times=None, *args):
-       """Repeat calls to func with specified arguments.
+       Examples:
+           iter_except(functools.partial(heappop, h), IndexError)   # priority queue iterator
+           iter_except(d.popitem, KeyError)                         # non-blocking dict iterator
+           iter_except(d.popleft, IndexError)                       # non-blocking deque iterator
+           iter_except(q.get_nowait, Queue.Empty)                   # loop over a producer Queue
+           iter_except(s.pop, KeyError)                             # non-blocking set iterator
 
-       Example:  repeatfunc(random.random)
        """
-       if times is None:
-           return starmap(func, repeat(args))
-       return starmap(func, repeat(args, times))
+       try:
+           if first is not None:
+               yield first()            # For database APIs needing an initial cast to db.first()
+           while True:
+               yield func()
+       except exception:
+           pass
 
    def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
        "Collect data into non-overlapping fixed-length chunks or blocks"
@@ -967,12 +925,6 @@ which incur interpreter overhead.
        else:
            raise ValueError('Expected fill, strict, or ignore')
 
-   def triplewise(iterable):
-       "Return overlapping triplets from an iterable"
-       # triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
-       for (a, _), (b, c) in pairwise(pairwise(iterable)):
-           yield a, b, c
-
    def sliding_window(iterable, n):
        # sliding_window('ABCDEFG', 4) --> ABCD BCDE CDEF DEFG
        it = iter(iterable)
@@ -1003,6 +955,12 @@ which incur interpreter overhead.
        t1, t2 = tee(iterable)
        return filterfalse(pred, t1), filter(pred, t2)
 
+   def subslices(seq):
+       "Return all contiguous non-empty subslices of a sequence"
+       # subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D
+       slices = starmap(slice, combinations(range(len(seq) + 1), 2))
+       return map(operator.getitem, repeat(seq), slices)
+
    def before_and_after(predicate, it):
        """ Variant of takewhile() that allows complete
            access to the remainder of the iterator.
@@ -1032,17 +990,6 @@ which incur interpreter overhead.
            yield from it
        return true_iterator(), remainder_iterator()
 
-   def subslices(seq):
-       "Return all contiguous non-empty subslices of a sequence"
-       # subslices('ABCD') --> A AB ABC ABCD B BC BCD C CD D
-       slices = starmap(slice, combinations(range(len(seq) + 1), 2))
-       return map(operator.getitem, repeat(seq), slices)
-
-   def powerset(iterable):
-       "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
-       s = list(iterable)
-       return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
-
    def unique_everseen(iterable, key=None):
        "List unique elements, preserving order. Remember all elements ever seen."
        # unique_everseen('AAAABBBCCDAABBB') --> A B C D
@@ -1072,41 +1019,96 @@ which incur interpreter overhead.
        # unique_justseen('ABBcCAD', str.lower) --> A B c A D
        return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
 
-   def iter_except(func, exception, first=None):
-       """ Call a function repeatedly until an exception is raised.
 
-       Converts a call-until-exception interface to an iterator interface.
-       Like builtins.iter(func, sentinel) but uses an exception instead
-       of a sentinel to end the loop.
+The following recipes have a more mathematical flavor:
 
-       Examples:
-           iter_except(functools.partial(heappop, h), IndexError)   # priority queue iterator
-           iter_except(d.popitem, KeyError)                         # non-blocking dict iterator
-           iter_except(d.popleft, IndexError)                       # non-blocking deque iterator
-           iter_except(q.get_nowait, Queue.Empty)                   # loop over a producer Queue
-           iter_except(s.pop, KeyError)                             # non-blocking set iterator
+.. testcode::
 
-       """
-       try:
-           if first is not None:
-               yield first()            # For database APIs needing an initial cast to db.first()
+   def powerset(iterable):
+       "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
+       s = list(iterable)
+       return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
+
+   def sieve(n):
+       "Primes less than n"
+       # sieve(30) --> 2 3 5 7 11 13 17 19 23 29
+       data = bytearray((0, 1)) * (n // 2)
+       data[:3] = 0, 0, 0
+       limit = math.isqrt(n) + 1
+       for p in compress(range(limit), data):
+           data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
+       data[2] = 1
+       return iter_index(data, 1) if n > 2 else iter([])
+
+   def factor(n):
+       "Prime factors of n."
+       # factor(99) --> 3 3 11
+       for prime in sieve(math.isqrt(n) + 1):
            while True:
-               yield func()
-       except exception:
-           pass
+               quotient, remainder = divmod(n, prime)
+               if remainder:
+                   break
+               yield prime
+               n = quotient
+               if n == 1:
+                   return
+       if n > 1:
+           yield n
 
-   def first_true(iterable, default=False, pred=None):
-       """Returns the first true value in the iterable.
+   def sum_of_squares(it):
+       "Add up the squares of the input values."
+       # sum_of_squares([10, 20, 30]) -> 1400
+       return math.sumprod(*tee(it))
 
-       If no true value is found, returns *default*
+   def transpose(it):
+       "Swap the rows and columns of the input."
+       # transpose([(1, 2, 3), (11, 22, 33)]) --> (1, 11) (2, 22) (3, 33)
+       return zip(*it, strict=True)
 
-       If *pred* is not None, returns the first item
-       for which pred(item) is true.
+   def matmul(m1, m2):
+       "Multiply two matrices."
+       # matmul([(7, 5), (3, 5)], [[2, 5], [7, 9]]) --> (49, 80), (41, 60)
+       n = len(m2[0])
+       return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)
+
+   def convolve(signal, kernel):
+       """Linear convolution of two iterables.
 
+       Article:  https://betterexplained.com/articles/intuitive-convolution/
+       Video:    https://www.youtube.com/watch?v=KuXjwB4LzSA
        """
-       # first_true([a,b,c], x) --> a or b or c or x
-       # first_true([a,b], x, f) --> a if f(a) else b if f(b) else x
-       return next(filter(pred, iterable), default)
+       # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
+       # convolve(data, [1, -1]) --> 1st finite difference (1st derivative)
+       # convolve(data, [1, -2, 1]) --> 2nd finite difference (2nd derivative)
+       kernel = tuple(kernel)[::-1]
+       n = len(kernel)
+       padded_signal = chain(repeat(0, n-1), signal, [0] * (n-1))
+       for window in sliding_window(padded_signal, n):
+           yield math.sumprod(kernel, window)
+
+   def polynomial_from_roots(roots):
+       """Compute a polynomial's coefficients from its roots.
+
+          (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
+       """
+       # polynomial_from_roots([5, -4, 3]) --> [1, -4, -17, 60]
+       expansion = [1]
+       for r in roots:
+           expansion = convolve(expansion, (1, -r))
+       return list(expansion)
+
+   def polynomial_eval(coefficients, x):
+       """Evaluate a polynomial at a specific value.
+
+       Computes with better numeric stability than Horner's method.
+       """
+       # Evaluate x³ -4x² -17x + 60 at x = 2.5
+       # polynomial_eval([1, -4, -17, 60], x=2.5) --> 8.125
+       n = len(coefficients)
+       if n == 0:
+           return x * 0  # coerce zero to the type of x
+       powers = map(pow, repeat(x), reversed(range(n)))
+       return math.sumprod(coefficients, powers)
 
    def nth_combination(iterable, r, index):
        "Equivalent to list(combinations(iterable, r))[index]"
@@ -1126,6 +1128,7 @@ which incur interpreter overhead.
            result.append(pool[-1-n])
        return tuple(result)
 
+
 .. doctest::
     :hide:
 
@@ -1401,9 +1404,6 @@ which incur interpreter overhead.
     >>> list(grouper('abcdefg', n=3, incomplete='ignore'))
     [('a', 'b', 'c'), ('d', 'e', 'f')]
 
-    >>> list(triplewise('ABCDEFG'))
-    [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E'), ('D', 'E', 'F'), ('E', 'F', 'G')]
-
     >>> list(sliding_window('ABCDEFG', 4))
     [('A', 'B', 'C', 'D'), ('B', 'C', 'D', 'E'), ('C', 'D', 'E', 'F'), ('D', 'E', 'F', 'G')]
 
@@ -1485,3 +1485,45 @@ which incur interpreter overhead.
     >>> combos = list(combinations(iterable, r))
     >>> all(nth_combination(iterable, r, i) == comb for i, comb in enumerate(combos))
     True
+
+
+.. testcode::
+    :hide:
+
+    # Old recipes and their tests which are guaranteed to continue to work.
+
+    def sumprod(vec1, vec2):
+        "Compute a sum of products."
+        return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
+
+    def dotproduct(vec1, vec2):
+        return sum(map(operator.mul, vec1, vec2))
+
+    def pad_none(iterable):
+        """Returns the sequence elements and then returns None indefinitely.
+
+        Useful for emulating the behavior of the built-in map() function.
+        """
+        return chain(iterable, repeat(None))
+
+    def triplewise(iterable):
+        "Return overlapping triplets from an iterable"
+        # triplewise('ABCDEFG') --> ABC BCD CDE DEF EFG
+        for (a, _), (b, c) in pairwise(pairwise(iterable)):
+            yield a, b, c
+
+
+.. doctest::
+    :hide:
+
+    >>> dotproduct([1,2,3], [4,5,6])
+    32
+
+    >>> sumprod([1,2,3], [4,5,6])
+    32
+
+    >>> list(islice(pad_none('abc'), 0, 6))
+    ['a', 'b', 'c', None, None, None]
+
+    >>> list(triplewise('ABCDEFG'))
+    [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E'), ('D', 'E', 'F'), ('E', 'F', 'G')]