]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-91048: Also clear and set ts->asyncio_running_task with eager tasks (#129197)
authorŁukasz Langa <lukasz@langa.pl>
Thu, 23 Jan 2025 18:26:36 +0000 (19:26 +0100)
committerGitHub <noreply@github.com>
Thu, 23 Jan 2025 18:26:36 +0000 (19:26 +0100)
This was missing from gh-124640. It's already covered by the new
test_asyncio/test_free_threading.py in combination with the runtime
assertion in set_ts_asyncio_running_task.

Co-authored-by: Kumar Aditya <kumaraditya@python.org>
Modules/_asynciomodule.c

index bba7416b398101e234580e33bbd78f4d19228575..c821860d9e4f70d08a6678bc5c1df53372605053 100644 (file)
@@ -2063,6 +2063,8 @@ static int task_call_step_soon(asyncio_state *state, TaskObj *, PyObject *);
 static PyObject * task_wakeup(TaskObj *, PyObject *);
 static PyObject * task_step(asyncio_state *, TaskObj *, PyObject *);
 static int task_eager_start(asyncio_state *state, TaskObj *task);
+static inline void clear_ts_asyncio_running_task(PyObject *loop);
+static inline void set_ts_asyncio_running_task(PyObject *loop, PyObject *task);
 
 /* ----- Task._step wrapper */
 
@@ -2236,47 +2238,7 @@ enter_task(asyncio_state *state, PyObject *loop, PyObject *task)
 
     assert(task == item);
     Py_CLEAR(item);
-
-    // This block is needed to enable `asyncio.capture_call_graph()` API.
-    // We want to be enable debuggers and profilers to be able to quickly
-    // introspect the asyncio running state from another process.
-    // When we do that, we need to essentially traverse the address space
-    // of a Python process and understand what every Python thread in it is
-    // currently doing, mainly:
-    //
-    //  * current frame
-    //  * current asyncio task
-    //
-    // A naive solution would be to require profilers and debuggers to
-    // find the current task in the "_asynciomodule" module state, but
-    // unfortunately that would require a lot of complicated remote
-    // memory reads and logic, as Python's dict is a notoriously complex
-    // and ever-changing data structure.
-    //
-    // So the easier solution is to put a strong reference to the currently
-    // running `asyncio.Task` on the interpreter thread state (we already
-    // have some asyncio state there.)
-    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
-    if (ts->asyncio_running_loop == loop) {
-        // Protect from a situation when someone calls this method
-        // from another thread. This shouldn't ever happen though,
-        // as `enter_task` and `leave_task` can either be called by:
-        //
-        //  - `asyncio.Task` itself, in `Task.__step()`. That method
-        //    can only be called by the event loop itself.
-        //
-        //  - third-party Task "from scratch" implementations, that
-        //    our `capture_call_graph` API doesn't support anyway.
-        //
-        // That said, we still want to make sure we don't end up in
-        // a broken state, so we check that we're in the correct thread
-        // by comparing the *loop* argument to the event loop running
-        // in the current thread. If they match we know we're in the
-        // right thread, as asyncio event loops don't change threads.
-        assert(ts->asyncio_running_task == NULL);
-        ts->asyncio_running_task = Py_NewRef(task);
-    }
-
+    set_ts_asyncio_running_task(loop, task);
     return 0;
 }
 
@@ -2308,14 +2270,7 @@ leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
         // task was not found
         return err_leave_task(Py_None, task);
     }
-
-    // See the comment in `enter_task` for the explanation of why
-    // the following is needed.
-    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
-    if (ts->asyncio_running_loop == NULL || ts->asyncio_running_loop == loop) {
-        Py_CLEAR(ts->asyncio_running_task);
-    }
-
+    clear_ts_asyncio_running_task(loop);
     return res;
 }
 
@@ -2342,6 +2297,7 @@ swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
 {
     PyObject *prev_task;
 
+    clear_ts_asyncio_running_task(loop);
     if (task == Py_None) {
         if (PyDict_Pop(state->current_tasks, loop, &prev_task) < 0) {
             return NULL;
@@ -2361,9 +2317,63 @@ swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
     Py_BEGIN_CRITICAL_SECTION(current_tasks);
     prev_task = swap_current_task_lock_held(current_tasks, loop, hash, task);
     Py_END_CRITICAL_SECTION();
+    set_ts_asyncio_running_task(loop, task);
     return prev_task;
 }
 
+static inline void
+set_ts_asyncio_running_task(PyObject *loop, PyObject *task)
+{
+    // We want to enable debuggers and profilers to be able to quickly
+    // introspect the asyncio running state from another process.
+    // When we do that, we need to essentially traverse the address space
+    // of a Python process and understand what every Python thread in it is
+    // currently doing, mainly:
+    //
+    //  * current frame
+    //  * current asyncio task
+    //
+    // A naive solution would be to require profilers and debuggers to
+    // find the current task in the "_asynciomodule" module state, but
+    // unfortunately that would require a lot of complicated remote
+    // memory reads and logic, as Python's dict is a notoriously complex
+    // and ever-changing data structure.
+    //
+    // So the easier solution is to put a strong reference to the currently
+    // running `asyncio.Task` on the current thread state (the current loop
+    // is also stored there.)
+    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
+    if (ts->asyncio_running_loop == loop) {
+        // Protect from a situation when someone calls this method
+        // from another thread. This shouldn't ever happen though,
+        // as `enter_task` and `leave_task` can either be called by:
+        //
+        //  - `asyncio.Task` itself, in `Task.__step()`. That method
+        //    can only be called by the event loop itself.
+        //
+        //  - third-party Task "from scratch" implementations, that
+        //    our `capture_call_graph` API doesn't support anyway.
+        //
+        // That said, we still want to make sure we don't end up in
+        // a broken state, so we check that we're in the correct thread
+        // by comparing the *loop* argument to the event loop running
+        // in the current thread. If they match we know we're in the
+        // right thread, as asyncio event loops don't change threads.
+        assert(ts->asyncio_running_task == NULL);
+        ts->asyncio_running_task = Py_NewRef(task);
+    }
+}
+
+static inline void
+clear_ts_asyncio_running_task(PyObject *loop)
+{
+    // See comment in set_ts_asyncio_running_task() for details.
+    _PyThreadStateImpl *ts = (_PyThreadStateImpl *)_PyThreadState_GET();
+    if (ts->asyncio_running_loop == NULL || ts->asyncio_running_loop == loop) {
+        Py_CLEAR(ts->asyncio_running_task);
+    }
+}
+
 /* ----- Task */
 
 /*[clinic input]