]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-112075: Fix race in constructing dict for instance (#118499)
authorDino Viehland <dinoviehland@meta.com>
Mon, 6 May 2024 23:31:09 +0000 (16:31 -0700)
committerGitHub <noreply@github.com>
Mon, 6 May 2024 23:31:09 +0000 (23:31 +0000)
Include/internal/pycore_dict.h
Lib/test/test_free_threading/test_dict.py [new file with mode: 0644]
Objects/dictobject.c
Objects/object.c

index 3ba8ee74b4df8709b3be45aecfc39281feed968c..cb7d4c3219a9afa519f8a83eeaa639be9e9c9cb4 100644 (file)
@@ -105,10 +105,10 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec
 
 /* Consumes references to key and value */
 PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value);
-extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr, PyObject *name, PyObject *value);
 extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value);
 extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
 extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
+extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value);
 
 extern int _PyDict_Pop_KnownHash(
     PyDictObject *dict,
diff --git a/Lib/test/test_free_threading/test_dict.py b/Lib/test/test_free_threading/test_dict.py
new file mode 100644 (file)
index 0000000..6a909dd
--- /dev/null
@@ -0,0 +1,141 @@
+import gc
+import time
+import unittest
+import weakref
+
+from ast import Or
+from functools import partial
+from threading import Thread
+from unittest import TestCase
+
+from test.support import threading_helper
+
+
+@threading_helper.requires_working_threading()
+class TestDict(TestCase):
+    def test_racing_creation_shared_keys(self):
+        """Verify that creating dictionaries is thread safe when we
+        have a type with shared keys"""
+        class C(int):
+            pass
+
+        self.racing_creation(C)
+
+    def test_racing_creation_no_shared_keys(self):
+        """Verify that creating dictionaries is thread safe when we
+        have a type with an ordinary dict"""
+        self.racing_creation(Or)
+
+    def test_racing_creation_inline_values_invalid(self):
+        """Verify that re-creating a dict after we have invalid inline values
+        is thread safe"""
+        class C:
+            pass
+
+        def make_obj():
+            a = C()
+            # Make object, make inline values invalid, and then delete dict
+            a.__dict__ = {}
+            del a.__dict__
+            return a
+
+        self.racing_creation(make_obj)
+
+    def test_racing_creation_nonmanaged_dict(self):
+        """Verify that explicit creation of an unmanaged dict is thread safe
+        outside of the normal attribute setting code path"""
+        def make_obj():
+            def f(): pass
+            return f
+
+        def set(func, name, val):
+            # Force creation of the dict via PyObject_GenericGetDict
+            func.__dict__[name] = val
+
+        self.racing_creation(make_obj, set)
+
+    def racing_creation(self, cls, set=setattr):
+        objects = []
+        processed = []
+
+        OBJECT_COUNT = 100
+        THREAD_COUNT = 10
+        CUR = 0
+
+        for i in range(OBJECT_COUNT):
+            objects.append(cls())
+
+        def writer_func(name):
+            last = -1
+            while True:
+                if CUR == last:
+                    continue
+                elif CUR == OBJECT_COUNT:
+                    break
+
+                obj = objects[CUR]
+                set(obj, name, name)
+                last = CUR
+                processed.append(name)
+
+        writers = []
+        for x in range(THREAD_COUNT):
+            writer = Thread(target=partial(writer_func, f"a{x:02}"))
+            writers.append(writer)
+            writer.start()
+
+        for i in range(OBJECT_COUNT):
+            CUR = i
+            while len(processed) != THREAD_COUNT:
+                time.sleep(0.001)
+            processed.clear()
+
+        CUR = OBJECT_COUNT
+
+        for writer in writers:
+            writer.join()
+
+        for obj_idx, obj in enumerate(objects):
+            assert (
+                len(obj.__dict__) == THREAD_COUNT
+            ), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
+            for i in range(THREAD_COUNT):
+                assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
+
+    def test_racing_set_dict(self):
+        """Races assigning to __dict__ should be thread safe"""
+
+        def f(): pass
+        l = []
+        THREAD_COUNT = 10
+        class MyDict(dict): pass
+
+        def writer_func(l):
+            for i in range(1000):
+                d = MyDict()
+                l.append(weakref.ref(d))
+                f.__dict__ = d
+
+        lists = []
+        writers = []
+        for x in range(THREAD_COUNT):
+            thread_list = []
+            lists.append(thread_list)
+            writer = Thread(target=partial(writer_func, thread_list))
+            writers.append(writer)
+
+        for writer in writers:
+            writer.start()
+
+        for writer in writers:
+            writer.join()
+
+        f.__dict__ = {}
+        gc.collect()
+
+        for thread_list in lists:
+            for ref in thread_list:
+                self.assertIsNone(ref())
+
+if __name__ == "__main__":
+    unittest.main()
index 3e662e09ea598e56b51f4bf1445c838f109b7c60..b0fce09d7940e0aebf86e320ef6c212f813dbc76 100644 (file)
@@ -924,16 +924,15 @@ new_dict(PyInterpreterState *interp,
     return (PyObject *)mp;
 }
 
-/* Consumes a reference to the keys object */
 static PyObject *
 new_dict_with_shared_keys(PyInterpreterState *interp, PyDictKeysObject *keys)
 {
     size_t size = shared_keys_usable_size(keys);
     PyDictValues *values = new_values(size);
     if (values == NULL) {
-        dictkeys_decref(interp, keys, false);
         return PyErr_NoMemory();
     }
+    dictkeys_incref(keys);
     for (size_t i = 0; i < size; i++) {
         values->values[i] = NULL;
     }
@@ -6693,8 +6692,6 @@ materialize_managed_dict_lock_held(PyObject *obj)
 {
     _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(obj);
 
-    OBJECT_STAT_INC(dict_materialized_on_request);
-
     PyDictValues *values = _PyObject_InlineValues(obj);
     PyInterpreterState *interp = _PyInterpreterState_GET();
     PyDictKeysObject *keys = CACHED_KEYS(Py_TYPE(obj));
@@ -7186,35 +7183,77 @@ _PyDict_DetachFromObject(PyDictObject *mp, PyObject *obj)
     return 0;
 }
 
-PyObject *
-PyObject_GenericGetDict(PyObject *obj, void *context)
+static inline PyObject *
+ensure_managed_dict(PyObject *obj)
 {
-    PyInterpreterState *interp = _PyInterpreterState_GET();
-    PyTypeObject *tp = Py_TYPE(obj);
-    PyDictObject *dict;
-    if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
-        dict = _PyObject_GetManagedDict(obj);
-        if (dict == NULL &&
-            (tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
+    PyDictObject *dict = _PyObject_GetManagedDict(obj);
+    if (dict == NULL) {
+        PyTypeObject *tp = Py_TYPE(obj);
+        if ((tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
             FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) {
             dict = _PyObject_MaterializeManagedDict(obj);
         }
-        else if (dict == NULL) {
-            Py_BEGIN_CRITICAL_SECTION(obj);
-
+        else {
+#ifdef Py_GIL_DISABLED
             // Check again that we're not racing with someone else creating the dict
+            Py_BEGIN_CRITICAL_SECTION(obj);
             dict = _PyObject_GetManagedDict(obj);
-            if (dict == NULL) {
-                OBJECT_STAT_INC(dict_materialized_on_request);
-                dictkeys_incref(CACHED_KEYS(tp));
-                dict = (PyDictObject *)new_dict_with_shared_keys(interp, CACHED_KEYS(tp));
-                FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
-                                            (PyDictObject *)dict);
+            if (dict != NULL) {
+                goto done;
             }
+#endif
+            dict = (PyDictObject *)new_dict_with_shared_keys(_PyInterpreterState_GET(),
+                                                             CACHED_KEYS(tp));
+            FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
+                                        (PyDictObject *)dict);
 
+#ifdef Py_GIL_DISABLED
+done:
             Py_END_CRITICAL_SECTION();
+#endif
         }
-        return Py_XNewRef((PyObject *)dict);
+    }
+    return (PyObject *)dict;
+}
+
+static inline PyObject *
+ensure_nonmanaged_dict(PyObject *obj, PyObject **dictptr)
+{
+    PyDictKeysObject *cached;
+
+    PyObject *dict = FT_ATOMIC_LOAD_PTR_ACQUIRE(*dictptr);
+    if (dict == NULL) {
+#ifdef Py_GIL_DISABLED
+        Py_BEGIN_CRITICAL_SECTION(obj);
+        dict = *dictptr;
+        if (dict != NULL) {
+            goto done;
+        }
+#endif
+        PyTypeObject *tp = Py_TYPE(obj);
+        if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
+            PyInterpreterState *interp = _PyInterpreterState_GET();
+            assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
+            dict = new_dict_with_shared_keys(interp, cached);
+        }
+        else {
+            dict = PyDict_New();
+        }
+        FT_ATOMIC_STORE_PTR_RELEASE(*dictptr, dict);
+#ifdef Py_GIL_DISABLED
+done:
+        Py_END_CRITICAL_SECTION();
+#endif
+    }
+    return dict;
+}
+
+PyObject *
+PyObject_GenericGetDict(PyObject *obj, void *context)
+{
+    PyTypeObject *tp = Py_TYPE(obj);
+    if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
+        return Py_XNewRef(ensure_managed_dict(obj));
     }
     else {
         PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
@@ -7223,65 +7262,28 @@ PyObject_GenericGetDict(PyObject *obj, void *context)
                             "This object has no __dict__");
             return NULL;
         }
-        PyObject *dict = *dictptr;
-        if (dict == NULL) {
-            PyTypeObject *tp = Py_TYPE(obj);
-            if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && CACHED_KEYS(tp)) {
-                dictkeys_incref(CACHED_KEYS(tp));
-                *dictptr = dict = new_dict_with_shared_keys(
-                        interp, CACHED_KEYS(tp));
-            }
-            else {
-                *dictptr = dict = PyDict_New();
-            }
-        }
-        return Py_XNewRef(dict);
+
+        return Py_XNewRef(ensure_nonmanaged_dict(obj, dictptr));
     }
 }
 
 int
-_PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr,
+_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
                       PyObject *key, PyObject *value)
 {
     PyObject *dict;
     int res;
-    PyDictKeysObject *cached;
-    PyInterpreterState *interp = _PyInterpreterState_GET();
 
     assert(dictptr != NULL);
-    if ((tp->tp_flags & Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
-        assert(dictptr != NULL);
-        dict = *dictptr;
-        if (dict == NULL) {
-            assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
-            dictkeys_incref(cached);
-            dict = new_dict_with_shared_keys(interp, cached);
-            if (dict == NULL)
-                return -1;
-            *dictptr = dict;
-        }
-        if (value == NULL) {
-            res = PyDict_DelItem(dict, key);
-        }
-        else {
-            res = PyDict_SetItem(dict, key, value);
-        }
-    } else {
-        dict = *dictptr;
-        if (dict == NULL) {
-            dict = PyDict_New();
-            if (dict == NULL)
-                return -1;
-            *dictptr = dict;
-        }
-        if (value == NULL) {
-            res = PyDict_DelItem(dict, key);
-        } else {
-            res = PyDict_SetItem(dict, key, value);
-        }
+    dict = ensure_nonmanaged_dict(obj, dictptr);
+    if (dict == NULL) {
+        return -1;
     }
 
+    Py_BEGIN_CRITICAL_SECTION(dict);
+    res = _PyDict_SetItem_LockHeld((PyDictObject *)dict, key, value);
     ASSERT_CONSISTENT(dict);
+    Py_END_CRITICAL_SECTION();
     return res;
 }
 
index effbd51991eaa5a48c94324eea61e480b61864e3..8ad0389cbc7626c026bcb0039e4bcee460528080 100644 (file)
@@ -1731,7 +1731,7 @@ _PyObject_GenericSetAttrWithDict(PyObject *obj, PyObject *name,
             goto done;
         }
         else {
-            res = _PyObjectDict_SetItem(tp, dictptr, name, value);
+            res = _PyObjectDict_SetItem(tp, obj, dictptr, name, value);
         }
     }
     else {
@@ -1789,7 +1789,9 @@ PyObject_GenericSetDict(PyObject *obj, PyObject *value, void *context)
                      "not a '%.200s'", Py_TYPE(value)->tp_name);
         return -1;
     }
+    Py_BEGIN_CRITICAL_SECTION(obj);
     Py_XSETREF(*dictptr, Py_NewRef(value));
+    Py_END_CRITICAL_SECTION();
     return 0;
 }