]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Update tests for the itertools docs rough equivalents (#120509)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Fri, 14 Jun 2024 16:00:46 +0000 (11:00 -0500)
committerGitHub <noreply@github.com>
Fri, 14 Jun 2024 16:00:46 +0000 (11:00 -0500)
Lib/test/test_itertools.py

index 53b8064c3cfe82012545848e37d5a7edd3860571..5fd6ecf37427f7d0e03ad34fea30834cad6dbd68 100644 (file)
@@ -1587,27 +1587,169 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
                 self.assertEqual(r1, r2)
                 self.assertEqual(e1, e2)
 
+
+    def test_groupby_recipe(self):
+
+        # Begin groupby() recipe #######################################
+
+        def groupby(iterable, key=None):
+            # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
+            # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D
+
+            keyfunc = (lambda x: x) if key is None else key
+            iterator = iter(iterable)
+            exhausted = False
+
+            def _grouper(target_key):
+                nonlocal curr_value, curr_key, exhausted
+                yield curr_value
+                for curr_value in iterator:
+                    curr_key = keyfunc(curr_value)
+                    if curr_key != target_key:
+                        return
+                    yield curr_value
+                exhausted = True
+
+            try:
+                curr_value = next(iterator)
+            except StopIteration:
+                return
+            curr_key = keyfunc(curr_value)
+
+            while not exhausted:
+                target_key = curr_key
+                curr_group = _grouper(target_key)
+                yield curr_key, curr_group
+                if curr_key == target_key:
+                    for _ in curr_group:
+                        pass
+
+        # End groupby() recipe #########################################
+
+        # Check whether it accepts arguments correctly
+        self.assertEqual([], list(groupby([])))
+        self.assertEqual([], list(groupby([], key=id)))
+        self.assertRaises(TypeError, list, groupby('abc', []))
+        if False:
+            # Test not applicable to the recipe
+            self.assertRaises(TypeError, list, groupby('abc', None))
+        self.assertRaises(TypeError, groupby, 'abc', lambda x:x, 10)
+
+        # Check normal input
+        s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22),
+             (2,15,22), (3,16,23), (3,17,23)]
+        dup = []
+        for k, g in groupby(s, lambda r:r[0]):
+            for elem in g:
+                self.assertEqual(k, elem[0])
+                dup.append(elem)
+        self.assertEqual(s, dup)
+
+        # Check nested case
+        dup = []
+        for k, g in groupby(s, testR):
+            for ik, ig in groupby(g, testR2):
+                for elem in ig:
+                    self.assertEqual(k, elem[0])
+                    self.assertEqual(ik, elem[2])
+                    dup.append(elem)
+        self.assertEqual(s, dup)
+
+        # Check case where inner iterator is not used
+        keys = [k for k, g in groupby(s, testR)]
+        expectedkeys = set([r[0] for r in s])
+        self.assertEqual(set(keys), expectedkeys)
+        self.assertEqual(len(keys), len(expectedkeys))
+
+        # Check case where inner iterator is used after advancing the groupby
+        # iterator
+        s = list(zip('AABBBAAAA', range(9)))
+        it = groupby(s, testR)
+        _, g1 = next(it)
+        _, g2 = next(it)
+        _, g3 = next(it)
+        self.assertEqual(list(g1), [])
+        self.assertEqual(list(g2), [])
+        self.assertEqual(next(g3), ('A', 5))
+        list(it)  # exhaust the groupby iterator
+        self.assertEqual(list(g3), [])
+
+        # Exercise pipes and filters style
+        s = 'abracadabra'
+        # sort s | uniq
+        r = [k for k, g in groupby(sorted(s))]
+        self.assertEqual(r, ['a', 'b', 'c', 'd', 'r'])
+        # sort s | uniq -d
+        r = [k for k, g in groupby(sorted(s)) if list(islice(g,1,2))]
+        self.assertEqual(r, ['a', 'b', 'r'])
+        # sort s | uniq -c
+        r = [(len(list(g)), k) for k, g in groupby(sorted(s))]
+        self.assertEqual(r, [(5, 'a'), (2, 'b'), (1, 'c'), (1, 'd'), (2, 'r')])
+        # sort s | uniq -c | sort -rn | head -3
+        r = sorted([(len(list(g)) , k) for k, g in groupby(sorted(s))], reverse=True)[:3]
+        self.assertEqual(r, [(5, 'a'), (2, 'r'), (2, 'b')])
+
+        # iter.__next__ failure
+        class ExpectedError(Exception):
+            pass
+        def delayed_raise(n=0):
+            for i in range(n):
+                yield 'yo'
+            raise ExpectedError
+        def gulp(iterable, keyp=None, func=list):
+            return [func(g) for k, g in groupby(iterable, keyp)]
+
+        # iter.__next__ failure on outer object
+        self.assertRaises(ExpectedError, gulp, delayed_raise(0))
+        # iter.__next__ failure on inner object
+        self.assertRaises(ExpectedError, gulp, delayed_raise(1))
+
+        # __eq__ failure
+        class DummyCmp:
+            def __eq__(self, dst):
+                raise ExpectedError
+        s = [DummyCmp(), DummyCmp(), None]
+
+        # __eq__ failure on outer object
+        self.assertRaises(ExpectedError, gulp, s, func=id)
+        # __eq__ failure on inner object
+        self.assertRaises(ExpectedError, gulp, s)
+
+        # keyfunc failure
+        def keyfunc(obj):
+            if keyfunc.skip > 0:
+                keyfunc.skip -= 1
+                return obj
+            else:
+                raise ExpectedError
+
+        # keyfunc failure on outer object
+        keyfunc.skip = 0
+        self.assertRaises(ExpectedError, gulp, [None], keyfunc)
+        keyfunc.skip = 1
+        self.assertRaises(ExpectedError, gulp, [None, None], keyfunc)
+
+
     @staticmethod
     def islice(iterable, *args):
+        # islice('ABCDEFG', 2) → A B
+        # islice('ABCDEFG', 2, 4) → C D
+        # islice('ABCDEFG', 2, None) → C D E F G
+        # islice('ABCDEFG', 0, None, 2) → A C E G
+
         s = slice(*args)
-        start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
-        it = iter(range(start, stop, step))
-        try:
-            nexti = next(it)
-        except StopIteration:
-            # Consume *iterable* up to the *start* position.
-            for i, element in zip(range(start), iterable):
-                pass
-            return
-        try:
-            for i, element in enumerate(iterable):
-                if i == nexti:
-                    yield element
-                    nexti = next(it)
-        except StopIteration:
-            # Consume to *stop*.
-            for i, element in zip(range(i + 1, stop), iterable):
-                pass
+        start = 0 if s.start is None else s.start
+        stop = s.stop
+        step = 1 if s.step is None else s.step
+        if start < 0 or (stop is not None and stop < 0) or step <= 0:
+            raise ValueError
+
+        indices = count() if stop is None else range(max(start, stop))
+        next_i = start
+        for i, element in zip(indices, iterable):
+            if i == next_i:
+                yield element
+                next_i += step
 
     def test_islice_recipe(self):
         self.assertEqual(list(self.islice('ABCDEFG', 2)), list('AB'))
@@ -1627,6 +1769,161 @@ class TestPurePythonRoughEquivalents(unittest.TestCase):
         self.assertEqual(next(c), 3)
 
 
+    def test_tee_recipe(self):
+
+        # Begin tee() recipe ###########################################
+
+        def tee(iterable, n=2):
+            iterator = iter(iterable)
+            shared_link = [None, None]
+            return tuple(_tee(iterator, shared_link) for _ in range(n))
+
+        def _tee(iterator, link):
+            try:
+                while True:
+                    if link[1] is None:
+                        link[0] = next(iterator)
+                        link[1] = [None, None]
+                    value, link = link
+                    yield value
+            except StopIteration:
+                return
+
+        # End tee() recipe #############################################
+
+        n = 200
+
+        a, b = tee([])        # test empty iterator
+        self.assertEqual(list(a), [])
+        self.assertEqual(list(b), [])
+
+        a, b = tee(irange(n)) # test 100% interleaved
+        self.assertEqual(lzip(a,b), lzip(range(n), range(n)))
+
+        a, b = tee(irange(n)) # test 0% interleaved
+        self.assertEqual(list(a), list(range(n)))
+        self.assertEqual(list(b), list(range(n)))
+
+        a, b = tee(irange(n)) # test dealloc of leading iterator
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        del a
+        self.assertEqual(list(b), list(range(n)))
+
+        a, b = tee(irange(n)) # test dealloc of trailing iterator
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        del b
+        self.assertEqual(list(a), list(range(100, n)))
+
+        for j in range(5):   # test randomly interleaved
+            order = [0]*n + [1]*n
+            random.shuffle(order)
+            lists = ([], [])
+            its = tee(irange(n))
+            for i in order:
+                value = next(its[i])
+                lists[i].append(value)
+            self.assertEqual(lists[0], list(range(n)))
+            self.assertEqual(lists[1], list(range(n)))
+
+        # test argument format checking
+        self.assertRaises(TypeError, tee)
+        self.assertRaises(TypeError, tee, 3)
+        self.assertRaises(TypeError, tee, [1,2], 'x')
+        self.assertRaises(TypeError, tee, [1,2], 3, 'x')
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # tee object should be instantiable
+            a, b = tee('abc')
+            c = type(a)('def')
+            self.assertEqual(list(c), list('def'))
+
+        # test long-lagged and multi-way split
+        a, b, c = tee(range(2000), 3)
+        for i in range(100):
+            self.assertEqual(next(a), i)
+        self.assertEqual(list(b), list(range(2000)))
+        self.assertEqual([next(c), next(c)], list(range(2)))
+        self.assertEqual(list(a), list(range(100,2000)))
+        self.assertEqual(list(c), list(range(2,2000)))
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # test invalid values of n
+            self.assertRaises(TypeError, tee, 'abc', 'invalid')
+            self.assertRaises(ValueError, tee, [], -1)
+
+        for n in range(5):
+            result = tee('abc', n)
+            self.assertEqual(type(result), tuple)
+            self.assertEqual(len(result), n)
+            self.assertEqual([list(x) for x in result], [list('abc')]*n)
+
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # tee pass-through to copyable iterator
+            a, b = tee('abc')
+            c, d = tee(a)
+            self.assertTrue(a is c)
+
+            # test tee_new
+            t1, t2 = tee('abc')
+            tnew = type(t1)
+            self.assertRaises(TypeError, tnew)
+            self.assertRaises(TypeError, tnew, 10)
+            t3 = tnew(t1)
+            self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
+
+        # test that tee objects are weak referencable
+        a, b = tee(range(10))
+        p = weakref.proxy(a)
+        self.assertEqual(getattr(p, '__class__'), type(b))
+        del a
+        gc.collect()  # For PyPy or other GCs.
+        self.assertRaises(ReferenceError, getattr, p, '__class__')
+
+        ans = list('abc')
+        long_ans = list(range(10000))
+
+        # Tests not applicable to the tee() recipe
+        if False:
+            # check copy
+            a, b = tee('abc')
+            self.assertEqual(list(copy.copy(a)), ans)
+            self.assertEqual(list(copy.copy(b)), ans)
+            a, b = tee(list(range(10000)))
+            self.assertEqual(list(copy.copy(a)), long_ans)
+            self.assertEqual(list(copy.copy(b)), long_ans)
+
+            # check partially consumed copy
+            a, b = tee('abc')
+            take(2, a)
+            take(1, b)
+            self.assertEqual(list(copy.copy(a)), ans[2:])
+            self.assertEqual(list(copy.copy(b)), ans[1:])
+            self.assertEqual(list(a), ans[2:])
+            self.assertEqual(list(b), ans[1:])
+            a, b = tee(range(10000))
+            take(100, a)
+            take(60, b)
+            self.assertEqual(list(copy.copy(a)), long_ans[100:])
+            self.assertEqual(list(copy.copy(b)), long_ans[60:])
+            self.assertEqual(list(a), long_ans[100:])
+            self.assertEqual(list(b), long_ans[60:])
+
+        # Issue 13454: Crash when deleting backward iterator from tee()
+        forward, backward = tee(repeat(None, 2000)) # 20000000
+        try:
+            any(forward)  # exhaust the iterator
+            del backward
+        except:
+            del forward, backward
+            raise
+
+
 class TestGC(unittest.TestCase):
 
     def makecycle(self, iterator, container):