]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-117657: Fix itertools.count thread safety (#119268)
authorArnon Yaari <wiggin15@yahoo.com>
Tue, 21 May 2024 17:16:34 +0000 (13:16 -0400)
committerGitHub <noreply@github.com>
Tue, 21 May 2024 17:16:34 +0000 (10:16 -0700)
Fix itertools.count in free-threading mode

Lib/test/test_itertools.py
Modules/itertoolsmodule.c
Tools/tsan/suppressions_free_threading.txt

index 4d2c01886724a80c07d62aa2ebbb9c913f985996..53b8064c3cfe82012545848e37d5a7edd3860571 100644 (file)
@@ -546,7 +546,7 @@ class TestBasicOps(unittest.TestCase):
         #check proper internal error handling for large "step' sizes
         count(1, maxsize+5); sys.exc_info()
 
-    def test_count_with_stride(self):
+    def test_count_with_step(self):
         self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
         self.assertEqual(lzip('abc',count(start=2,step=3)),
                          [('a', 2), ('b', 5), ('c', 8)])
@@ -590,6 +590,28 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(type(next(c)), int)
         self.assertEqual(type(next(c)), float)
 
+    @threading_helper.requires_working_threading()
+    def test_count_threading(self, step=1):
+        # this test verifies multithreading consistency, which is
+        # mostly for testing builds without GIL, but nice to test anyway
+        count_to = 10_000
+        num_threads = 10
+        c = count(step=step)
+        def counting_thread():
+            for i in range(count_to):
+                next(c)
+        threads = []
+        for i in range(num_threads):
+            thread = threading.Thread(target=counting_thread)
+            thread.start()
+            threads.append(thread)
+        for thread in threads:
+            thread.join()
+        self.assertEqual(next(c), count_to * num_threads * step)
+
+    def test_count_with_step_threading(self):
+        self.test_count_threading(step=5)
+
     def test_cycle(self):
         self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
         self.assertEqual(list(cycle('')), [])
index ae316d9e369d458fa93dcced2b0a44baf711c1fe..e740ec4d7625c3bc555a3f2c9f41f4dc33994531 100644 (file)
@@ -1,13 +1,14 @@
 #include "Python.h"
-#include "pycore_call.h"          // _PyObject_CallNoArgs()
-#include "pycore_ceval.h"         // _PyEval_GetBuiltin()
-#include "pycore_long.h"          // _PyLong_GetZero()
-#include "pycore_moduleobject.h"  // _PyModule_GetState()
-#include "pycore_typeobject.h"    // _PyType_GetModuleState()
-#include "pycore_object.h"        // _PyObject_GC_TRACK()
-#include "pycore_tuple.h"         // _PyTuple_ITEMS()
+#include "pycore_call.h"              // _PyObject_CallNoArgs()
+#include "pycore_ceval.h"             // _PyEval_GetBuiltin()
+#include "pycore_critical_section.h"  // Py_BEGIN_CRITICAL_SECTION()
+#include "pycore_long.h"              // _PyLong_GetZero()
+#include "pycore_moduleobject.h"      // _PyModule_GetState()
+#include "pycore_typeobject.h"        // _PyType_GetModuleState()
+#include "pycore_object.h"            // _PyObject_GC_TRACK()
+#include "pycore_tuple.h"             // _PyTuple_ITEMS()
 
-#include <stddef.h>               // offsetof()
+#include <stddef.h>                   // offsetof()
 
 /* Itertools module written and maintained
    by Raymond D. Hettinger <python@rcn.com>
@@ -3254,7 +3255,7 @@ fast_mode:  when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
 
     assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
     Advances with:  cnt += 1
-    When count hits Y_SSIZE_T_MAX, switch to slow_mode.
+    When count hits PY_SSIZE_T_MAX, switch to slow_mode.
 
 slow_mode:  when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
 
@@ -3403,9 +3404,30 @@ count_nextlong(countobject *lz)
 static PyObject *
 count_next(countobject *lz)
 {
+#ifndef Py_GIL_DISABLED
     if (lz->cnt == PY_SSIZE_T_MAX)
         return count_nextlong(lz);
     return PyLong_FromSsize_t(lz->cnt++);
+#else
+    // free-threading version
+    // fast mode uses compare-exchange loop
+    // slow mode uses a critical section
+    PyObject *returned;
+    Py_ssize_t cnt;
+
+    cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
+    for (;;) {
+        if (cnt == PY_SSIZE_T_MAX) {
+            Py_BEGIN_CRITICAL_SECTION(lz);
+            returned = count_nextlong(lz);
+            Py_END_CRITICAL_SECTION();
+            return returned;
+        }
+        if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
+            return PyLong_FromSsize_t(cnt);
+        }
+    }
+#endif
 }
 
 static PyObject *
index dfa4a1fe9ca43819e978b1e9afdb4038fb4aa34f..cda57d78067bb32da615c37332b57058bba5765b 100644 (file)
@@ -56,7 +56,6 @@ race_top:_Py_dict_lookup_threadsafe
 race_top:_imp_release_lock
 race_top:_multiprocessing_SemLock_acquire_impl
 race_top:builtin_compile_impl
-race_top:count_next
 race_top:dictiter_new
 race_top:dictresize
 race_top:insert_to_emptydict