]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-112075: _Py_dict_lookup needs to lock shared keys (#117528)
authorDino Viehland <dinoviehland@meta.com>
Thu, 25 Apr 2024 22:34:05 +0000 (15:34 -0700)
committerGitHub <noreply@github.com>
Thu, 25 Apr 2024 22:34:05 +0000 (15:34 -0700)
Lock shared keys in `Py_dict_lookup` and use thread-safe lookup in `insertdict`

Co-authored-by: Sam Gross <colesbury@gmail.com>
Objects/dictobject.c

index afcf535f8c0a78aaa8f5198591cc4e996a50e60c..43cb2350b7fc71b3814812709147afec63662c36 100644 (file)
@@ -162,6 +162,16 @@ ASSERT_DICT_LOCKED(PyObject *op)
     assert(_Py_IsOwnedByCurrentThread((PyObject *)mp) || IS_DICT_SHARED(mp));
 #define LOAD_KEYS_NENTRIES(d)
 
+#define LOCK_KEYS_IF_SPLIT(keys, kind) \
+        if (kind == DICT_KEYS_SPLIT) { \
+            LOCK_KEYS(dk);             \
+        }
+
+#define UNLOCK_KEYS_IF_SPLIT(keys, kind) \
+        if (kind == DICT_KEYS_SPLIT) {   \
+            UNLOCK_KEYS(dk);             \
+        }
+
 static inline Py_ssize_t
 load_keys_nentries(PyDictObject *mp)
 {
@@ -195,6 +205,9 @@ set_values(PyDictObject *mp, PyDictValues *values)
 #define DECREF_KEYS(dk)  _Py_atomic_add_ssize(&dk->dk_refcnt, -1)
 #define LOAD_KEYS_NENTIRES(keys) _Py_atomic_load_ssize_relaxed(&keys->dk_nentries)
 
+#define INCREF_KEYS_FT(dk) dictkeys_incref(dk)
+#define DECREF_KEYS_FT(dk, shared) dictkeys_decref(_PyInterpreterState_GET(), dk, shared)
+
 static inline void split_keys_entry_added(PyDictKeysObject *keys)
 {
     ASSERT_KEYS_LOCKED(keys);
@@ -216,6 +229,10 @@ static inline void split_keys_entry_added(PyDictKeysObject *keys)
 #define INCREF_KEYS(dk)  dk->dk_refcnt++
 #define DECREF_KEYS(dk)  dk->dk_refcnt--
 #define LOAD_KEYS_NENTIRES(keys) keys->dk_nentries
+#define INCREF_KEYS_FT(dk)
+#define DECREF_KEYS_FT(dk, shared)
+#define LOCK_KEYS_IF_SPLIT(keys, kind)
+#define UNLOCK_KEYS_IF_SPLIT(keys, kind)
 #define IS_DICT_SHARED(mp) (false)
 #define SET_DICT_SHARED(mp)
 #define LOAD_INDEX(keys, size, idx) ((const int##size##_t*)(keys->dk_indices))[idx]
@@ -1171,6 +1188,14 @@ _PyDictKeys_StringLookup(PyDictKeysObject* dk, PyObject *key)
     return unicodekeys_lookup_unicode(dk, key, hash);
 }
 
+#ifdef Py_GIL_DISABLED
+
+static Py_ssize_t
+unicodekeys_lookup_unicode_threadsafe(PyDictKeysObject* dk, PyObject *key,
+                                      Py_hash_t hash);
+
+#endif
+
 /*
 The basic lookup function used by all operations.
 This is based on Algorithm D from Knuth Vol. 3, Sec. 6.4.
@@ -1200,10 +1225,33 @@ start:
 
     if (kind != DICT_KEYS_GENERAL) {
         if (PyUnicode_CheckExact(key)) {
+#ifdef Py_GIL_DISABLED
+            if (kind == DICT_KEYS_SPLIT) {
+                // A split dictionaries keys can be mutated by other
+                // dictionaries but if we have a unicode key we can avoid
+                // locking the shared keys.
+                ix = unicodekeys_lookup_unicode_threadsafe(dk, key, hash);
+                if (ix == DKIX_KEY_CHANGED) {
+                    LOCK_KEYS(dk);
+                    ix = unicodekeys_lookup_unicode(dk, key, hash);
+                    UNLOCK_KEYS(dk);
+                }
+            }
+            else {
+                ix = unicodekeys_lookup_unicode(dk, key, hash);
+            }
+#else
             ix = unicodekeys_lookup_unicode(dk, key, hash);
+#endif
         }
         else {
+            INCREF_KEYS_FT(dk);
+            LOCK_KEYS_IF_SPLIT(dk, kind);
+
             ix = unicodekeys_lookup_generic(mp, dk, key, hash);
+
+            UNLOCK_KEYS_IF_SPLIT(dk, kind);
+            DECREF_KEYS_FT(dk, IS_DICT_SHARED(mp));
             if (ix == DKIX_KEY_CHANGED) {
                 goto start;
             }
@@ -1607,31 +1655,6 @@ insertion_resize(PyInterpreterState *interp, PyDictObject *mp, int unicode)
     return dictresize(interp, mp, calculate_log2_keysize(GROWTH_RATE(mp)), unicode);
 }
 
-static Py_ssize_t
-insert_into_splitdictkeys(PyDictKeysObject *keys, PyObject *name, Py_hash_t hash)
-{
-    assert(PyUnicode_CheckExact(name));
-    ASSERT_KEYS_LOCKED(keys);
-
-    Py_ssize_t ix = unicodekeys_lookup_unicode(keys, name, hash);
-    if (ix == DKIX_EMPTY) {
-        if (keys->dk_usable <= 0) {
-            return DKIX_EMPTY;
-        }
-        /* Insert into new slot. */
-        keys->dk_version = 0;
-        Py_ssize_t hashpos = find_empty_slot(keys, hash);
-        ix = keys->dk_nentries;
-        PyDictUnicodeEntry *ep = &DK_UNICODE_ENTRIES(keys)[ix];
-        dictkeys_set_index(keys, hashpos, ix);
-        assert(ep->me_key == NULL);
-        ep->me_key = Py_NewRef(name);
-        split_keys_entry_added(keys);
-    }
-    assert (ix < SHARED_KEYS_MAX_SIZE);
-    return ix;
-}
-
 static inline int
 insert_combined_dict(PyInterpreterState *interp, PyDictObject *mp,
                      Py_hash_t hash, PyObject *key, PyObject *value)
@@ -1665,39 +1688,57 @@ insert_combined_dict(PyInterpreterState *interp, PyDictObject *mp,
     return 0;
 }
 
-static int
-insert_split_dict(PyInterpreterState *interp, PyDictObject *mp,
-                  Py_hash_t hash, PyObject *key, PyObject *value)
+static Py_ssize_t
+insert_split_key(PyDictKeysObject *keys, PyObject *key, Py_hash_t hash)
 {
-    PyDictKeysObject *keys = mp->ma_keys;
-    LOCK_KEYS(keys);
-    if (keys->dk_usable <= 0) {
-        /* Need to resize. */
-        UNLOCK_KEYS(keys);
-        int ins = insertion_resize(interp, mp, 1);
-        if (ins < 0) {
-            return -1;
-        }
-        assert(!_PyDict_HasSplitTable(mp));
-        return insert_combined_dict(interp, mp, hash, key, value);
-    }
-
-    Py_ssize_t hashpos = find_empty_slot(keys, hash);
-    dictkeys_set_index(keys, hashpos, keys->dk_nentries);
+    assert(PyUnicode_CheckExact(key));
+    Py_ssize_t ix;
 
-    PyDictUnicodeEntry *ep;
-    ep = &DK_UNICODE_ENTRIES(keys)[keys->dk_nentries];
-    STORE_SHARED_KEY(ep->me_key, key);
 
-    Py_ssize_t index = keys->dk_nentries;
-    _PyDictValues_AddToInsertionOrder(mp->ma_values, index);
-    assert (mp->ma_values->values[index] == NULL);
-    STORE_SPLIT_VALUE(mp, index, value);
+#ifdef Py_GIL_DISABLED
+    ix = unicodekeys_lookup_unicode_threadsafe(keys, key, hash);
+    if (ix >= 0) {
+        return ix;
+    }
+#endif
 
-    split_keys_entry_added(keys);
-    assert(keys->dk_usable >= 0);
+    LOCK_KEYS(keys);
+    ix = unicodekeys_lookup_unicode(keys, key, hash);
+    if (ix == DKIX_EMPTY && keys->dk_usable > 0) {
+        // Insert into new slot
+        keys->dk_version = 0;
+        Py_ssize_t hashpos = find_empty_slot(keys, hash);
+        ix = keys->dk_nentries;
+        dictkeys_set_index(keys, hashpos, ix);
+        PyDictUnicodeEntry *ep = &DK_UNICODE_ENTRIES(keys)[ix];
+        STORE_SHARED_KEY(ep->me_key, Py_NewRef(key));
+        split_keys_entry_added(keys);
+    }
+    assert (ix < SHARED_KEYS_MAX_SIZE);
     UNLOCK_KEYS(keys);
-    return 0;
+    return ix;
+}
+
+static void
+insert_split_value(PyInterpreterState *interp, PyDictObject *mp, PyObject *key, PyObject *value, Py_ssize_t ix)
+{
+    assert(PyUnicode_CheckExact(key));
+    MAINTAIN_TRACKING(mp, key, value);
+    PyObject *old_value = mp->ma_values->values[ix];
+    if (old_value == NULL) {
+        uint64_t new_version = _PyDict_NotifyEvent(interp, PyDict_EVENT_ADDED, mp, key, value);
+        STORE_SPLIT_VALUE(mp, ix, Py_NewRef(value));
+        _PyDictValues_AddToInsertionOrder(mp->ma_values, ix);
+        mp->ma_used++;
+        mp->ma_version_tag = new_version;
+    }
+    else {
+        uint64_t new_version = _PyDict_NotifyEvent(interp, PyDict_EVENT_MODIFIED, mp, key, value);
+        STORE_SPLIT_VALUE(mp, ix, Py_NewRef(value));
+        mp->ma_version_tag = new_version;
+        Py_DECREF(old_value);
+    }
+    ASSERT_CONSISTENT(mp);
 }
 
 /*
@@ -1720,6 +1761,21 @@ insertdict(PyInterpreterState *interp, PyDictObject *mp,
         assert(mp->ma_keys->dk_kind == DICT_KEYS_GENERAL);
     }
 
+    if (_PyDict_HasSplitTable(mp)) {
+        Py_ssize_t ix = insert_split_key(mp->ma_keys, key, hash);
+        if (ix != DKIX_EMPTY) {
+            insert_split_value(interp, mp, key, value, ix);
+            Py_DECREF(key);
+            Py_DECREF(value);
+            return 0;
+        }
+
+        /* No space in shared keys. Resize and continue below. */
+        if (insertion_resize(interp, mp, 1) < 0) {
+            goto Fail;
+        }
+    }
+
     Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &old_value);
     if (ix == DKIX_ERROR)
         goto Fail;
@@ -1727,24 +1783,17 @@ insertdict(PyInterpreterState *interp, PyDictObject *mp,
     MAINTAIN_TRACKING(mp, key, value);
 
     if (ix == DKIX_EMPTY) {
+        assert(!_PyDict_HasSplitTable(mp));
         uint64_t new_version = _PyDict_NotifyEvent(
                 interp, PyDict_EVENT_ADDED, mp, key, value);
         /* Insert into new slot. */
         mp->ma_keys->dk_version = 0;
         assert(old_value == NULL);
-
-        if (!_PyDict_HasSplitTable(mp)) {
-            if (insert_combined_dict(interp, mp, hash, key, value) < 0) {
-                goto Fail;
-            }
-        }
-        else {
-            if (insert_split_dict(interp, mp, hash, key, value) < 0)
-                goto Fail;
+        if (insert_combined_dict(interp, mp, hash, key, value) < 0) {
+            goto Fail;
         }
-
-        mp->ma_used++;
         mp->ma_version_tag = new_version;
+        mp->ma_used++;
         ASSERT_CONSISTENT(mp);
         return 0;
     }
@@ -1752,23 +1801,15 @@ insertdict(PyInterpreterState *interp, PyDictObject *mp,
     if (old_value != value) {
         uint64_t new_version = _PyDict_NotifyEvent(
                 interp, PyDict_EVENT_MODIFIED, mp, key, value);
-        if (_PyDict_HasSplitTable(mp)) {
-            STORE_SPLIT_VALUE(mp, ix, value);
-            if (old_value == NULL) {
-                _PyDictValues_AddToInsertionOrder(mp->ma_values, ix);
-                mp->ma_used++;
-            }
+        assert(old_value != NULL);
+        assert(!_PyDict_HasSplitTable(mp));
+        if (DK_IS_UNICODE(mp->ma_keys)) {
+            PyDictUnicodeEntry *ep = &DK_UNICODE_ENTRIES(mp->ma_keys)[ix];
+            STORE_VALUE(ep, value);
         }
         else {
-            assert(old_value != NULL);
-            if (DK_IS_UNICODE(mp->ma_keys)) {
-                PyDictUnicodeEntry *ep = &DK_UNICODE_ENTRIES(mp->ma_keys)[ix];
-                STORE_VALUE(ep, value);
-            }
-            else {
-                PyDictKeyEntry *ep = &DK_ENTRIES(mp->ma_keys)[ix];
-                STORE_VALUE(ep, value);
-            }
+            PyDictKeyEntry *ep = &DK_ENTRIES(mp->ma_keys)[ix];
+            STORE_VALUE(ep, value);
         }
         mp->ma_version_tag = new_version;
     }
@@ -4177,6 +4218,29 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu
         }
     }
 
+    if (_PyDict_HasSplitTable(mp)) {
+        Py_ssize_t ix = insert_split_key(mp->ma_keys, key, hash);
+        if (ix != DKIX_EMPTY) {
+            PyObject *value = mp->ma_values->values[ix];
+            int already_present = value != NULL;
+            if (!already_present) {
+                insert_split_value(interp, mp, key, default_value, ix);
+                value = default_value;
+            }
+            if (result) {
+                *result = incref_result ? Py_NewRef(value) : value;
+            }
+            return already_present;
+        }
+
+        /* No space in shared keys. Resize and continue below. */
+        if (insertion_resize(interp, mp, 1) < 0) {
+            goto error;
+        }
+    }
+
+    assert(!_PyDict_HasSplitTable(mp));
+
     Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &value);
     if (ix == DKIX_ERROR) {
         if (result) {
@@ -4186,29 +4250,17 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu
     }
 
     if (ix == DKIX_EMPTY) {
+        assert(!_PyDict_HasSplitTable(mp));
         uint64_t new_version = _PyDict_NotifyEvent(
-                interp, PyDict_EVENT_ADDED, mp, key, default_value);
+            interp, PyDict_EVENT_ADDED, mp, key, default_value);
         mp->ma_keys->dk_version = 0;
         value = default_value;
 
-        if (!_PyDict_HasSplitTable(mp)) {
-            if (insert_combined_dict(interp, mp, hash, Py_NewRef(key), Py_NewRef(value)) < 0) {
-                Py_DECREF(key);
-                Py_DECREF(value);
-                if (result) {
-                    *result = NULL;
-                }
-                return -1;
-            }
-        }
-        else {
-            if (insert_split_dict(interp, mp, hash, Py_NewRef(key), Py_NewRef(value)) < 0) {
-                Py_DECREF(key);
-                Py_DECREF(value);
-                if (result) {
-                    *result = NULL;
-                }
-                return -1;
+        if (insert_combined_dict(interp, mp, hash, Py_NewRef(key), Py_NewRef(value)) < 0) {
+            Py_DECREF(key);
+            Py_DECREF(value);
+            if (result) {
+                *result = NULL;
             }
         }
 
@@ -4222,29 +4274,19 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu
         }
         return 0;
     }
-    else if (value == NULL) {
-        uint64_t new_version = _PyDict_NotifyEvent(
-                interp, PyDict_EVENT_ADDED, mp, key, default_value);
-        value = default_value;
-        assert(_PyDict_HasSplitTable(mp));
-        assert(mp->ma_values->values[ix] == NULL);
-        MAINTAIN_TRACKING(mp, key, value);
-        STORE_SPLIT_VALUE(mp, ix, Py_NewRef(value));
-        _PyDictValues_AddToInsertionOrder(mp->ma_values, ix);
-        mp->ma_used++;
-        mp->ma_version_tag = new_version;
-        ASSERT_CONSISTENT(mp);
-        if (result) {
-            *result = incref_result ? Py_NewRef(value) : value;
-        }
-        return 0;
-    }
 
+    assert(value != NULL);
     ASSERT_CONSISTENT(mp);
     if (result) {
         *result = incref_result ? Py_NewRef(value) : value;
     }
     return 1;
+
+error:
+    if (result) {
+        *result = NULL;
+    }
+    return -1;
 }
 
 int
@@ -6699,18 +6741,7 @@ store_instance_attr_lock_held(PyObject *obj, PyDictValues *values,
             assert(hash != -1);
         }
 
-#ifdef Py_GIL_DISABLED
-        // Try a thread-safe lookup to see if the index is already allocated
-        ix = unicodekeys_lookup_unicode_threadsafe(keys, name, hash);
-        if (ix == DKIX_EMPTY || ix == DKIX_KEY_CHANGED) {
-            // Lock keys and do insert
-            LOCK_KEYS(keys);
-            ix = insert_into_splitdictkeys(keys, name, hash);
-            UNLOCK_KEYS(keys);
-        }
-#else
-        ix = insert_into_splitdictkeys(keys, name, hash);
-#endif
+        ix = insert_split_key(keys, name, hash);
 
 #ifdef Py_STATS
         if (ix == DKIX_EMPTY) {