]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-129898: per-thread current task implementation in asyncio (#129899)
authorKumar Aditya <kumaraditya@python.org>
Wed, 19 Feb 2025 16:34:49 +0000 (22:04 +0530)
committerGitHub <noreply@github.com>
Wed, 19 Feb 2025 16:34:49 +0000 (16:34 +0000)
Store the current running task on the thread state, it makes it thread safe for the free-threading build and while improving performance as there is no lock contention, this effectively makes it lock free.
When accessing the current task of the current running loop in current thread, no locking is required and can be acessed without locking.
In the rare case of accessing current task of a loop running in a different thread, the stop the world pauses is used in free-threading builds to stop all other running threads and find the task for the specified loop.

This also makes it easier for external introspection to find the current task, and now it will be always correct.

Lib/asyncio/tasks.py
Modules/_asynciomodule.c

index 2d931040e57d156993640e4256b0fc0951ad3523..7f7ee81403a7599f3c4504c0bc84c5d308ccc819 100644 (file)
@@ -1110,7 +1110,7 @@ try:
     from _asyncio import (_register_task, _register_eager_task,
                           _unregister_task, _unregister_eager_task,
                           _enter_task, _leave_task, _swap_current_task,
-                          _scheduled_tasks, _eager_tasks, _current_tasks,
+                          _scheduled_tasks, _eager_tasks,
                           current_task, all_tasks)
 except ImportError:
     pass
index ff24fd76a617333b34d52b674fb15278410e4fbb..761c53a5e45fdd84556f1232c41eca469db30c81 100644 (file)
@@ -139,10 +139,6 @@ typedef struct {
     PyObject *asyncio_mod;
     PyObject *context_kwname;
 
-    /* Dictionary containing tasks that are currently active in
-       all running event loops.  {EventLoop: Task} */
-    PyObject *current_tasks;
-
     /* WeakSet containing scheduled 3rd party tasks which don't
        inherit from native asyncio.Task */
     PyObject *non_asyncio_tasks;
@@ -2061,8 +2057,6 @@ static int task_call_step_soon(asyncio_state *state, TaskObj *, PyObject *);
 static PyObject * task_wakeup(TaskObj *, PyObject *);
 static PyObject * task_step(asyncio_state *, TaskObj *, PyObject *);
 static int task_eager_start(asyncio_state *state, TaskObj *task);
-static inline void clear_ts_asyncio_running_task(PyObject *loop);
-static inline void set_ts_asyncio_running_task(PyObject *loop, PyObject *task);
 
 /* ----- Task._step wrapper */
 
@@ -2235,159 +2229,71 @@ unregister_eager_task(asyncio_state *state, PyObject *task)
 }
 
 static int
-enter_task(asyncio_state *state, PyObject *loop, PyObject *task)
+enter_task(PyObject *loop, PyObject *task)
 {
-    PyObject *item;
-    int res = PyDict_SetDefaultRef(state->current_tasks, loop, task, &item);
-    if (res < 0) {
+    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
+
+    if (ts->asyncio_running_loop != loop) {
+        PyErr_Format(PyExc_RuntimeError, "loop %R is not the running loop", loop);
         return -1;
     }
-    else if (res == 1) {
+
+    if (ts->asyncio_running_task != NULL) {
         PyErr_Format(
             PyExc_RuntimeError,
             "Cannot enter into task %R while another " \
             "task %R is being executed.",
-            task, item, NULL);
-        Py_DECREF(item);
+            task, ts->asyncio_running_task, NULL);
         return -1;
     }
 
-    assert(task == item);
-    Py_CLEAR(item);
-    set_ts_asyncio_running_task(loop, task);
+    ts->asyncio_running_task = Py_NewRef(task);
     return 0;
 }
 
 static int
-err_leave_task(PyObject *item, PyObject *task)
+leave_task(PyObject *loop, PyObject *task)
 {
-    PyErr_Format(
-        PyExc_RuntimeError,
-        "Leaving task %R does not match the current task %R.",
-        task, item);
-    return -1;
-}
-
-static int
-leave_task_predicate(PyObject *item, void *task)
-{
-    if (item != task) {
-        return err_leave_task(item, (PyObject *)task);
-    }
-    return 1;
-}
+    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
 
-static int
-leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
-{
-    int res = _PyDict_DelItemIf(state->current_tasks, loop,
-                                leave_task_predicate, task);
-    if (res == 0) {
-        // task was not found
-        return err_leave_task(Py_None, task);
+    if (ts->asyncio_running_loop != loop) {
+        PyErr_Format(PyExc_RuntimeError, "loop %R is not the running loop", loop);
+        return -1;
     }
-    clear_ts_asyncio_running_task(loop);
-    return res;
-}
 
-static PyObject *
-swap_current_task_lock_held(PyDictObject *current_tasks, PyObject *loop,
-                            Py_hash_t hash, PyObject *task)
-{
-    PyObject *prev_task;
-    if (_PyDict_GetItemRef_KnownHash_LockHeld(current_tasks, loop, hash, &prev_task) < 0) {
-        return NULL;
-    }
-    if (_PyDict_SetItem_KnownHash_LockHeld(current_tasks, loop, task, hash) < 0) {
-        Py_XDECREF(prev_task);
-        return NULL;
-    }
-    if (prev_task == NULL) {
-        Py_RETURN_NONE;
+    if (ts->asyncio_running_task != task) {
+        PyErr_Format(
+            PyExc_RuntimeError,
+            "Invalid attempt to leave task %R while " \
+            "task %R is entered.",
+            task, ts->asyncio_running_task ? ts->asyncio_running_task : Py_None, NULL);
+        return -1;
     }
-    return prev_task;
+    Py_CLEAR(ts->asyncio_running_task);
+    return 0;
 }
 
 static PyObject *
-swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
+swap_current_task(PyObject *loop, PyObject *task)
 {
-    PyObject *prev_task;
-
-    clear_ts_asyncio_running_task(loop);
-    if (task == Py_None) {
-        if (PyDict_Pop(state->current_tasks, loop, &prev_task) < 0) {
-            return NULL;
-        }
-        if (prev_task == NULL) {
-            Py_RETURN_NONE;
-        }
-        return prev_task;
-    }
+    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
 
-    Py_hash_t hash = PyObject_Hash(loop);
-    if (hash == -1) {
+    if (ts->asyncio_running_loop != loop) {
+        PyErr_Format(PyExc_RuntimeError, "loop %R is not the running loop", loop);
         return NULL;
     }
 
-    PyDictObject *current_tasks = (PyDictObject *)state->current_tasks;
-    Py_BEGIN_CRITICAL_SECTION(current_tasks);
-    prev_task = swap_current_task_lock_held(current_tasks, loop, hash, task);
-    Py_END_CRITICAL_SECTION();
-    set_ts_asyncio_running_task(loop, task);
-    return prev_task;
-}
-
-static inline void
-set_ts_asyncio_running_task(PyObject *loop, PyObject *task)
-{
-    // We want to enable debuggers and profilers to be able to quickly
-    // introspect the asyncio running state from another process.
-    // When we do that, we need to essentially traverse the address space
-    // of a Python process and understand what every Python thread in it is
-    // currently doing, mainly:
-    //
-    //  * current frame
-    //  * current asyncio task
-    //
-    // A naive solution would be to require profilers and debuggers to
-    // find the current task in the "_asynciomodule" module state, but
-    // unfortunately that would require a lot of complicated remote
-    // memory reads and logic, as Python's dict is a notoriously complex
-    // and ever-changing data structure.
-    //
-    // So the easier solution is to put a strong reference to the currently
-    // running `asyncio.Task` on the current thread state (the current loop
-    // is also stored there.)
-    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
-    if (ts->asyncio_running_loop == loop) {
-        // Protect from a situation when someone calls this method
-        // from another thread. This shouldn't ever happen though,
-        // as `enter_task` and `leave_task` can either be called by:
-        //
-        //  - `asyncio.Task` itself, in `Task.__step()`. That method
-        //    can only be called by the event loop itself.
-        //
-        //  - third-party Task "from scratch" implementations, that
-        //    our `capture_call_graph` API doesn't support anyway.
-        //
-        // That said, we still want to make sure we don't end up in
-        // a broken state, so we check that we're in the correct thread
-        // by comparing the *loop* argument to the event loop running
-        // in the current thread. If they match we know we're in the
-        // right thread, as asyncio event loops don't change threads.
-        assert(ts->asyncio_running_task == NULL);
+    /* transfer ownership to avoid redundant ref counting */
+    PyObject *prev_task = ts->asyncio_running_task;
+    if (task != Py_None) {
         ts->asyncio_running_task = Py_NewRef(task);
+    } else {
+        ts->asyncio_running_task = NULL;
     }
-}
-
-static inline void
-clear_ts_asyncio_running_task(PyObject *loop)
-{
-    // See comment in set_ts_asyncio_running_task() for details.
-    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
-    if (ts->asyncio_running_loop == NULL || ts->asyncio_running_loop == loop) {
-        Py_CLEAR(ts->asyncio_running_task);
+    if (prev_task == NULL) {
+        Py_RETURN_NONE;
     }
+    return prev_task;
 }
 
 /* ----- Task */
@@ -3539,7 +3445,7 @@ task_step(asyncio_state *state, TaskObj *task, PyObject *exc)
 {
     PyObject *res;
 
-    if (enter_task(state, task->task_loop, (PyObject*)task) < 0) {
+    if (enter_task(task->task_loop, (PyObject*)task) < 0) {
         return NULL;
     }
 
@@ -3547,12 +3453,12 @@ task_step(asyncio_state *state, TaskObj *task, PyObject *exc)
 
     if (res == NULL) {
         PyObject *exc = PyErr_GetRaisedException();
-        leave_task(state, task->task_loop, (PyObject*)task);
+        leave_task(task->task_loop, (PyObject*)task);
         _PyErr_ChainExceptions1(exc);
         return NULL;
     }
     else {
-        if (leave_task(state, task->task_loop, (PyObject*)task) < 0) {
+        if (leave_task(task->task_loop, (PyObject*)task) < 0) {
             Py_DECREF(res);
             return NULL;
         }
@@ -3566,7 +3472,7 @@ static int
 task_eager_start(asyncio_state *state, TaskObj *task)
 {
     assert(task != NULL);
-    PyObject *prevtask = swap_current_task(state, task->task_loop, (PyObject *)task);
+    PyObject *prevtask = swap_current_task(task->task_loop, (PyObject *)task);
     if (prevtask == NULL) {
         return -1;
     }
@@ -3595,7 +3501,7 @@ task_eager_start(asyncio_state *state, TaskObj *task)
         Py_DECREF(stepres);
     }
 
-    PyObject *curtask = swap_current_task(state, task->task_loop, prevtask);
+    PyObject *curtask = swap_current_task(task->task_loop, prevtask);
     Py_DECREF(prevtask);
     if (curtask == NULL) {
         retval = -1;
@@ -3907,8 +3813,7 @@ static PyObject *
 _asyncio__enter_task_impl(PyObject *module, PyObject *loop, PyObject *task)
 /*[clinic end generated code: output=a22611c858035b73 input=de1b06dca70d8737]*/
 {
-    asyncio_state *state = get_asyncio_state(module);
-    if (enter_task(state, loop, task) < 0) {
+    if (enter_task(loop, task) < 0) {
         return NULL;
     }
     Py_RETURN_NONE;
@@ -3932,8 +3837,7 @@ static PyObject *
 _asyncio__leave_task_impl(PyObject *module, PyObject *loop, PyObject *task)
 /*[clinic end generated code: output=0ebf6db4b858fb41 input=51296a46313d1ad8]*/
 {
-    asyncio_state *state = get_asyncio_state(module);
-    if (leave_task(state, loop, task) < 0) {
+    if (leave_task(loop, task) < 0) {
         return NULL;
     }
     Py_RETURN_NONE;
@@ -3957,7 +3861,7 @@ _asyncio__swap_current_task_impl(PyObject *module, PyObject *loop,
                                  PyObject *task)
 /*[clinic end generated code: output=9f88de958df74c7e input=c9c72208d3d38b6c]*/
 {
-    return swap_current_task(get_asyncio_state(module), loop, task);
+    return swap_current_task(loop, task);
 }
 
 
@@ -3974,9 +3878,6 @@ static PyObject *
 _asyncio_current_task_impl(PyObject *module, PyObject *loop)
 /*[clinic end generated code: output=fe15ac331a7f981a input=58910f61a5627112]*/
 {
-    PyObject *ret;
-    asyncio_state *state = get_asyncio_state(module);
-
     if (loop == Py_None) {
         loop = _asyncio_get_running_loop_impl(module);
         if (loop == NULL) {
@@ -3986,11 +3887,36 @@ _asyncio_current_task_impl(PyObject *module, PyObject *loop)
         Py_INCREF(loop);
     }
 
-    int rc = PyDict_GetItemRef(state->current_tasks, loop, &ret);
-    Py_DECREF(loop);
-    if (rc == 0) {
+    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
+    // Fast path for the current running loop of current thread
+    // no locking or stop the world pause is required
+    if (ts->asyncio_running_loop == loop) {
+        if (ts->asyncio_running_task != NULL) {
+            Py_DECREF(loop);
+            return Py_NewRef(ts->asyncio_running_task);
+        }
+        Py_DECREF(loop);
         Py_RETURN_NONE;
     }
+
+    PyObject *ret = Py_None;
+    // Stop the world and traverse the per-thread current tasks
+    // and return the task if the loop matches
+    PyInterpreterState *interp = ts->base.interp;
+    _PyEval_StopTheWorld(interp);
+    _Py_FOR_EACH_TSTATE_BEGIN(interp, p) {
+        ts = (_PyThreadStateImpl *)p;
+        if (ts->asyncio_running_loop == loop) {
+            if (ts->asyncio_running_task != NULL) {
+                ret = Py_NewRef(ts->asyncio_running_task);
+            }
+            goto exit;
+        }
+    }
+exit:
+    _Py_FOR_EACH_TSTATE_END(interp);
+    _PyEval_StartTheWorld(interp);
+    Py_DECREF(loop);
     return ret;
 }
 
@@ -4258,7 +4184,6 @@ module_traverse(PyObject *mod, visitproc visit, void *arg)
 
     Py_VISIT(state->non_asyncio_tasks);
     Py_VISIT(state->eager_tasks);
-    Py_VISIT(state->current_tasks);
     Py_VISIT(state->iscoroutine_typecache);
 
     Py_VISIT(state->context_kwname);
@@ -4289,7 +4214,6 @@ module_clear(PyObject *mod)
 
     Py_CLEAR(state->non_asyncio_tasks);
     Py_CLEAR(state->eager_tasks);
-    Py_CLEAR(state->current_tasks);
     Py_CLEAR(state->iscoroutine_typecache);
 
     Py_CLEAR(state->context_kwname);
@@ -4319,11 +4243,6 @@ module_init(asyncio_state *state)
         goto fail;
     }
 
-    state->current_tasks = PyDict_New();
-    if (state->current_tasks == NULL) {
-        goto fail;
-    }
-
     state->iscoroutine_typecache = PySet_New(NULL);
     if (state->iscoroutine_typecache == NULL) {
         goto fail;
@@ -4456,11 +4375,6 @@ module_exec(PyObject *mod)
         return -1;
     }
 
-    if (PyModule_AddObjectRef(mod, "_current_tasks", state->current_tasks) < 0) {
-        return -1;
-    }
-
-
     return 0;
 }