]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-126907: make `atexit` thread safe in free-threading (#127935)
authorPeter Bierma <zintensitydev@gmail.com>
Mon, 16 Dec 2024 19:31:44 +0000 (14:31 -0500)
committerGitHub <noreply@github.com>
Mon, 16 Dec 2024 19:31:44 +0000 (19:31 +0000)
Co-authored-by: Victor Stinner <vstinner@python.org>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
Include/internal/pycore_atexit.h
Lib/test/test_atexit.py
Misc/NEWS.d/next/Library/2024-12-13-22-20-54.gh-issue-126907.fWRL_R.rst [new file with mode: 0644]
Modules/atexitmodule.c

index cde5b530baf00c8a3452d7eab826d26684b952e7..db1e5568e09413601fc661ff6aa986357e1b7f79 100644 (file)
@@ -36,23 +36,29 @@ typedef struct atexit_callback {
     struct atexit_callback *next;
 } atexit_callback;
 
-typedef struct {
-    PyObject *func;
-    PyObject *args;
-    PyObject *kwargs;
-} atexit_py_callback;
-
 struct atexit_state {
+#ifdef Py_GIL_DISABLED
+    PyMutex ll_callbacks_lock;
+#endif
     atexit_callback *ll_callbacks;
 
     // XXX The rest of the state could be moved to the atexit module state
     // and a low-level callback added for it during module exec.
     // For the moment we leave it here.
-    atexit_py_callback **callbacks;
-    int ncallbacks;
-    int callback_len;
+
+    // List containing tuples with callback information.
+    // e.g. [(func, args, kwargs), ...]
+    PyObject *callbacks;
 };
 
+#ifdef Py_GIL_DISABLED
+#  define _PyAtExit_LockCallbacks(state) PyMutex_Lock(&state->ll_callbacks_lock);
+#  define _PyAtExit_UnlockCallbacks(state) PyMutex_Unlock(&state->ll_callbacks_lock);
+#else
+#  define _PyAtExit_LockCallbacks(state)
+#  define _PyAtExit_UnlockCallbacks(state)
+#endif
+
 // Export for '_interpchannels' shared extension
 PyAPI_FUNC(int) _Py_AtExit(
     PyInterpreterState *interp,
index 913b7556be83aff8777599b40a64ccabc26d7575..eb01da6e88a8bcd3d2641fe1257333a24ae3d1c9 100644 (file)
@@ -4,7 +4,7 @@ import textwrap
 import unittest
 from test import support
 from test.support import script_helper
-
+from test.support import threading_helper
 
 class GeneralTest(unittest.TestCase):
     def test_general(self):
@@ -46,6 +46,39 @@ class FunctionalTest(unittest.TestCase):
         self.assertEqual(res.out.decode().splitlines(), ["atexit2", "atexit1"])
         self.assertFalse(res.err)
 
+    @threading_helper.requires_working_threading()
+    @support.requires_resource("cpu")
+    @unittest.skipUnless(support.Py_GIL_DISABLED, "only meaningful without the GIL")
+    def test_atexit_thread_safety(self):
+        # GH-126907: atexit was not thread safe on the free-threaded build
+        source = """
+        from threading import Thread
+
+        def dummy():
+            pass
+
+
+        def thready():
+            for _ in range(100):
+                atexit.register(dummy)
+                atexit._clear()
+                atexit.register(dummy)
+                atexit.unregister(dummy)
+                atexit._run_exitfuncs()
+
+
+        threads = [Thread(target=thready) for _ in range(10)]
+        for thread in threads:
+            thread.start()
+
+        for thread in threads:
+            thread.join()
+        """
+
+        # atexit._clear() has some evil side effects, and we don't
+        # want them to affect the rest of the tests.
+        script_helper.assert_python_ok("-c", textwrap.dedent(source))
+
 
 @support.cpython_only
 class SubinterpreterTest(unittest.TestCase):
diff --git a/Misc/NEWS.d/next/Library/2024-12-13-22-20-54.gh-issue-126907.fWRL_R.rst b/Misc/NEWS.d/next/Library/2024-12-13-22-20-54.gh-issue-126907.fWRL_R.rst
new file mode 100644 (file)
index 0000000..d33d2aa
--- /dev/null
@@ -0,0 +1,2 @@
+Fix crash when using :mod:`atexit` concurrently on the :term:`free-threaded
+<free threading>` build.
index c009235b7a36c20b62b79f5ef280210f0866784e..1b89b32ba907d7d926dd5b95b4d13ee7c0663c62 100644 (file)
@@ -41,6 +41,7 @@ PyUnstable_AtExit(PyInterpreterState *interp,
     callback->next = NULL;
 
     struct atexit_state *state = &interp->atexit;
+    _PyAtExit_LockCallbacks(state);
     atexit_callback *top = state->ll_callbacks;
     if (top == NULL) {
         state->ll_callbacks = callback;
@@ -49,36 +50,16 @@ PyUnstable_AtExit(PyInterpreterState *interp,
         callback->next = top;
         state->ll_callbacks = callback;
     }
+    _PyAtExit_UnlockCallbacks(state);
     return 0;
 }
 
 
-static void
-atexit_delete_cb(struct atexit_state *state, int i)
-{
-    atexit_py_callback *cb = state->callbacks[i];
-    state->callbacks[i] = NULL;
-
-    Py_DECREF(cb->func);
-    Py_DECREF(cb->args);
-    Py_XDECREF(cb->kwargs);
-    PyMem_Free(cb);
-}
-
-
 /* Clear all callbacks without calling them */
 static void
 atexit_cleanup(struct atexit_state *state)
 {
-    atexit_py_callback *cb;
-    for (int i = 0; i < state->ncallbacks; i++) {
-        cb = state->callbacks[i];
-        if (cb == NULL)
-            continue;
-
-        atexit_delete_cb(state, i);
-    }
-    state->ncallbacks = 0;
+    PyList_Clear(state->callbacks);
 }
 
 
@@ -89,23 +70,21 @@ _PyAtExit_Init(PyInterpreterState *interp)
     // _PyAtExit_Init() must only be called once
     assert(state->callbacks == NULL);
 
-    state->callback_len = 32;
-    state->ncallbacks = 0;
-    state->callbacks = PyMem_New(atexit_py_callback*, state->callback_len);
+    state->callbacks = PyList_New(0);
     if (state->callbacks == NULL) {
         return _PyStatus_NO_MEMORY();
     }
     return _PyStatus_OK();
 }
 
-
 void
 _PyAtExit_Fini(PyInterpreterState *interp)
 {
+    // In theory, there shouldn't be any threads left by now, so we
+    // won't lock this.
     struct atexit_state *state = &interp->atexit;
     atexit_cleanup(state);
-    PyMem_Free(state->callbacks);
-    state->callbacks = NULL;
+    Py_CLEAR(state->callbacks);
 
     atexit_callback *next = state->ll_callbacks;
     state->ll_callbacks = NULL;
@@ -120,35 +99,44 @@ _PyAtExit_Fini(PyInterpreterState *interp)
     }
 }
 
-
 static void
 atexit_callfuncs(struct atexit_state *state)
 {
     assert(!PyErr_Occurred());
+    assert(state->callbacks != NULL);
+    assert(PyList_CheckExact(state->callbacks));
 
-    if (state->ncallbacks == 0) {
+    // Create a copy of the list for thread safety
+    PyObject *copy = PyList_GetSlice(state->callbacks, 0, PyList_GET_SIZE(state->callbacks));
+    if (copy == NULL)
+    {
+        PyErr_WriteUnraisable(NULL);
         return;
     }
 
-    for (int i = state->ncallbacks - 1; i >= 0; i--) {
-        atexit_py_callback *cb = state->callbacks[i];
-        if (cb == NULL) {
-            continue;
-        }
+    for (Py_ssize_t i = 0; i < PyList_GET_SIZE(copy); ++i) {
+        // We don't have to worry about evil borrowed references, because
+        // no other threads can access this list.
+        PyObject *tuple = PyList_GET_ITEM(copy, i);
+        assert(PyTuple_CheckExact(tuple));
+
+        PyObject *func = PyTuple_GET_ITEM(tuple, 0);
+        PyObject *args = PyTuple_GET_ITEM(tuple, 1);
+        PyObject *kwargs = PyTuple_GET_ITEM(tuple, 2);
 
-        // bpo-46025: Increment the refcount of cb->func as the call itself may unregister it
-        PyObject* the_func = Py_NewRef(cb->func);
-        PyObject *res = PyObject_Call(cb->func, cb->args, cb->kwargs);
+        PyObject *res = PyObject_Call(func,
+                                      args,
+                                      kwargs == Py_None ? NULL : kwargs);
         if (res == NULL) {
             PyErr_FormatUnraisable(
-                "Exception ignored in atexit callback %R", the_func);
+                "Exception ignored in atexit callback %R", func);
         }
         else {
             Py_DECREF(res);
         }
-        Py_DECREF(the_func);
     }
 
+    Py_DECREF(copy);
     atexit_cleanup(state);
 
     assert(!PyErr_Occurred());
@@ -194,33 +182,27 @@ atexit_register(PyObject *module, PyObject *args, PyObject *kwargs)
                 "the first argument must be callable");
         return NULL;
     }
+    PyObject *func_args = PyTuple_GetSlice(args, 1, PyTuple_GET_SIZE(args));
+    PyObject *func_kwargs = kwargs;
 
-    struct atexit_state *state = get_atexit_state();
-    if (state->ncallbacks >= state->callback_len) {
-        atexit_py_callback **r;
-        state->callback_len += 16;
-        size_t size = sizeof(atexit_py_callback*) * (size_t)state->callback_len;
-        r = (atexit_py_callback**)PyMem_Realloc(state->callbacks, size);
-        if (r == NULL) {
-            return PyErr_NoMemory();
-        }
-        state->callbacks = r;
+    if (func_kwargs == NULL)
+    {
+        func_kwargs = Py_None;
     }
-
-    atexit_py_callback *callback = PyMem_Malloc(sizeof(atexit_py_callback));
-    if (callback == NULL) {
-        return PyErr_NoMemory();
+    PyObject *callback = PyTuple_Pack(3, func, func_args, func_kwargs);
+    if (callback == NULL)
+    {
+        return NULL;
     }
 
-    callback->args = PyTuple_GetSlice(args, 1, PyTuple_GET_SIZE(args));
-    if (callback->args == NULL) {
-        PyMem_Free(callback);
+    struct atexit_state *state = get_atexit_state();
+    // atexit callbacks go in a LIFO order
+    if (PyList_Insert(state->callbacks, 0, callback) < 0)
+    {
+        Py_DECREF(callback);
         return NULL;
     }
-    callback->func = Py_NewRef(func);
-    callback->kwargs = Py_XNewRef(kwargs);
-
-    state->callbacks[state->ncallbacks++] = callback;
+    Py_DECREF(callback);
 
     return Py_NewRef(func);
 }
@@ -264,7 +246,33 @@ static PyObject *
 atexit_ncallbacks(PyObject *module, PyObject *unused)
 {
     struct atexit_state *state = get_atexit_state();
-    return PyLong_FromSsize_t(state->ncallbacks);
+    assert(state->callbacks != NULL);
+    assert(PyList_CheckExact(state->callbacks));
+    return PyLong_FromSsize_t(PyList_GET_SIZE(state->callbacks));
+}
+
+static int
+atexit_unregister_locked(PyObject *callbacks, PyObject *func)
+{
+    for (Py_ssize_t i = 0; i < PyList_GET_SIZE(callbacks); ++i) {
+        PyObject *tuple = PyList_GET_ITEM(callbacks, i);
+        assert(PyTuple_CheckExact(tuple));
+        PyObject *to_compare = PyTuple_GET_ITEM(tuple, 0);
+        int cmp = PyObject_RichCompareBool(func, to_compare, Py_EQ);
+        if (cmp < 0)
+        {
+            return -1;
+        }
+        if (cmp == 1) {
+            // We found a callback!
+            if (PyList_SetSlice(callbacks, i, i + 1, NULL) < 0) {
+                return -1;
+            }
+            --i;
+        }
+    }
+
+    return 0;
 }
 
 PyDoc_STRVAR(atexit_unregister__doc__,
@@ -280,22 +288,11 @@ static PyObject *
 atexit_unregister(PyObject *module, PyObject *func)
 {
     struct atexit_state *state = get_atexit_state();
-    for (int i = 0; i < state->ncallbacks; i++)
-    {
-        atexit_py_callback *cb = state->callbacks[i];
-        if (cb == NULL) {
-            continue;
-        }
-
-        int eq = PyObject_RichCompareBool(cb->func, func, Py_EQ);
-        if (eq < 0) {
-            return NULL;
-        }
-        if (eq) {
-            atexit_delete_cb(state, i);
-        }
-    }
-    Py_RETURN_NONE;
+    int result;
+    Py_BEGIN_CRITICAL_SECTION(state->callbacks);
+    result = atexit_unregister_locked(state->callbacks, func);
+    Py_END_CRITICAL_SECTION();
+    return result < 0 ? NULL : Py_None;
 }