]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-120974: Make asyncio `swap_current_task` safe in free-threaded build (#122317)
authorSam Gross <colesbury@gmail.com>
Fri, 2 Aug 2024 13:32:08 +0000 (09:32 -0400)
committerGitHub <noreply@github.com>
Fri, 2 Aug 2024 13:32:08 +0000 (19:02 +0530)
* gh-120974: Make asyncio `swap_current_task` safe in free-threaded build

Include/internal/pycore_dict.h
Modules/_asynciomodule.c
Objects/dictobject.c

index fc304aca7fea104c51b64bc4f581aefc31e635f1..a84246ee34efff4a897e811bbe52d311a026a218 100644 (file)
@@ -108,8 +108,13 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec
 /* Consumes references to key and value */
 PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value);
 extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value);
-extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
+// Export for '_asyncio' shared extension
+PyAPI_FUNC(int) _PyDict_SetItem_KnownHash_LockHeld(PyDictObject *mp, PyObject *key,
+                                                   PyObject *value, Py_hash_t hash);
+// Export for '_asyncio' shared extension
+PyAPI_FUNC(int) _PyDict_GetItemRef_KnownHash_LockHeld(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
 extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
+extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
 extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value);
 
 extern int _PyDict_Pop_KnownHash(
index 873c17cd78709d47d4e2890f1d20965e5dd89aaf..c6eb43f044fdbd9fee1a9044b23dfa50fad302b1 100644 (file)
@@ -2026,6 +2026,24 @@ leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
     return res;
 }
 
+static PyObject *
+swap_current_task_lock_held(PyDictObject *current_tasks, PyObject *loop,
+                            Py_hash_t hash, PyObject *task)
+{
+    PyObject *prev_task;
+    if (_PyDict_GetItemRef_KnownHash_LockHeld(current_tasks, loop, hash, &prev_task) < 0) {
+        return NULL;
+    }
+    if (_PyDict_SetItem_KnownHash_LockHeld(current_tasks, loop, task, hash) < 0) {
+        Py_XDECREF(prev_task);
+        return NULL;
+    }
+    if (prev_task == NULL) {
+        Py_RETURN_NONE;
+    }
+    return prev_task;
+}
+
 static PyObject *
 swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
 {
@@ -2041,24 +2059,15 @@ swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
         return prev_task;
     }
 
-    Py_hash_t hash;
-    hash = PyObject_Hash(loop);
+    Py_hash_t hash = PyObject_Hash(loop);
     if (hash == -1) {
         return NULL;
     }
-    prev_task = _PyDict_GetItem_KnownHash(state->current_tasks, loop, hash);
-    if (prev_task == NULL) {
-        if (PyErr_Occurred()) {
-            return NULL;
-        }
-        prev_task = Py_None;
-    }
-    Py_INCREF(prev_task);
-    if (_PyDict_SetItem_KnownHash(state->current_tasks, loop, task, hash) == -1) {
-        Py_DECREF(prev_task);
-        return NULL;
-    }
 
+    PyDictObject *current_tasks = (PyDictObject *)state->current_tasks;
+    Py_BEGIN_CRITICAL_SECTION(current_tasks);
+    prev_task = swap_current_task_lock_held(current_tasks, loop, hash, task);
+    Py_END_CRITICAL_SECTION();
     return prev_task;
 }
 
index 6a16a04102a6c0e96c719a00161c31b2e0768456..3e9f982ae070a302450dbca55f3f5a275fbd00f4 100644 (file)
@@ -2216,6 +2216,29 @@ _PyDict_GetItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash)
     return value;  // borrowed reference
 }
 
+/* Gets an item and provides a new reference if the value is present.
+ * Returns 1 if the key is present, 0 if the key is missing, and -1 if an
+ * exception occurred.
+*/
+int
+_PyDict_GetItemRef_KnownHash_LockHeld(PyDictObject *op, PyObject *key,
+                                      Py_hash_t hash, PyObject **result)
+{
+    PyObject *value;
+    Py_ssize_t ix = _Py_dict_lookup(op, key, hash, &value);
+    assert(ix >= 0 || value == NULL);
+    if (ix == DKIX_ERROR) {
+        *result = NULL;
+        return -1;
+    }
+    if (value == NULL) {
+        *result = NULL;
+        return 0;  // missing key
+    }
+    *result = Py_NewRef(value);
+    return 1;  // key is present
+}
+
 /* Gets an item and provides a new reference if the value is present.
  * Returns 1 if the key is present, 0 if the key is missing, and -1 if an
  * exception occurred.
@@ -2460,11 +2483,21 @@ setitem_lock_held(PyDictObject *mp, PyObject *key, PyObject *value)
 
 
 int
-_PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value,
-                         Py_hash_t hash)
+_PyDict_SetItem_KnownHash_LockHeld(PyDictObject *mp, PyObject *key, PyObject *value,
+                                   Py_hash_t hash)
 {
-    PyDictObject *mp;
+    PyInterpreterState *interp = _PyInterpreterState_GET();
+    if (mp->ma_keys == Py_EMPTY_KEYS) {
+        return insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
+    }
+    /* insertdict() handles any resizing that might be necessary */
+    return insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
+}
 
+int
+_PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value,
+                          Py_hash_t hash)
+{
     if (!PyDict_Check(op)) {
         PyErr_BadInternalCall();
         return -1;
@@ -2472,21 +2505,10 @@ _PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value,
     assert(key);
     assert(value);
     assert(hash != -1);
-    mp = (PyDictObject *)op;
 
     int res;
-    PyInterpreterState *interp = _PyInterpreterState_GET();
-
-    Py_BEGIN_CRITICAL_SECTION(mp);
-
-    if (mp->ma_keys == Py_EMPTY_KEYS) {
-        res = insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
-    }
-    else {
-        /* insertdict() handles any resizing that might be necessary */
-        res = insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
-    }
-
+    Py_BEGIN_CRITICAL_SECTION(op);
+    res = _PyDict_SetItem_KnownHash_LockHeld((PyDictObject *)op, key, value, hash);
     Py_END_CRITICAL_SECTION();
     return res;
 }