]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-114271: Make `_thread.ThreadHandle` thread-safe in free-threaded builds (GH-115190)
authormpage <mpage@meta.com>
Fri, 1 Mar 2024 21:43:12 +0000 (13:43 -0800)
committerGitHub <noreply@github.com>
Fri, 1 Mar 2024 21:43:12 +0000 (13:43 -0800)
Make `_thread.ThreadHandle` thread-safe in free-threaded builds

We protect the mutable state of `ThreadHandle` using a `_PyOnceFlag`.
Concurrent operations (i.e. `join` or `detach`) on `ThreadHandle` block
until it is their turn to execute or an earlier operation succeeds.
Once an operation has been applied successfully all future operations
complete immediately.

The `join()` method is now idempotent. It may be called multiple times
but the underlying OS thread will only be joined once. After `join()`
succeeds, any future calls to `join()` will succeed immediately.

The internal thread handle `detach()` method has been removed.

Include/internal/pycore_lock.h
Lib/test/test_thread.py
Lib/threading.py
Modules/_threadmodule.c
Python/lock.c

index c89159b55e130f6bc3e9f8defd8c0db45f58a195..f648be496ea4af39c98e53d833d384e975e882ed 100644 (file)
@@ -136,6 +136,10 @@ typedef struct {
     uint8_t v;
 } PyEvent;
 
+// Check if the event is set without blocking. Returns 1 if the event is set or
+// 0 otherwise.
+PyAPI_FUNC(int) _PyEvent_IsSet(PyEvent *evt);
+
 // Set the event and notify any waiting threads.
 // Export for '_testinternalcapi' shared extension
 PyAPI_FUNC(void) _PyEvent_Notify(PyEvent *evt);
@@ -149,6 +153,15 @@ PyAPI_FUNC(void) PyEvent_Wait(PyEvent *evt);
 // and 0 if the timeout expired or thread was interrupted.
 PyAPI_FUNC(int) PyEvent_WaitTimed(PyEvent *evt, PyTime_t timeout_ns);
 
+// A one-time event notification with reference counting.
+typedef struct _PyEventRc {
+    PyEvent event;
+    Py_ssize_t refcount;
+} _PyEventRc;
+
+_PyEventRc *_PyEventRc_New(void);
+void _PyEventRc_Incref(_PyEventRc *erc);
+void _PyEventRc_Decref(_PyEventRc *erc);
 
 // _PyRawMutex implements a word-sized mutex that that does not depend on the
 // parking lot API, and therefore can be used in the parking lot
index 931cb4b797e0b21e80b2cfe5669b1d0872162680..83235230d5c11206f606abb5273765db5eabffcd 100644 (file)
@@ -189,8 +189,8 @@ class ThreadRunningTests(BasicThreadTest):
         with threading_helper.wait_threads_exit():
             handle = thread.start_joinable_thread(task)
             handle.join()
-            with self.assertRaisesRegex(ValueError, "not joinable"):
-                handle.join()
+            # Subsequent join() calls should succeed
+            handle.join()
 
     def test_joinable_not_joined(self):
         handle_destroyed = thread.allocate_lock()
@@ -233,58 +233,61 @@ class ThreadRunningTests(BasicThreadTest):
         with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"):
             raise errors[0]
 
-    def test_detach_from_self(self):
-        errors = []
-        handles = []
-        start_joinable_thread_returned = thread.allocate_lock()
-        start_joinable_thread_returned.acquire()
-        thread_detached = thread.allocate_lock()
-        thread_detached.acquire()
+    def test_join_then_self_join(self):
+        # make sure we can't deadlock in the following scenario with
+        # threads t0 and t1 (see comment in `ThreadHandle_join()` for more
+        # details):
+        #
+        # - t0 joins t1
+        # - t1 self joins
+        def make_lock():
+            lock = thread.allocate_lock()
+            lock.acquire()
+            return lock
+
+        error = None
+        self_joiner_handle = None
+        self_joiner_started = make_lock()
+        self_joiner_barrier = make_lock()
+        def self_joiner():
+            nonlocal error
+
+            self_joiner_started.release()
+            self_joiner_barrier.acquire()
 
-        def task():
-            start_joinable_thread_returned.acquire()
             try:
-                handles[0].detach()
+                self_joiner_handle.join()
             except Exception as e:
-                errors.append(e)
-            finally:
-                thread_detached.release()
+                error = e
+
+        joiner_started = make_lock()
+        def joiner():
+            joiner_started.release()
+            self_joiner_handle.join()
 
         with threading_helper.wait_threads_exit():
-            handle = thread.start_joinable_thread(task)
-            handles.append(handle)
-            start_joinable_thread_returned.release()
-            thread_detached.acquire()
-            with self.assertRaisesRegex(ValueError, "not joinable"):
-                handle.join()
+            self_joiner_handle = thread.start_joinable_thread(self_joiner)
+            # Wait for the self-joining thread to start
+            self_joiner_started.acquire()
 
-        assert len(errors) == 0
+            # Start the thread that joins the self-joiner
+            joiner_handle = thread.start_joinable_thread(joiner)
 
-    def test_detach_then_join(self):
-        lock = thread.allocate_lock()
-        lock.acquire()
+            # Wait for the joiner to start
+            joiner_started.acquire()
 
-        def task():
-            lock.acquire()
+            # Not great, but I don't think there's a deterministic way to make
+            # sure that the self-joining thread has been joined.
+            time.sleep(0.1)
 
-        with threading_helper.wait_threads_exit():
-            handle = thread.start_joinable_thread(task)
-            # detach() returns even though the thread is blocked on lock
-            handle.detach()
-            # join() then cannot be called anymore
-            with self.assertRaisesRegex(ValueError, "not joinable"):
-                handle.join()
-            lock.release()
-
-    def test_join_then_detach(self):
-        def task():
-            pass
+            # Unblock the self-joiner
+            self_joiner_barrier.release()
 
-        with threading_helper.wait_threads_exit():
-            handle = thread.start_joinable_thread(task)
-            handle.join()
-            with self.assertRaisesRegex(ValueError, "not joinable"):
-                handle.detach()
+            self_joiner_handle.join()
+            joiner_handle.join()
+
+            with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"):
+                raise error
 
 
 class Barrier:
index b6ff00acadd58fe7a38dad25e5ec3beeae458c0f..ec89550d6b022ee08bc095201215098ce77fc787 100644 (file)
@@ -931,7 +931,6 @@ class Thread:
         if _HAVE_THREAD_NATIVE_ID:
             self._native_id = None
         self._tstate_lock = None
-        self._join_lock = None
         self._handle = None
         self._started = Event()
         self._is_stopped = False
@@ -956,14 +955,11 @@ class Thread:
             if self._tstate_lock is not None:
                 self._tstate_lock._at_fork_reinit()
                 self._tstate_lock.acquire()
-            if self._join_lock is not None:
-                self._join_lock._at_fork_reinit()
         else:
             # This thread isn't alive after fork: it doesn't have a tstate
             # anymore.
             self._is_stopped = True
             self._tstate_lock = None
-            self._join_lock = None
             self._handle = None
 
     def __repr__(self):
@@ -996,8 +992,6 @@ class Thread:
         if self._started.is_set():
             raise RuntimeError("threads can only be started once")
 
-        self._join_lock = _allocate_lock()
-
         with _active_limbo_lock:
             _limbo[self] = self
         try:
@@ -1167,17 +1161,9 @@ class Thread:
             self._join_os_thread()
 
     def _join_os_thread(self):
-        join_lock = self._join_lock
-        if join_lock is None:
-            return
-        with join_lock:
-            # Calling join() multiple times would raise an exception
-            # in one of the callers.
-            if self._handle is not None:
-                self._handle.join()
-                self._handle = None
-                # No need to keep this around
-                self._join_lock = None
+        # self._handle may be cleared post-fork
+        if self._handle is not None:
+            self._handle.join()
 
     def _wait_for_tstate_lock(self, block=True, timeout=-1):
         # Issue #18808: wait for the thread state to be gone.
@@ -1478,6 +1464,10 @@ class _MainThread(Thread):
         with _active_limbo_lock:
             _active[self._ident] = self
 
+    def _join_os_thread(self):
+        # No ThreadHandle for main thread
+        pass
+
 
 # Helper thread-local instance to detect when a _DummyThread
 # is collected. Not a part of the public API.
index 4c2185cc7ea1fd306b20dc714dfda00e4a510413..3a8f77d6dfbbc6a9c4410db6f780f4b7f1f2a64b 100644 (file)
@@ -1,9 +1,9 @@
-
 /* Thread module */
 /* Interface to Sjoerd's portable C thread library */
 
 #include "Python.h"
 #include "pycore_interp.h"        // _PyInterpreterState.threads.count
+#include "pycore_lock.h"
 #include "pycore_moduleobject.h"  // _PyModule_GetState()
 #include "pycore_modsupport.h"    // _PyArg_NoKeywords()
 #include "pycore_pylifecycle.h"
@@ -44,24 +44,76 @@ get_thread_state(PyObject *module)
 
 // _ThreadHandle type
 
+// Handles transition from RUNNING to one of JOINED, DETACHED, or INVALID (post
+// fork).
+typedef enum {
+    THREAD_HANDLE_RUNNING = 1,
+    THREAD_HANDLE_JOINED = 2,
+    THREAD_HANDLE_DETACHED = 3,
+    THREAD_HANDLE_INVALID = 4,
+} ThreadHandleState;
+
+// A handle around an OS thread.
+//
+// The OS thread is either joined or detached after the handle is destroyed.
+//
+// Joining the handle is idempotent; the underlying OS thread is joined or
+// detached only once. Concurrent join operations are serialized until it is
+// their turn to execute or an earlier operation completes successfully. Once a
+// join has completed successfully all future joins complete immediately.
 typedef struct {
     PyObject_HEAD
     struct llist_node node;  // linked list node (see _pythread_runtime_state)
+
+    // The `ident` and `handle` fields are immutable once the object is visible
+    // to threads other than its creator, thus they do not need to be accessed
+    // atomically.
     PyThread_ident_t ident;
     PyThread_handle_t handle;
-    char joinable;
+
+    // Holds a value from the `ThreadHandleState` enum.
+    int state;
+
+    // Set immediately before `thread_run` returns to indicate that the OS
+    // thread is about to exit. This is used to avoid false positives when
+    // detecting self-join attempts. See the comment in `ThreadHandle_join()`
+    // for a more detailed explanation.
+    _PyEventRc *thread_is_exiting;
+
+    // Serializes calls to `join`.
+    _PyOnceFlag once;
 } ThreadHandleObject;
 
+static inline int
+get_thread_handle_state(ThreadHandleObject *handle)
+{
+    return _Py_atomic_load_int(&handle->state);
+}
+
+static inline void
+set_thread_handle_state(ThreadHandleObject *handle, ThreadHandleState state)
+{
+    _Py_atomic_store_int(&handle->state, state);
+}
+
 static ThreadHandleObject*
 new_thread_handle(thread_module_state* state)
 {
+    _PyEventRc *event = _PyEventRc_New();
+    if (event == NULL) {
+        PyErr_NoMemory();
+        return NULL;
+    }
     ThreadHandleObject* self = PyObject_New(ThreadHandleObject, state->thread_handle_type);
     if (self == NULL) {
+        _PyEventRc_Decref(event);
         return NULL;
     }
     self->ident = 0;
     self->handle = 0;
-    self->joinable = 0;
+    self->thread_is_exiting = event;
+    self->once = (_PyOnceFlag){0};
+    self->state = THREAD_HANDLE_INVALID;
 
     HEAD_LOCK(&_PyRuntime);
     llist_insert_tail(&_PyRuntime.threads.handles, &self->node);
@@ -82,13 +134,21 @@ ThreadHandle_dealloc(ThreadHandleObject *self)
     }
     HEAD_UNLOCK(&_PyRuntime);
 
-    if (self->joinable) {
-        int ret = PyThread_detach_thread(self->handle);
-        if (ret) {
+    // It's safe to access state non-atomically:
+    //   1. This is the destructor; nothing else holds a reference.
+    //   2. The refcount going to zero is a "synchronizes-with" event;
+    //      all changes from other threads are visible.
+    if (self->state == THREAD_HANDLE_RUNNING) {
+        // This is typically short so no need to release the GIL
+        if (PyThread_detach_thread(self->handle)) {
             PyErr_SetString(ThreadError, "Failed detaching thread");
             PyErr_WriteUnraisable(tp);
         }
+        else {
+            self->state = THREAD_HANDLE_DETACHED;
+        }
     }
+    _PyEventRc_Decref(self->thread_is_exiting);
     PyObject_Free(self);
     Py_DECREF(tp);
 }
@@ -109,8 +169,9 @@ _PyThread_AfterFork(struct _pythread_runtime_state *state)
             continue;
         }
 
-        // Disallow calls to detach() and join() as they could crash.
-        hobj->joinable = 0;
+        // Disallow calls to join() as they could crash. We are the only
+        // thread; it's safe to set this without an atomic.
+        hobj->state = THREAD_HANDLE_INVALID;
         llist_remove(node);
     }
 }
@@ -128,48 +189,54 @@ ThreadHandle_get_ident(ThreadHandleObject *self, void *ignored)
     return PyLong_FromUnsignedLongLong(self->ident);
 }
 
-
-static PyObject *
-ThreadHandle_detach(ThreadHandleObject *self, void* ignored)
+static int
+join_thread(ThreadHandleObject *handle)
 {
-    if (!self->joinable) {
-        PyErr_SetString(PyExc_ValueError,
-                        "the thread is not joinable and thus cannot be detached");
-        return NULL;
-    }
-    self->joinable = 0;
-    // This is typically short so no need to release the GIL
-    int ret = PyThread_detach_thread(self->handle);
-    if (ret) {
-        PyErr_SetString(ThreadError, "Failed detaching thread");
-        return NULL;
+    assert(get_thread_handle_state(handle) == THREAD_HANDLE_RUNNING);
+
+    int err;
+    Py_BEGIN_ALLOW_THREADS
+    err = PyThread_join_thread(handle->handle);
+    Py_END_ALLOW_THREADS
+    if (err) {
+        PyErr_SetString(ThreadError, "Failed joining thread");
+        return -1;
     }
-    Py_RETURN_NONE;
+    set_thread_handle_state(handle, THREAD_HANDLE_JOINED);
+    return 0;
 }
 
 static PyObject *
 ThreadHandle_join(ThreadHandleObject *self, void* ignored)
 {
-    if (!self->joinable) {
-        PyErr_SetString(PyExc_ValueError, "the thread is not joinable");
+    if (get_thread_handle_state(self) == THREAD_HANDLE_INVALID) {
+        PyErr_SetString(PyExc_ValueError,
+                        "the handle is invalid and thus cannot be joined");
         return NULL;
     }
-    if (self->ident == PyThread_get_thread_ident_ex()) {
+
+    // We want to perform this check outside of the `_PyOnceFlag` to prevent
+    // deadlock in the scenario where another thread joins us and we then
+    // attempt to join ourselves. However, it's not safe to check thread
+    // identity once the handle's os thread has finished. We may end up reusing
+    // the identity stored in the handle and erroneously think we are
+    // attempting to join ourselves.
+    //
+    // To work around this, we set `thread_is_exiting` immediately before
+    // `thread_run` returns.  We can be sure that we are not attempting to join
+    // ourselves if the handle's thread is about to exit.
+    if (!_PyEvent_IsSet(&self->thread_is_exiting->event) &&
+        self->ident == PyThread_get_thread_ident_ex()) {
         // PyThread_join_thread() would deadlock or error out.
         PyErr_SetString(ThreadError, "Cannot join current thread");
         return NULL;
     }
-    // Before actually joining, we must first mark the thread as non-joinable,
-    // as joining several times simultaneously or sequentially is undefined behavior.
-    self->joinable = 0;
-    int ret;
-    Py_BEGIN_ALLOW_THREADS
-    ret = PyThread_join_thread(self->handle);
-    Py_END_ALLOW_THREADS
-    if (ret) {
-        PyErr_SetString(ThreadError, "Failed joining thread");
+
+    if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)join_thread,
+                             self) == -1) {
         return NULL;
     }
+    assert(get_thread_handle_state(self) == THREAD_HANDLE_JOINED);
     Py_RETURN_NONE;
 }
 
@@ -180,7 +247,6 @@ static PyGetSetDef ThreadHandle_getsetlist[] = {
 
 static PyMethodDef ThreadHandle_methods[] =
 {
-    {"detach", (PyCFunction)ThreadHandle_detach, METH_NOARGS},
     {"join", (PyCFunction)ThreadHandle_join, METH_NOARGS},
     {0, 0}
 };
@@ -1210,11 +1276,15 @@ _localdummy_destroyed(PyObject *localweakref, PyObject *dummyweakref)
 
 /* Module functions */
 
+// bootstate is used to "bootstrap" new threads. Any arguments needed by
+// `thread_run()`, which can only take a single argument due to platform
+// limitations, are contained in bootstate.
 struct bootstate {
     PyThreadState *tstate;
     PyObject *func;
     PyObject *args;
     PyObject *kwargs;
+    _PyEventRc *thread_is_exiting;
 };
 
 
@@ -1226,6 +1296,9 @@ thread_bootstate_free(struct bootstate *boot, int decref)
         Py_DECREF(boot->args);
         Py_XDECREF(boot->kwargs);
     }
+    if (boot->thread_is_exiting != NULL) {
+        _PyEventRc_Decref(boot->thread_is_exiting);
+    }
     PyMem_RawFree(boot);
 }
 
@@ -1236,6 +1309,10 @@ thread_run(void *boot_raw)
     struct bootstate *boot = (struct bootstate *) boot_raw;
     PyThreadState *tstate = boot->tstate;
 
+    // `thread_is_exiting` needs to be set after bootstate has been freed
+    _PyEventRc *thread_is_exiting = boot->thread_is_exiting;
+    boot->thread_is_exiting = NULL;
+
     // gh-108987: If _thread.start_new_thread() is called before or while
     // Python is being finalized, thread_run() can called *after*.
     // _PyRuntimeState_SetFinalizing() is called. At this point, all Python
@@ -1280,6 +1357,11 @@ thread_run(void *boot_raw)
     _PyThreadState_DeleteCurrent(tstate);
 
 exit:
+    if (thread_is_exiting != NULL) {
+        _PyEvent_Notify(&thread_is_exiting->event);
+        _PyEventRc_Decref(thread_is_exiting);
+    }
+
     // bpo-44434: Don't call explicitly PyThread_exit_thread(). On Linux with
     // the glibc, pthread_exit() can abort the whole process if dlopen() fails
     // to open the libgcc_s.so library (ex: EMFILE error).
@@ -1308,7 +1390,8 @@ static int
 do_start_new_thread(thread_module_state* state,
                     PyObject *func, PyObject* args, PyObject* kwargs,
                     int joinable,
-                    PyThread_ident_t* ident, PyThread_handle_t* handle)
+                    PyThread_ident_t* ident, PyThread_handle_t* handle,
+                    _PyEventRc *thread_is_exiting)
 {
     PyInterpreterState *interp = _PyInterpreterState_GET();
     if (!_PyInterpreterState_HasFeature(interp, Py_RTFLAGS_THREADS)) {
@@ -1341,6 +1424,10 @@ do_start_new_thread(thread_module_state* state,
     boot->func = Py_NewRef(func);
     boot->args = Py_NewRef(args);
     boot->kwargs = Py_XNewRef(kwargs);
+    boot->thread_is_exiting = thread_is_exiting;
+    if (thread_is_exiting != NULL) {
+        _PyEventRc_Incref(thread_is_exiting);
+    }
 
     int err;
     if (joinable) {
@@ -1392,7 +1479,7 @@ thread_PyThread_start_new_thread(PyObject *module, PyObject *fargs)
     PyThread_ident_t ident = 0;
     PyThread_handle_t handle;
     if (do_start_new_thread(state, func, args, kwargs, /*joinable=*/ 0,
-                            &ident, &handle)) {
+                            &ident, &handle, NULL)) {
         return NULL;
     }
     return PyLong_FromUnsignedLongLong(ident);
@@ -1436,13 +1523,13 @@ thread_PyThread_start_joinable_thread(PyObject *module, PyObject *func)
         return NULL;
     }
     if (do_start_new_thread(state, func, args, /*kwargs=*/ NULL, /*joinable=*/ 1,
-                            &hobj->ident, &hobj->handle)) {
+                            &hobj->ident, &hobj->handle, hobj->thread_is_exiting)) {
         Py_DECREF(args);
         Py_DECREF(hobj);
         return NULL;
     }
+    set_thread_handle_state(hobj, THREAD_HANDLE_RUNNING);
     Py_DECREF(args);
-    hobj->joinable = 1;
     return (PyObject*) hobj;
 }
 
index 5fa8bf78da2380836f0250bf9308d7d119e60d23..de25adce38510504375516ddeb7b36cfebdeb90e 100644 (file)
@@ -249,6 +249,13 @@ _PyRawMutex_UnlockSlow(_PyRawMutex *m)
     }
 }
 
+int
+_PyEvent_IsSet(PyEvent *evt)
+{
+    uint8_t v = _Py_atomic_load_uint8(&evt->v);
+    return v == _Py_LOCKED;
+}
+
 void
 _PyEvent_Notify(PyEvent *evt)
 {
@@ -297,6 +304,30 @@ PyEvent_WaitTimed(PyEvent *evt, PyTime_t timeout_ns)
     }
 }
 
+_PyEventRc *
+_PyEventRc_New(void)
+{
+    _PyEventRc *erc = (_PyEventRc *)PyMem_RawCalloc(1, sizeof(_PyEventRc));
+    if (erc != NULL) {
+        erc->refcount = 1;
+    }
+    return erc;
+}
+
+void
+_PyEventRc_Incref(_PyEventRc *erc)
+{
+    _Py_atomic_add_ssize(&erc->refcount, 1);
+}
+
+void
+_PyEventRc_Decref(_PyEventRc *erc)
+{
+    if (_Py_atomic_add_ssize(&erc->refcount, -1) == 1) {
+        PyMem_RawFree(erc);
+    }
+}
+
 static int
 unlock_once(_PyOnceFlag *o, int res)
 {