]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Add optional *func* argument to itertools.accumulate().
authorRaymond Hettinger <python@rcn.com>
Mon, 28 Mar 2011 01:52:10 +0000 (18:52 -0700)
committerRaymond Hettinger <python@rcn.com>
Mon, 28 Mar 2011 01:52:10 +0000 (18:52 -0700)
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS
Modules/itertoolsmodule.c

index 757823d9f1d022a66deeab30a49e8944e8ea5fdf..07378d1da5da47f75979f5a359a13c8f8bd6ee98 100644 (file)
@@ -46,7 +46,7 @@ Iterator            Arguments               Results
 ====================    ============================    =================================================   =============================================================
 Iterator                Arguments                       Results                                             Example
 ====================    ============================    =================================================   =============================================================
-:func:`accumulate`      p                               p0, p0+p1, p0+p1+p2, ...                            ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
+:func:`accumulate`      p [,func]                       p0, p0+p1, p0+p1+p2, ...                            ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
 :func:`chain`           p, q, ...                       p0, p1, ... plast, q0, q1, ...                      ``chain('ABC', 'DEF') --> A B C D E F``
 :func:`compress`        data, selectors                 (d[0] if s[0]), (d[1] if s[1]), ...                 ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F``
 :func:`dropwhile`       pred, seq                       seq[n], seq[n+1], starting when pred fails          ``dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1``
@@ -84,23 +84,46 @@ The following module functions all construct and return iterators. Some provide
 streams of infinite length, so they should only be accessed by functions or
 loops that truncate the stream.
 
-.. function:: accumulate(iterable)
+.. function:: accumulate(iterable[, func])
 
     Make an iterator that returns accumulated sums. Elements may be any addable
-    type including :class:`Decimal` or :class:`Fraction`.  Equivalent to::
+    type including :class:`Decimal` or :class:`Fraction`.  If the optional
+    *func* argument is supplied, it should be a function of two arguments
+    and it will be used instead of addition.
 
-        def accumulate(iterable):
+    Equivalent to::
+
+        def accumulate(iterable, func=operator.add):
             'Return running totals'
             # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+            # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
             it = iter(iterable)
             total = next(it)
             yield total
             for element in it:
-                total = total + element
+                total = func(total, element)
                 yield total
 
+    Uses for the *func* argument include :func:`min` for a running minimum,
+    :func:`max` for a running maximum, and :func:`operator.mul` for a running
+    product::
+
+      >>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
+      >>> list(accumulate(data, operator.mul))     # running product
+      [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
+      >>> list(accumulate(data, max))              # running maximum
+      [3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
+
+      # Amortize a 5% loan of 1000 with 4 annual payments of 90
+      >>> cashflows = [1000, -90, -90, -90, -90]
+      >>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt))
+      [1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001]
+
     .. versionadded:: 3.2
 
+    .. versionchanged:: 3.3
+       Added the optional *func* parameter.
+
 .. function:: chain(*iterables)
 
    Make an iterator that returns elements from the first iterable until it is
index 5e4bf1ba8917c025cbb2e8dba5adf9f3ebe22293..acbb00a9ff2f017105c797f5965ece559eba67ea 100644 (file)
@@ -69,11 +69,21 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(list(accumulate('abc')), ['a', 'ab', 'abc'])   # works with non-numeric
         self.assertEqual(list(accumulate([])), [])                  # empty iterable
         self.assertEqual(list(accumulate([7])), [7])                # iterable of length one
-        self.assertRaises(TypeError, accumulate, range(10), 5)      # too many args
+        self.assertRaises(TypeError, accumulate, range(10), 5, 6)   # too many args
         self.assertRaises(TypeError, accumulate)                    # too few args
         self.assertRaises(TypeError, accumulate, x=range(10))       # unexpected kwd arg
         self.assertRaises(TypeError, list, accumulate([1, []]))     # args that don't add
 
+        s = [2, 8, 9, 5, 7, 0, 3, 4, 1, 6]
+        self.assertEqual(list(accumulate(s, min)),
+                         [2, 2, 2, 2, 2, 0, 0, 0, 0, 0])
+        self.assertEqual(list(accumulate(s, max)),
+                         [2, 8, 9, 9, 9, 9, 9, 9, 9, 9])
+        self.assertEqual(list(accumulate(s, operator.mul)),
+                         [2, 16, 144, 720, 5040, 0, 0, 0, 0, 0])
+        with self.assertRaises(TypeError):
+            list(accumulate(s, chr))                                # unary-operation
+
     def test_chain(self):
 
         def chain2(*iterables):
index 0c0e136db4f0ce1d5b48c497c5360df5cddd5687..0341dd61305e4b3acdc7d433a4945983100fa4ea 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -89,6 +89,9 @@ Library
 
 - Issue #11696: Fix ID generation in msilib.
 
+- itertools.accumulate now supports an optional *func* argument for
+  a user-supplied binary function.
+
 - Issue #11692: Remove unnecessary demo functions in subprocess module.
 
 - Issue #9696: Fix exception incorrectly raised by xdrlib.Packer.pack_int when
index b202e5262bab463e6abe240c0974b6548265b3d4..4f58d573fb1b66a2faa64ce8a5b7cdaca6a62f85 100644 (file)
@@ -2590,6 +2590,7 @@ typedef struct {
     PyObject_HEAD
     PyObject *total;
     PyObject *it;
+    PyObject *binop;
 } accumulateobject;
 
 static PyTypeObject accumulate_type;
@@ -2597,12 +2598,14 @@ static PyTypeObject accumulate_type;
 static PyObject *
 accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
-    static char *kwargs[] = {"iterable", NULL};
+    static char *kwargs[] = {"iterable", "func", NULL};
     PyObject *iterable;
     PyObject *it;
+    PyObject *binop = NULL;
     accumulateobject *lz;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable))
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
+                                     kwargs, &iterable, &binop))
         return NULL;
 
     /* Get iterator. */
@@ -2617,6 +2620,8 @@ accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
         return NULL;
     }
 
+    Py_XINCREF(binop);
+    lz->binop = binop;
     lz->total = NULL;
     lz->it = it;
     return (PyObject *)lz;
@@ -2626,6 +2631,7 @@ static void
 accumulate_dealloc(accumulateobject *lz)
 {
     PyObject_GC_UnTrack(lz);
+    Py_XDECREF(lz->binop);
     Py_XDECREF(lz->total);
     Py_XDECREF(lz->it);
     Py_TYPE(lz)->tp_free(lz);
@@ -2634,6 +2640,7 @@ accumulate_dealloc(accumulateobject *lz)
 static int
 accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
 {
+    Py_VISIT(lz->binop);
     Py_VISIT(lz->it);
     Py_VISIT(lz->total);
     return 0;
@@ -2653,8 +2660,11 @@ accumulate_next(accumulateobject *lz)
         lz->total = val;
         return lz->total;
     }
-   
-    newtotal = PyNumber_Add(lz->total, val);
+
+    if (lz->binop == NULL) 
+        newtotal = PyNumber_Add(lz->total, val);
+    else
+        newtotal = PyObject_CallFunctionObjArgs(lz->binop, lz->total, val, NULL);
     Py_DECREF(val);
     if (newtotal == NULL)
         return NULL;