]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-109786: Fix leaks and crash when re-enter itertools.pairwise.__next__() (GH-109788)
authorSerhiy Storchaka <storchaka@gmail.com>
Mon, 4 Dec 2023 11:47:55 +0000 (13:47 +0200)
committerGitHub <noreply@github.com>
Mon, 4 Dec 2023 11:47:55 +0000 (11:47 +0000)
Lib/test/test_itertools.py
Misc/NEWS.d/next/Library/2023-09-23-14-40-51.gh-issue-109786.UX3pKv.rst [new file with mode: 0644]
Modules/itertoolsmodule.c

index 512745e45350d1094b0c6a08e7ae8ea5e31df3a9..705e880d98685e5b3523c1daabacecd4cf239e74 100644 (file)
@@ -1152,6 +1152,78 @@ class TestBasicOps(unittest.TestCase):
         with self.assertRaises(TypeError):
             pairwise(None)                                  # non-iterable argument
 
+    def test_pairwise_reenter(self):
+        def check(reenter_at, expected):
+            class I:
+                count = 0
+                def __iter__(self):
+                    return self
+                def __next__(self):
+                    self.count +=1
+                    if self.count in reenter_at:
+                        return next(it)
+                    return [self.count]  # new object
+
+            it = pairwise(I())
+            for item in expected:
+                self.assertEqual(next(it), item)
+
+        check({1}, [
+            (([2], [3]), [4]),
+            ([4], [5]),
+        ])
+        check({2}, [
+            ([1], ([1], [3])),
+            (([1], [3]), [4]),
+            ([4], [5]),
+        ])
+        check({3}, [
+            ([1], [2]),
+            ([2], ([2], [4])),
+            (([2], [4]), [5]),
+            ([5], [6]),
+        ])
+        check({1, 2}, [
+            ((([3], [4]), [5]), [6]),
+            ([6], [7]),
+        ])
+        check({1, 3}, [
+            (([2], ([2], [4])), [5]),
+            ([5], [6]),
+        ])
+        check({1, 4}, [
+            (([2], [3]), (([2], [3]), [5])),
+            ((([2], [3]), [5]), [6]),
+            ([6], [7]),
+        ])
+        check({2, 3}, [
+            ([1], ([1], ([1], [4]))),
+            (([1], ([1], [4])), [5]),
+            ([5], [6]),
+        ])
+
+    def test_pairwise_reenter2(self):
+        def check(maxcount, expected):
+            class I:
+                count = 0
+                def __iter__(self):
+                    return self
+                def __next__(self):
+                    if self.count >= maxcount:
+                        raise StopIteration
+                    self.count +=1
+                    if self.count == 1:
+                        return next(it, None)
+                    return [self.count]  # new object
+
+            it = pairwise(I())
+            self.assertEqual(list(it), expected)
+
+        check(1, [])
+        check(2, [])
+        check(3, [])
+        check(4, [(([2], [3]), [4])])
+
     def test_product(self):
         for args, result in [
             ([], [()]),                     # zero iterables
diff --git a/Misc/NEWS.d/next/Library/2023-09-23-14-40-51.gh-issue-109786.UX3pKv.rst b/Misc/NEWS.d/next/Library/2023-09-23-14-40-51.gh-issue-109786.UX3pKv.rst
new file mode 100644 (file)
index 0000000..07222fa
--- /dev/null
@@ -0,0 +1,2 @@
+Fix possible reference leaks and crash when re-enter the ``__next__()`` method of
+:class:`itertools.pairwise`.
index 4ed34aa5bde827bf293f09e9b313ecb91589998a..ab99fa4d873bf510eb377bf5615303ec16340c31 100644 (file)
@@ -330,21 +330,30 @@ pairwise_next(pairwiseobject *po)
         return NULL;
     }
     if (old == NULL) {
-        po->old = old = (*Py_TYPE(it)->tp_iternext)(it);
+        old = (*Py_TYPE(it)->tp_iternext)(it);
+        Py_XSETREF(po->old, old);
         if (old == NULL) {
             Py_CLEAR(po->it);
             return NULL;
         }
+        it = po->it;
+        if (it == NULL) {
+            Py_CLEAR(po->old);
+            return NULL;
+        }
     }
+    Py_INCREF(old);
     new = (*Py_TYPE(it)->tp_iternext)(it);
     if (new == NULL) {
         Py_CLEAR(po->it);
         Py_CLEAR(po->old);
+        Py_DECREF(old);
         return NULL;
     }
     /* Future optimization: Reuse the result tuple as we do in enumerate() */
     result = PyTuple_Pack(2, old, new);
-    Py_SETREF(po->old, new);
+    Py_XSETREF(po->old, new);
+    Py_DECREF(old);
     return result;
 }