]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
SF #950057: itertools.chain doesn't "process" exceptions as they occur
authorRaymond Hettinger <python@rcn.com>
Sat, 8 May 2004 19:52:39 +0000 (19:52 +0000)
committerRaymond Hettinger <python@rcn.com>
Sat, 8 May 2004 19:52:39 +0000 (19:52 +0000)
Both cycle() and chain() were handling exceptions only when switching
input sources.  The patch makes the handle more immediate.

Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index 2da82ed9c77bca639d116fe4c9fed72c7ba08b14..ba4a41b8a24fcf75e7f47b885ec57d8724cc6bb4 100644 (file)
@@ -458,6 +458,36 @@ class RegressionTests(unittest.TestCase):
         self.assertEqual(first, second)
 
 
+    def test_sf_950057(self):
+        # Make sure that chain() and cycle() catch exceptions immediately
+        # rather than when shifting between input sources
+
+        def gen1():
+            hist.append(0)
+            yield 1
+            hist.append(1)
+            assert False
+            hist.append(2)
+
+        def gen2(x):
+            hist.append(3)
+            yield 2
+            hist.append(4)
+            if x:
+                raise StopIteration
+
+        hist = []
+        self.assertRaises(AssertionError, list, chain(gen1(), gen2(False)))
+        self.assertEqual(hist, [0,1])
+
+        hist = []
+        self.assertRaises(AssertionError, list, chain(gen1(), gen2(True)))
+        self.assertEqual(hist, [0,1])
+
+        hist = []
+        self.assertRaises(AssertionError, list, cycle(gen1()))
+        self.assertEqual(hist, [0,1])
+
 libreftest = """ Doctest for examples in the library reference: libitertools.tex
 
 
index 68e176f23d4e71c0078d6e6f76887dadcc87d80f..edc6159c872ccb4e616ad15007ef2a1a22406d8d 100644 (file)
@@ -94,6 +94,12 @@ cycle_next(cycleobject *lz)
                                PyList_Append(lz->saved, item);
                        return item;
                }
+               if (PyErr_Occurred()) {
+                       if (PyErr_ExceptionMatches(PyExc_StopIteration))
+                               PyErr_Clear();
+                       else
+                               return NULL;
+               }
                if (PyList_Size(lz->saved) == 0) 
                        return NULL;
                it = PyObject_GetIter(lz->saved);
@@ -1049,6 +1055,12 @@ chain_next(chainobject *lz)
                item = PyIter_Next(it);
                if (item != NULL)
                        return item;
+               if (PyErr_Occurred()) {
+                       if (PyErr_ExceptionMatches(PyExc_StopIteration))
+                               PyErr_Clear();
+                       else
+                               return NULL;
+               }
                lz->iternum++;
        }
        return NULL;