]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.13] gh-117657: Fix itertools.count thread safety (GH-119268) (#120007)
authorSam Gross <colesbury@gmail.com>
Mon, 3 Jun 2024 22:47:34 +0000 (18:47 -0400)
committerGitHub <noreply@github.com>
Mon, 3 Jun 2024 22:47:34 +0000 (22:47 +0000)
Fix itertools.count in free-threading mode
(cherry picked from commit 87939bd5790accea77c5a81093f16f28d3f0b429)

Co-authored-by: Arnon Yaari <wiggin15@yahoo.com>
Lib/test/test_itertools.py
Modules/itertoolsmodule.c
Tools/tsan/suppressions_free_threading.txt

index e243da309f03d8b6264a17491c42a2badcfa3d60..2c92d880c10cb31f3b9f8cc94918e749e711a4a1 100644 (file)
@@ -644,7 +644,7 @@ class TestBasicOps(unittest.TestCase):
         count(1, maxsize+5); sys.exc_info()
 
     @pickle_deprecated
-    def test_count_with_stride(self):
+    def test_count_with_step(self):
         self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
         self.assertEqual(lzip('abc',count(start=2,step=3)),
                          [('a', 2), ('b', 5), ('c', 8)])
@@ -699,6 +699,28 @@ class TestBasicOps(unittest.TestCase):
                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
                     self.pickletest(proto, count(i, j))
 
+    @threading_helper.requires_working_threading()
+    def test_count_threading(self, step=1):
+        # this test verifies multithreading consistency, which is
+        # mostly for testing builds without GIL, but nice to test anyway
+        count_to = 10_000
+        num_threads = 10
+        c = count(step=step)
+        def counting_thread():
+            for i in range(count_to):
+                next(c)
+        threads = []
+        for i in range(num_threads):
+            thread = threading.Thread(target=counting_thread)
+            thread.start()
+            threads.append(thread)
+        for thread in threads:
+            thread.join()
+        self.assertEqual(next(c), count_to * num_threads * step)
+
+    def test_count_with_step_threading(self):
+        self.test_count_threading(step=5)
+
     def test_cycle(self):
         self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
         self.assertEqual(list(cycle('')), [])
index 8641c2f87e6db202a9f37fcc4d5044be3afb6f74..0d6ff20489aa2c18905b3186ed3f762171174f9f 100644 (file)
@@ -1,13 +1,14 @@
 #include "Python.h"
-#include "pycore_call.h"          // _PyObject_CallNoArgs()
-#include "pycore_ceval.h"         // _PyEval_GetBuiltin()
-#include "pycore_long.h"          // _PyLong_GetZero()
-#include "pycore_moduleobject.h"  // _PyModule_GetState()
-#include "pycore_typeobject.h"    // _PyType_GetModuleState()
-#include "pycore_object.h"        // _PyObject_GC_TRACK()
-#include "pycore_tuple.h"         // _PyTuple_ITEMS()
+#include "pycore_call.h"              // _PyObject_CallNoArgs()
+#include "pycore_ceval.h"             // _PyEval_GetBuiltin()
+#include "pycore_critical_section.h"  // Py_BEGIN_CRITICAL_SECTION()
+#include "pycore_long.h"              // _PyLong_GetZero()
+#include "pycore_moduleobject.h"      // _PyModule_GetState()
+#include "pycore_typeobject.h"        // _PyType_GetModuleState()
+#include "pycore_object.h"            // _PyObject_GC_TRACK()
+#include "pycore_tuple.h"             // _PyTuple_ITEMS()
 
-#include <stddef.h>               // offsetof()
+#include <stddef.h>                   // offsetof()
 
 /* Itertools module written and maintained
    by Raymond D. Hettinger <python@rcn.com>
@@ -4037,7 +4038,7 @@ fast_mode:  when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
 
     assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
     Advances with:  cnt += 1
-    When count hits Y_SSIZE_T_MAX, switch to slow_mode.
+    When count hits PY_SSIZE_T_MAX, switch to slow_mode.
 
 slow_mode:  when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
 
@@ -4186,9 +4187,30 @@ count_nextlong(countobject *lz)
 static PyObject *
 count_next(countobject *lz)
 {
+#ifndef Py_GIL_DISABLED
     if (lz->cnt == PY_SSIZE_T_MAX)
         return count_nextlong(lz);
     return PyLong_FromSsize_t(lz->cnt++);
+#else
+    // free-threading version
+    // fast mode uses compare-exchange loop
+    // slow mode uses a critical section
+    PyObject *returned;
+    Py_ssize_t cnt;
+
+    cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
+    for (;;) {
+        if (cnt == PY_SSIZE_T_MAX) {
+            Py_BEGIN_CRITICAL_SECTION(lz);
+            returned = count_nextlong(lz);
+            Py_END_CRITICAL_SECTION();
+            return returned;
+        }
+        if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
+            return PyLong_FromSsize_t(cnt);
+        }
+    }
+#endif
 }
 
 static PyObject *
index 0bb0183147b72215f954f3d1b1ff4839ad2d3ea1..d5fcac61f0db042c66d9251af2c264ee0dbde708 100644 (file)
@@ -47,7 +47,6 @@ race_top:_PyImport_AcquireLock
 race_top:_Py_dict_lookup_threadsafe
 race_top:_imp_release_lock
 race_top:_multiprocessing_SemLock_acquire_impl
-race_top:count_next
 race_top:dictiter_new
 race_top:dictresize
 race_top:insert_to_emptydict