]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Generalize map() to work with iterators.
authorTim Peters <tim.peters@gmail.com>
Thu, 3 May 2001 23:54:49 +0000 (23:54 +0000)
committerTim Peters <tim.peters@gmail.com>
Thu, 3 May 2001 23:54:49 +0000 (23:54 +0000)
NEEDS DOC CHANGES.
Possibly contentious:  The first time s.next() yields StopIteration (for
a given map argument s) is the last time map() *tries* s.next().  That
is, if other sequence args are longer, s will never again contribute
anything but None values to the result, even if trying s.next() again
could yield another result.  This is the same behavior map() used to have
wrt IndexError, so it's the only way to be wholly backward-compatible.
I'm not a fan of letting StopIteration mean "try again later" anyway.

Lib/test/test_iter.py
Misc/NEWS
Python/bltinmodule.c

index 3563661b6414412e1f8ac12e7df88e666e19a79d..c87f5ec29d497fce7d7048cb37adefec1bbce05c 100644 (file)
@@ -351,4 +351,39 @@ class TestCase(unittest.TestCase):
             except OSError:
                 pass
 
+    # Test map()'s use of iterators.
+    def test_builtin_map(self):
+        self.assertEqual(map(None, SequenceClass(5)), range(5))
+        self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
+
+        d = {"one": 1, "two": 2, "three": 3}
+        self.assertEqual(map(None, d), d.keys())
+        self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
+        dkeys = d.keys()
+        expected = [(i < len(d) and dkeys[i] or None,
+                     i,
+                     i < len(d) and dkeys[i] or None)
+                    for i in range(5)]
+        self.assertEqual(map(None, d,
+                                   SequenceClass(5),
+                                   iter(d.iterkeys())),
+                         expected) 
+
+        f = open(TESTFN, "w")
+        try:
+            for i in range(10):
+                f.write("xy" * i + "\n") # line i has len 2*i+1
+        finally:
+            f.close()
+        f = open(TESTFN, "r")
+        try:
+            self.assertEqual(map(len, f), range(1, 21, 2))
+            f.seek(0, 0)
+        finally:
+            f.close()
+            try:
+                unlink(TESTFN)
+            except OSError:
+                pass
+
 run_unittest(TestCase)
index 9d84845744edc15c46dde565c8affc3a9f4655fe..617146c9719b95fd79a6ed634a32fd5e64ecd1ff 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -19,6 +19,7 @@ Core
   arguments:
     filter()
     list()
+    map()
     max()
     min()
 
index 9e8a2279c0db7aa4729108d6d5d791a42aa7886c..0c20d1005892d5ef0886b041b4bcf400024e93c7 100644 (file)
@@ -936,9 +936,8 @@ static PyObject *
 builtin_map(PyObject *self, PyObject *args)
 {
        typedef struct {
-               PyObject *seq;
-               PySequenceMethods *sqf;
-               int saw_IndexError;
+               PyObject *it;   /* the iterator object */
+               int saw_StopIteration;  /* bool:  did the iterator end? */
        } sequence;
 
        PyObject *func, *result;
@@ -961,104 +960,105 @@ builtin_map(PyObject *self, PyObject *args)
                return PySequence_List(PyTuple_GetItem(args, 1));
        }
 
+       /* Get space for sequence descriptors.  Must NULL out the iterator
+        * pointers so that jumping to Fail_2 later doesn't see trash.
+        */
        if ((seqs = PyMem_NEW(sequence, n)) == NULL) {
                PyErr_NoMemory();
-               goto Fail_2;
+               return NULL;
+       }
+       for (i = 0; i < n; ++i) {
+               seqs[i].it = (PyObject*)NULL;
+               seqs[i].saw_StopIteration = 0;
        }
 
-       /* Do a first pass to (a) verify the args are sequences; (b) set
-        * len to the largest of their lengths; (c) initialize the seqs
-        * descriptor vector.
+       /* Do a first pass to obtain iterators for the arguments, and set len
+        * to the largest of their lengths.
         */
-       for (len = 0, i = 0, sqp = seqs; i < n; ++i, ++sqp) {
+       len = 0;
+       for (i = 0, sqp = seqs; i < n; ++i, ++sqp) {
+               PyObject *curseq;
                int curlen;
-               PySequenceMethods *sqf;
 
-               if ((sqp->seq = PyTuple_GetItem(args, i + 1)) == NULL)
-                       goto Fail_2;
-
-               sqp->saw_IndexError = 0;
-
-               sqp->sqf = sqf = sqp->seq->ob_type->tp_as_sequence;
-               if (sqf == NULL ||
-                   sqf->sq_item == NULL)
-               {
+               /* Get iterator. */
+               curseq = PyTuple_GetItem(args, i+1);
+               sqp->it = PyObject_GetIter(curseq);
+               if (sqp->it == NULL) {
                        static char errmsg[] =
-                           "argument %d to map() must be a sequence object";
+                           "argument %d to map() must support iteration";
                        char errbuf[sizeof(errmsg) + 25];
-
                        sprintf(errbuf, errmsg, i+2);
                        PyErr_SetString(PyExc_TypeError, errbuf);
                        goto Fail_2;
                }
 
-               if (sqf->sq_length == NULL)
-                       /* doesn't matter -- make something up */
-                       curlen = 8;
-               else
-                       curlen = (*sqf->sq_length)(sqp->seq);
+               /* Update len. */
+               curlen = -1;  /* unknown */
+               if (PySequence_Check(curseq) &&
+                   curseq->ob_type->tp_as_sequence->sq_length) {
+                       curlen = PySequence_Size(curseq);
+                       if (curlen < 0)
+                               PyErr_Clear();
+               }
                if (curlen < 0)
-                       goto Fail_2;
+                       curlen = 8;  /* arbitrary */
                if (curlen > len)
                        len = curlen;
        }
 
+       /* Get space for the result list. */
        if ((result = (PyObject *) PyList_New(len)) == NULL)
                goto Fail_2;
 
-       /* Iterate over the sequences until all have raised IndexError. */
+       /* Iterate over the sequences until all have stopped. */
        for (i = 0; ; ++i) {
                PyObject *alist, *item=NULL, *value;
-               int any = 0;
+               int numactive = 0;
 
                if (func == Py_None && n == 1)
                        alist = NULL;
-               else {
-                       if ((alist = PyTuple_New(n)) == NULL)
-                               goto Fail_1;
-               }
+               else if ((alist = PyTuple_New(n)) == NULL)
+                       goto Fail_1;
 
                for (j = 0, sqp = seqs; j < n; ++j, ++sqp) {
-                       if (sqp->saw_IndexError) {
+                       if (sqp->saw_StopIteration) {
                                Py_INCREF(Py_None);
                                item = Py_None;
                        }
                        else {
-                               item = (*sqp->sqf->sq_item)(sqp->seq, i);
-                               if (item == NULL) {
-                                       if (PyErr_ExceptionMatches(
-                                               PyExc_IndexError))
-                                       {
-                                               PyErr_Clear();
-                                               Py_INCREF(Py_None);
-                                               item = Py_None;
-                                               sqp->saw_IndexError = 1;
-                                       }
-                                       else {
-                                               goto Fail_0;
+                               item = PyIter_Next(sqp->it);
+                               if (item)
+                                       ++numactive;
+                               else {
+                                       /* StopIteration is *implied* by a
+                                        * NULL return from PyIter_Next() if
+                                        * PyErr_Occurred() is false.
+                                        */
+                                       if (PyErr_Occurred()) {
+                                               if (PyErr_ExceptionMatches(
+                                                   PyExc_StopIteration))
+                                                       PyErr_Clear();
+                                               else {
+                                                       Py_XDECREF(alist);
+                                                       goto Fail_1;
+                                               }
                                        }
+                                       Py_INCREF(Py_None);
+                                       item = Py_None;
+                                       sqp->saw_StopIteration = 1;
                                }
-                               else
-                                       any = 1;
 
                        }
-                       if (!alist)
+                       if (alist)
+                               PyTuple_SET_ITEM(alist, j, item);
+                       else
                                break;
-                       if (PyTuple_SetItem(alist, j, item) < 0) {
-                               Py_DECREF(item);
-                               goto Fail_0;
-                       }
-                       continue;
-
-               Fail_0:
-                       Py_XDECREF(alist);
-                       goto Fail_1;
                }
 
                if (!alist)
                        alist = item;
 
-               if (!any) {
+               if (numactive == 0) {
                        Py_DECREF(alist);
                        break;
                }
@@ -1077,23 +1077,25 @@ builtin_map(PyObject *self, PyObject *args)
                        if (status < 0)
                                goto Fail_1;
                }
-               else {
-                       if (PyList_SetItem(result, i, value) < 0)
-                               goto Fail_1;
-               }
+               else if (PyList_SetItem(result, i, value) < 0)
+                       goto Fail_1;
        }
 
        if (i < len && PyList_SetSlice(result, i, len, NULL) < 0)
                goto Fail_1;
 
-       PyMem_DEL(seqs);
-       return result;
+       goto Succeed;
 
 Fail_1:
        Py_DECREF(result);
 Fail_2:
-       if (seqs) PyMem_DEL(seqs);
-       return NULL;
+       result = NULL;
+Succeed:
+       assert(seqs);
+       for (i = 0; i < n; ++i)
+               Py_XDECREF(seqs[i].it);
+       PyMem_DEL(seqs);
+       return result;
 }
 
 static char map_doc[] =