]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Add itertools.combinations().
authorRaymond Hettinger <python@rcn.com>
Tue, 26 Feb 2008 23:40:50 +0000 (23:40 +0000)
committerRaymond Hettinger <python@rcn.com>
Tue, 26 Feb 2008 23:40:50 +0000 (23:40 +0000)
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS
Modules/itertoolsmodule.c

index c8f6e33f2866ce1bf2abaa35e9f049dce94d558a..6dc19a17a094504aac352a855ab5d8c84c13ef7b 100644 (file)
@@ -97,21 +97,21 @@ loops that truncate the stream.
 
         def combinations(iterable, r):
             pool = tuple(iterable)
-            if pool:
-                n = len(pool)
-                vec = range(r)
-                yield tuple(pool[i] for i in vec)
-                while 1:
-                    for i in reversed(range(r)):
-                        if vec[i] == i + n-r:
-                            continue
-                        vec[i] += 1
-                        for j in range(i+1, r):
-                            vec[j] = vec[j-1] + 1
-                        yield tuple(pool[i] for i in vec)
-                        break
-                    else:
-                        return
+            n = len(pool)
+            assert 0 <= r <= n
+            vec = range(r)
+            yield tuple(pool[i] for i in vec)
+            while 1:
+                for i in reversed(range(r)):
+                    if vec[i] == i + n-r:
+                        continue
+                    vec[i] += 1
+                    for j in range(i+1, r):
+                        vec[j] = vec[j-1] + 1
+                    yield tuple(pool[i] for i in vec)
+                    break
+                else:
+                    return
 
    .. versionadded:: 2.6
 
index dc9081ec5aacbb23b5514b0e85e947b663c36a42..98683254c7f81863cccbfcc4115ee3268f6cfe2d 100644 (file)
@@ -40,6 +40,10 @@ def take(n, seq):
     'Convenience function for partially consuming a long of infinite iterable'
     return list(islice(seq, n))
 
+def fact(n):
+    'Factorial'
+    return reduce(operator.mul, range(1, n+1), 1)
+
 class TestBasicOps(unittest.TestCase):
     def test_chain(self):
         self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
@@ -48,6 +52,26 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
         self.assertRaises(TypeError, chain, 2, 3)
 
+    def test_combinations(self):
+        self.assertRaises(TypeError, combinations, 'abc')   # missing r argument
+        self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
+        self.assertRaises(ValueError, combinations, 'abc', -2)  # r is negative
+        self.assertRaises(ValueError, combinations, 'abc', 32)  # r is too big
+        self.assertEqual(list(combinations(range(4), 3)),
+                                           [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
+        for n in range(6):
+            values = [5*x-12 for x in range(n)]
+            for r in range(n+1):
+                result = list(combinations(values, r))
+                self.assertEqual(len(result), fact(n) / fact(r) / fact(n-r)) # right number of combs
+                self.assertEqual(len(result), len(set(result)))         # no repeats
+                self.assertEqual(result, sorted(result))                # lexicographic order
+                for c in result:
+                    self.assertEqual(len(c), r)                         # r-length combinations
+                    self.assertEqual(len(set(c)), r)                    # no duplicate elements
+                    self.assertEqual(list(c), sorted(c))                # keep original ordering
+                    self.assert_(all(e in values for e in c))           # elements taken from input iterable
+
     def test_count(self):
         self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
         self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
index d72fe745c4e9c73b8be273d738bae31246001425..c28fd856770ef80279cfe8585f582909de489dc1 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -670,6 +670,8 @@ Library
 - Added itertools.product() which forms the Cartesian product of
   the input iterables.
 
+- Added itertools.combinations().
+
 - Patch #1541463: optimize performance of cgi.FieldStorage operations.
 
 - Decimal is fully updated to the latest Decimal Specification (v1.66).
index d95376a581dd9f46c19c104cf864e19a79184482..10c5e0bdbd86a5e7bed5603cea6e668ed8cf6199 100644 (file)
@@ -1982,6 +1982,229 @@ static PyTypeObject product_type = {
 };
 
 
+/* combinations object ************************************************************/
+
+typedef struct {
+       PyObject_HEAD
+       PyObject *pool;                 /* input converted to a tuple */
+       Py_ssize_t *indices;            /* one index per result element */
+       PyObject *result;               /* most recently returned result tuple */
+       Py_ssize_t r;                   /* size of result tuple */
+       int stopped;                    /* set to 1 when the combinations iterator is exhausted */
+} combinationsobject;
+
+static PyTypeObject combinations_type;
+
+static PyObject *
+combinations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       combinationsobject *co;
+       Py_ssize_t n;
+       Py_ssize_t r;
+       PyObject *pool = NULL;
+       PyObject *iterable = NULL;
+       Py_ssize_t *indices = NULL;
+       Py_ssize_t i;
+       static char *kwargs[] = {"iterable", "r", NULL};
+       if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations", kwargs, 
+                                        &iterable, &r))
+               return NULL;
+
+       pool = PySequence_Tuple(iterable);
+       if (pool == NULL)
+               goto error;
+       n = PyTuple_GET_SIZE(pool);
+       if (r < 0) {
+               PyErr_SetString(PyExc_ValueError, "r must be non-negative");
+               goto error;
+       }
+       if (r > n) {
+               PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
+               goto error;
+       }
+
+       indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
+       if (indices == NULL) {
+               PyErr_NoMemory();
+               goto error;
+       }
+
+       for (i=0 ; i<r ; i++)
+               indices[i] = i;
+
+       /* create combinationsobject structure */
+       co = (combinationsobject *)type->tp_alloc(type, 0);
+       if (co == NULL)
+               goto error;
+
+       co->pool = pool;
+       co->indices = indices;
+       co->result = NULL;
+       co->r = r;
+       co->stopped = 0;
+
+       return (PyObject *)co;
+
+error:
+       if (indices != NULL)
+               PyMem_Free(indices);
+       Py_XDECREF(pool);
+       return NULL;
+}
+
+static void
+combinations_dealloc(combinationsobject *co)
+{
+       PyObject_GC_UnTrack(co);
+       Py_XDECREF(co->pool);
+       Py_XDECREF(co->result);
+       PyMem_Free(co->indices);
+       Py_TYPE(co)->tp_free(co);
+}
+
+static int
+combinations_traverse(combinationsobject *co, visitproc visit, void *arg)
+{
+       Py_VISIT(co->pool);
+       Py_VISIT(co->result);
+       return 0;
+}
+
+static PyObject *
+combinations_next(combinationsobject *co)
+{
+       PyObject *elem;
+       PyObject *oldelem;
+       PyObject *pool = co->pool;
+       Py_ssize_t *indices = co->indices;
+       PyObject *result = co->result;
+       Py_ssize_t n = PyTuple_GET_SIZE(pool);
+       Py_ssize_t r = co->r;
+       Py_ssize_t i, j, index;
+
+       if (co->stopped)
+               return NULL;
+
+       if (result == NULL) {
+                /* On the first pass, initialize result tuple using the indices */
+               result = PyTuple_New(r);
+               if (result == NULL)
+                       goto empty;
+               co->result = result;
+               for (i=0; i<r ; i++) {
+                       index = indices[i];
+                       elem = PyTuple_GET_ITEM(pool, index);
+                       Py_INCREF(elem);
+                       PyTuple_SET_ITEM(result, i, elem);
+               }
+       } else {
+               /* Copy the previous result tuple or re-use it if available */
+               if (Py_REFCNT(result) > 1) {
+                       PyObject *old_result = result;
+                       result = PyTuple_New(r);
+                       if (result == NULL)
+                               goto empty;
+                       co->result = result;
+                       for (i=0; i<r ; i++) {
+                               elem = PyTuple_GET_ITEM(old_result, i);
+                               Py_INCREF(elem);
+                               PyTuple_SET_ITEM(result, i, elem);
+                       }
+                       Py_DECREF(old_result);
+               }
+               /* Now, we've got the only copy so we can update it in-place */
+               assert (Py_REFCNT(result) == 1);
+
+                /* Scan indices right-to-left until finding one that is not
+                   at its maximum (i + n - r). */
+               for (i=r-1 ; i >= 0 && indices[i] == i+n-r ; i--)
+                       ;
+
+               /* If i is negative, then the indices are all at
+                   their maximum value and we're done. */
+               if (i < 0)
+                       goto empty;
+
+               /* Increment the current index which we know is not at its
+                   maximum.  Then move back to the right setting each index
+                   to its lowest possible value (one higher than the index
+                   to its left -- this maintains the sort order invariant). */
+               indices[i]++;
+               for (j=i+1 ; j<r ; j++)
+                       indices[j] = indices[j-1] + 1;
+
+               /* Update the result tuple for the new indices
+                  starting with i, the leftmost index that changed */
+               for ( ; i<r ; i++) {
+                       index = indices[i];
+                       elem = PyTuple_GET_ITEM(pool, index);
+                       Py_INCREF(elem);
+                       oldelem = PyTuple_GET_ITEM(result, i);
+                       PyTuple_SET_ITEM(result, i, elem);
+                       Py_DECREF(oldelem);
+               }
+       }
+
+       Py_INCREF(result);
+       return result;
+
+empty:
+       co->stopped = 1;
+       return NULL;
+}
+
+PyDoc_STRVAR(combinations_doc,
+"combinations(iterables) --> combinations object\n\
+\n\
+Return successive r-length combinations of elements in the iterable.\n\n\
+combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)");
+
+static PyTypeObject combinations_type = {
+       PyVarObject_HEAD_INIT(NULL, 0)
+       "itertools.combinations",               /* tp_name */
+       sizeof(combinationsobject),     /* tp_basicsize */
+       0,                              /* tp_itemsize */
+       /* methods */
+       (destructor)combinations_dealloc,       /* tp_dealloc */
+       0,                              /* tp_print */
+       0,                              /* tp_getattr */
+       0,                              /* tp_setattr */
+       0,                              /* tp_compare */
+       0,                              /* tp_repr */
+       0,                              /* tp_as_number */
+       0,                              /* tp_as_sequence */
+       0,                              /* tp_as_mapping */
+       0,                              /* tp_hash */
+       0,                              /* tp_call */
+       0,                              /* tp_str */
+       PyObject_GenericGetAttr,        /* tp_getattro */
+       0,                              /* tp_setattro */
+       0,                              /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+               Py_TPFLAGS_BASETYPE,    /* tp_flags */
+       combinations_doc,                       /* tp_doc */
+       (traverseproc)combinations_traverse,    /* tp_traverse */
+       0,                              /* tp_clear */
+       0,                              /* tp_richcompare */
+       0,                              /* tp_weaklistoffset */
+       PyObject_SelfIter,              /* tp_iter */
+       (iternextfunc)combinations_next,        /* tp_iternext */
+       0,                              /* tp_methods */
+       0,                              /* tp_members */
+       0,                              /* tp_getset */
+       0,                              /* tp_base */
+       0,                              /* tp_dict */
+       0,                              /* tp_descr_get */
+       0,                              /* tp_descr_set */
+       0,                              /* tp_dictoffset */
+       0,                              /* tp_init */
+       0,                              /* tp_alloc */
+       combinations_new,                       /* tp_new */
+       PyObject_GC_Del,                /* tp_free */
+};
+
+
 /* ifilter object ************************************************************/
 
 typedef struct {
@@ -3026,6 +3249,7 @@ inititertools(void)
        PyObject *m;
        char *name;
        PyTypeObject *typelist[] = {
+               &combinations_type,
                &cycle_type,
                &dropwhile_type,
                &takewhile_type,
@@ -3038,7 +3262,7 @@ inititertools(void)
                &count_type,
                &izip_type,
                &iziplongest_type,
-               &product_type,          
+               &product_type,         
                &repeat_type,
                &groupby_type,
                NULL