]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.12] Tee of tee was not producing n independent iterators (gh-123884) (gh-125153)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Tue, 8 Oct 2024 20:16:18 +0000 (15:16 -0500)
committerGitHub <noreply@github.com>
Tue, 8 Oct 2024 20:16:18 +0000 (20:16 +0000)
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst [new file with mode: 0644]
Modules/itertoolsmodule.c

index 3fab46c3c0a5b4db56ed86fd9329d4c7690b30d2..047d805eda628a62bf13e8eb61d692ad548d6cea 100644 (file)
@@ -676,24 +676,37 @@ loops that truncate the stream.
    Roughly equivalent to::
 
         def tee(iterable, n=2):
-            iterator = iter(iterable)
-            shared_link = [None, None]
-            return tuple(_tee(iterator, shared_link) for _ in range(n))
-
-        def _tee(iterator, link):
-            try:
-                while True:
-                    if link[1] is None:
-                        link[0] = next(iterator)
-                        link[1] = [None, None]
-                    value, link = link
-                    yield value
-            except StopIteration:
-                return
-
-   Once a :func:`tee` has been created, the original *iterable* should not be
-   used anywhere else; otherwise, the *iterable* could get advanced without
-   the tee objects being informed.
+            if n < 0:
+                raise ValueError
+            if n == 0:
+                return ()
+            iterator = _tee(iterable)
+            result = [iterator]
+            for _ in range(n - 1):
+                result.append(_tee(iterator))
+            return tuple(result)
+
+        class _tee:
+
+            def __init__(self, iterable):
+                it = iter(iterable)
+                if isinstance(it, _tee):
+                    self.iterator = it.iterator
+                    self.link = it.link
+                else:
+                    self.iterator = it
+                    self.link = [None, None]
+
+            def __iter__(self):
+                return self
+
+            def __next__(self):
+                link = self.link
+                if link[1] is None:
+                    link[0] = next(self.iterator)
+                    link[1] = [None, None]
+                value, self.link = link
+                return value
 
    ``tee`` iterators are not threadsafe. A :exc:`RuntimeError` may be
    raised when simultaneously using iterators returned by the same :func:`tee`
index 3d20e70fc1b63f84bea6558d870c2ea13987e3c2..b6404f4366ca0e979e084e29f52c4186c845c9da 100644 (file)
@@ -1612,10 +1612,11 @@ class TestBasicOps(unittest.TestCase):
             self.assertEqual(len(result), n)
             self.assertEqual([list(x) for x in result], [list('abc')]*n)
 
-        # tee pass-through to copyable iterator
+        # tee objects are independent (see bug gh-123884)
         a, b = tee('abc')
         c, d = tee(a)
-        self.assertTrue(a is c)
+        e, f = tee(c)
+        self.assertTrue(len({a, b, c, d, e, f}) == 6)
 
         # test tee_new
         t1, t2 = tee('abc')
@@ -2029,6 +2030,172 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
         self.assertEqual(next(c), 3)
 
 
+    def test_tee_recipe(self):
+
+        # Begin tee() recipe ###########################################
+
+        def tee(iterable, n=2):
+            if n < 0:
+                raise ValueError
+            if n == 0:
+                return ()
+            iterator = _tee(iterable)
+            result = [iterator]
+            for _ in range(n - 1):
+                result.append(_tee(iterator))
+            return tuple(result)
+
+        class _tee:
+
+            def __init__(self, iterable):
+                it = iter(iterable)
+                if isinstance(it, _tee):
+                    self.iterator = it.iterator
+                    self.link = it.link
+                else:
+                    self.iterator = it
+                    self.link = [None, None]
+
+            def __iter__(self):
+                return self
+
+            def __next__(self):
+                link = self.link
+                if link[1] is None:
+                    link[0] = next(self.iterator)
+                    link[1] = [None, None]
+                value, self.link = link
+                return value
+
+        # End tee() recipe #############################################
+
+        n = 200
+
+        a, b = tee([])        # test empty iterator
+        self.assertEqual(list(a), [])
+        self.assertEqual(list(b), [])
+
+        a, b = tee(irange(n)) # test 100% interleaved
+        self.assertEqual(lzip(a,b), lzip(range(n), range(n)))
+
+        a, b = tee(irange(n)) # test 0% interleaved
+        self.assertEqual(list(a), list(range(n)))
+        self.assertEqual(list(b), list(range(n)))
+
+        a, b = tee(irange(n)) # test dealloc of leading iterator
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        del a
+        self.assertEqual(list(b), list(range(n)))
+
+        a, b = tee(irange(n)) # test dealloc of trailing iterator
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        del b
+        self.assertEqual(list(a), list(range(100, n)))
+
+        for j in range(5):   # test randomly interleaved
+            order = [0]*n + [1]*n
+            random.shuffle(order)
+            lists = ([], [])
+            its = tee(irange(n))
+            for i in order:
+                value = next(its[i])
+                lists[i].append(value)
+            self.assertEqual(lists[0], list(range(n)))
+            self.assertEqual(lists[1], list(range(n)))
+
+        # test argument format checking
+        self.assertRaises(TypeError, tee)
+        self.assertRaises(TypeError, tee, 3)
+        self.assertRaises(TypeError, tee, [1,2], 'x')
+        self.assertRaises(TypeError, tee, [1,2], 3, 'x')
+
+        # tee object should be instantiable
+        a, b = tee('abc')
+        c = type(a)('def')
+        self.assertEqual(list(c), list('def'))
+
+        # test long-lagged and multi-way split
+        a, b, c = tee(range(2000), 3)
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        self.assertEqual(list(b), list(range(2000)))
+        self.assertEqual([next(c), next(c)], list(range(2)))
+        self.assertEqual(list(a), list(range(100,2000)))
+        self.assertEqual(list(c), list(range(2,2000)))
+
+        # test invalid values of n
+        self.assertRaises(TypeError, tee, 'abc', 'invalid')
+        self.assertRaises(ValueError, tee, [], -1)
+
+        for n in range(5):
+            result = tee('abc', n)
+            self.assertEqual(type(result), tuple)
+            self.assertEqual(len(result), n)
+            self.assertEqual([list(x) for x in result], [list('abc')]*n)
+
+        # tee objects are independent (see bug gh-123884)
+        a, b = tee('abc')
+        c, d = tee(a)
+        e, f = tee(c)
+        self.assertTrue(len({a, b, c, d, e, f}) == 6)
+
+        # test tee_new
+        t1, t2 = tee('abc')
+        tnew = type(t1)
+        self.assertRaises(TypeError, tnew)
+        self.assertRaises(TypeError, tnew, 10)
+        t3 = tnew(t1)
+        self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
+
+        # test that tee objects are weak referencable
+        a, b = tee(range(10))
+        p = weakref.proxy(a)
+        self.assertEqual(getattr(p, '__class__'), type(b))
+        del a
+        gc.collect()  # For PyPy or other GCs.
+        self.assertRaises(ReferenceError, getattr, p, '__class__')
+
+        ans = list('abc')
+        long_ans = list(range(10000))
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # check copy
+            a, b = tee('abc')
+            self.assertEqual(list(copy.copy(a)), ans)
+            self.assertEqual(list(copy.copy(b)), ans)
+            a, b = tee(list(range(10000)))
+            self.assertEqual(list(copy.copy(a)), long_ans)
+            self.assertEqual(list(copy.copy(b)), long_ans)
+
+            # check partially consumed copy
+            a, b = tee('abc')
+            take(2, a)
+            take(1, b)
+            self.assertEqual(list(copy.copy(a)), ans[2:])
+            self.assertEqual(list(copy.copy(b)), ans[1:])
+            self.assertEqual(list(a), ans[2:])
+            self.assertEqual(list(b), ans[1:])
+            a, b = tee(range(10000))
+            take(100, a)
+            take(60, b)
+            self.assertEqual(list(copy.copy(a)), long_ans[100:])
+            self.assertEqual(list(copy.copy(b)), long_ans[60:])
+            self.assertEqual(list(a), long_ans[100:])
+            self.assertEqual(list(b), long_ans[60:])
+
+        # Issue 13454: Crash when deleting backward iterator from tee()
+        forward, backward = tee(repeat(None, 2000)) # 20000000
+        try:
+            any(forward)  # exhaust the iterator
+            del backward
+        except:
+            del forward, backward
+            raise
+
+
 class TestGC(unittest.TestCase):
 
     def makecycle(self, iterator, container):
diff --git a/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst b/Misc/NEWS.d/next/Library/2024-09-24-22-38-51.gh-issue-123884.iEPTK4.rst
new file mode 100644 (file)
index 0000000..55f1d4b
--- /dev/null
@@ -0,0 +1,4 @@
+Fixed bug in itertools.tee() handling of other tee inputs (a tee in a tee).
+The output now has the promised *n* independent new iterators.  Formerly,
+the first iterator was identical (not independent) to the input iterator.
+This would sometimes give surprising results.
index d42f9dd07686584d5b35b38d42a3671d87550c99..e87c753113563fb306b85eb7ece0c57d51b0f889 100644 (file)
@@ -1137,7 +1137,7 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
 /*[clinic end generated code: output=1c64519cd859c2f0 input=c99a1472c425d66d]*/
 {
     Py_ssize_t i;
-    PyObject *it, *copyable, *copyfunc, *result;
+    PyObject *it, *to, *result;
 
     if (n < 0) {
         PyErr_SetString(PyExc_ValueError, "n must be >= 0");
@@ -1154,41 +1154,24 @@ itertools_tee_impl(PyObject *module, PyObject *iterable, Py_ssize_t n)
         return NULL;
     }
 
-    if (_PyObject_LookupAttr(it, &_Py_ID(__copy__), &copyfunc) < 0) {
-        Py_DECREF(it);
+    (void)&_Py_ID(__copy__); // Retain a reference to __copy__
+    itertools_state *state = get_module_state(module);
+    to = tee_fromiterable(state, it);
+    Py_DECREF(it);
+    if (to == NULL) {
         Py_DECREF(result);
         return NULL;
     }
-    if (copyfunc != NULL) {
-        copyable = it;
-    }
-    else {
-        itertools_state *state = get_module_state(module);
-        copyable = tee_fromiterable(state, it);
-        Py_DECREF(it);
-        if (copyable == NULL) {
-            Py_DECREF(result);
-            return NULL;
-        }
-        copyfunc = PyObject_GetAttr(copyable, &_Py_ID(__copy__));
-        if (copyfunc == NULL) {
-            Py_DECREF(copyable);
-            Py_DECREF(result);
-            return NULL;
-        }
-    }
 
-    PyTuple_SET_ITEM(result, 0, copyable);
+    PyTuple_SET_ITEM(result, 0, to);
     for (i = 1; i < n; i++) {
-        copyable = _PyObject_CallNoArgs(copyfunc);
-        if (copyable == NULL) {
-            Py_DECREF(copyfunc);
+        to = tee_copy((teeobject *)to, NULL);
+        if (to == NULL) {
             Py_DECREF(result);
             return NULL;
         }
-        PyTuple_SET_ITEM(result, i, copyable);
+        PyTuple_SET_ITEM(result, i, to);
     }
-    Py_DECREF(copyfunc);
     return result;
 }