]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.13] Small improvements to the itertools docs (GH-123885) (#125075)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Mon, 7 Oct 2024 23:08:09 +0000 (18:08 -0500)
committerGitHub <noreply@github.com>
Mon, 7 Oct 2024 23:08:09 +0000 (23:08 +0000)
Doc/library/itertools.rst
Lib/test/test_itertools.py

index 43e665c3f0d5e377c70b32642f1a2a4e23e39053..ceaab2f3e724541a1737d312f54ed6df52191760 100644 (file)
@@ -474,7 +474,7 @@ loops that truncate the stream.
    If *start* is zero or ``None``, iteration starts at zero.  Otherwise,
    elements from the iterable are skipped until *start* is reached.
 
-   If *stop* is ``None``, iteration continues until the iterator is
+   If *stop* is ``None``, iteration continues until the iterable is
    exhausted, if at all.  Otherwise, it stops at the specified position.
 
    If *step* is ``None``, the step defaults to one.  Elements are returned
@@ -503,6 +503,10 @@ loops that truncate the stream.
                   yield element
                   next_i += step
 
+   If the input is an iterator, then fully consuming the *islice*
+   advances the input iterator by ``max(start, stop)`` steps regardless
+   of the *step* value.
+
 
 .. function:: pairwise(iterable)
 
@@ -601,6 +605,8 @@ loops that truncate the stream.
            # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
            # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111
 
+           if repeat < 0:
+               raise ValueError('repeat argument cannot be negative')
            pools = [tuple(pool) for pool in iterables] * repeat
 
            result = [[]]
@@ -684,6 +690,8 @@ loops that truncate the stream.
    Roughly equivalent to::
 
         def tee(iterable, n=2):
+            if n < 0:
+                raise ValueError('n must be >= 0')
             iterator = iter(iterable)
             shared_link = [None, None]
             return tuple(_tee(iterator, shared_link) for _ in range(n))
@@ -703,6 +711,12 @@ loops that truncate the stream.
    used anywhere else; otherwise, the *iterable* could get advanced without
    the tee objects being informed.
 
+   When the input *iterable* is already a tee iterator object, all
+   members of the return tuple are constructed as if they had been
+   produced by the upstream :func:`tee` call.  This "flattening step"
+   allows nested :func:`tee` calls to share the same underlying data
+   chain and to have a single update step rather than a chain of calls.
+
    ``tee`` iterators are not threadsafe. A :exc:`RuntimeError` may be
    raised when simultaneously using iterators returned by the same :func:`tee`
    call, even if the original *iterable* is threadsafe.
index 2c92d880c10cb31f3b9f8cc94918e749e711a4a1..2c8752d215dc696e4d443e7b5929bcd88b944d95 100644 (file)
@@ -1288,12 +1288,16 @@ class TestBasicOps(unittest.TestCase):
                 else:
                     return
 
-        def product2(*args, **kwds):
+        def product2(*iterables, repeat=1):
             'Pure python version used in docs'
-            pools = list(map(tuple, args)) * kwds.get('repeat', 1)
+            if repeat < 0:
+                raise ValueError('repeat argument cannot be negative')
+            pools = [tuple(pool) for pool in iterables] * repeat
+
             result = [[]]
             for pool in pools:
                 result = [x+[y] for x in result for y in pool]
+
             for prod in result:
                 yield tuple(prod)
 
@@ -2062,6 +2066,161 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
         self.assertEqual(next(c), 3)
 
 
+    def test_tee_recipe(self):
+
+        # Begin tee() recipe ###########################################
+
+        def tee(iterable, n=2):
+            if n < 0:
+                raise ValueError('n must be >= 0')
+            iterator = iter(iterable)
+            shared_link = [None, None]
+            return tuple(_tee(iterator, shared_link) for _ in range(n))
+
+        def _tee(iterator, link):
+            try:
+                while True:
+                    if link[1] is None:
+                        link[0] = next(iterator)
+                        link[1] = [None, None]
+                    value, link = link
+                    yield value
+            except StopIteration:
+                return
+
+        # End tee() recipe #############################################
+
+        n = 200
+
+        a, b = tee([])        # test empty iterator
+        self.assertEqual(list(a), [])
+        self.assertEqual(list(b), [])
+
+        a, b = tee(irange(n)) # test 100% interleaved
+        self.assertEqual(lzip(a,b), lzip(range(n), range(n)))
+
+        a, b = tee(irange(n)) # test 0% interleaved
+        self.assertEqual(list(a), list(range(n)))
+        self.assertEqual(list(b), list(range(n)))
+
+        a, b = tee(irange(n)) # test dealloc of leading iterator
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        del a
+        self.assertEqual(list(b), list(range(n)))
+
+        a, b = tee(irange(n)) # test dealloc of trailing iterator
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        del b
+        self.assertEqual(list(a), list(range(100, n)))
+
+        for j in range(5):   # test randomly interleaved
+            order = [0]*n + [1]*n
+            random.shuffle(order)
+            lists = ([], [])
+            its = tee(irange(n))
+            for i in order:
+                value = next(its[i])
+                lists[i].append(value)
+            self.assertEqual(lists[0], list(range(n)))
+            self.assertEqual(lists[1], list(range(n)))
+
+        # test argument format checking
+        self.assertRaises(TypeError, tee)
+        self.assertRaises(TypeError, tee, 3)
+        self.assertRaises(TypeError, tee, [1,2], 'x')
+        self.assertRaises(TypeError, tee, [1,2], 3, 'x')
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # tee object should be instantiable
+            a, b = tee('abc')
+            c = type(a)('def')
+            self.assertEqual(list(c), list('def'))
+
+        # test long-lagged and multi-way split
+        a, b, c = tee(range(2000), 3)
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        self.assertEqual(list(b), list(range(2000)))
+        self.assertEqual([next(c), next(c)], list(range(2)))
+        self.assertEqual(list(a), list(range(100,2000)))
+        self.assertEqual(list(c), list(range(2,2000)))
+
+        # test invalid values of n
+        self.assertRaises(TypeError, tee, 'abc', 'invalid')
+        self.assertRaises(ValueError, tee, [], -1)
+
+        for n in range(5):
+            result = tee('abc', n)
+            self.assertEqual(type(result), tuple)
+            self.assertEqual(len(result), n)
+            self.assertEqual([list(x) for x in result], [list('abc')]*n)
+
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # tee pass-through to copyable iterator
+            a, b = tee('abc')
+            c, d = tee(a)
+            self.assertTrue(a is c)
+
+            # test tee_new
+            t1, t2 = tee('abc')
+            tnew = type(t1)
+            self.assertRaises(TypeError, tnew)
+            self.assertRaises(TypeError, tnew, 10)
+            t3 = tnew(t1)
+            self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
+
+        # test that tee objects are weak referencable
+        a, b = tee(range(10))
+        p = weakref.proxy(a)
+        self.assertEqual(getattr(p, '__class__'), type(b))
+        del a
+        gc.collect()  # For PyPy or other GCs.
+        self.assertRaises(ReferenceError, getattr, p, '__class__')
+
+        ans = list('abc')
+        long_ans = list(range(10000))
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # check copy
+            a, b = tee('abc')
+            self.assertEqual(list(copy.copy(a)), ans)
+            self.assertEqual(list(copy.copy(b)), ans)
+            a, b = tee(list(range(10000)))
+            self.assertEqual(list(copy.copy(a)), long_ans)
+            self.assertEqual(list(copy.copy(b)), long_ans)
+
+            # check partially consumed copy
+            a, b = tee('abc')
+            take(2, a)
+            take(1, b)
+            self.assertEqual(list(copy.copy(a)), ans[2:])
+            self.assertEqual(list(copy.copy(b)), ans[1:])
+            self.assertEqual(list(a), ans[2:])
+            self.assertEqual(list(b), ans[1:])
+            a, b = tee(range(10000))
+            take(100, a)
+            take(60, b)
+            self.assertEqual(list(copy.copy(a)), long_ans[100:])
+            self.assertEqual(list(copy.copy(b)), long_ans[60:])
+            self.assertEqual(list(a), long_ans[100:])
+            self.assertEqual(list(b), long_ans[60:])
+
+        # Issue 13454: Crash when deleting backward iterator from tee()
+        forward, backward = tee(repeat(None, 2000)) # 20000000
+        try:
+            any(forward)  # exhaust the iterator
+            del backward
+        except:
+            del forward, backward
+            raise
+
+
 class TestGC(unittest.TestCase):
 
     def makecycle(self, iterator, container):