]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-112066: Add `PyDict_SetDefaultRef` function. (#112123)
authorSam Gross <colesbury@gmail.com>
Tue, 6 Feb 2024 16:36:23 +0000 (11:36 -0500)
committerGitHub <noreply@github.com>
Tue, 6 Feb 2024 16:36:23 +0000 (11:36 -0500)
The `PyDict_SetDefaultRef` function is similar to `PyDict_SetDefault`,
but returns a strong reference through the optional `**result` pointer
instead of a borrowed reference.

Co-authored-by: Petr Viktorin <encukou@gmail.com>
Doc/c-api/dict.rst
Doc/whatsnew/3.13.rst
Include/cpython/dictobject.h
Lib/test/test_capi/test_dict.py
Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst [new file with mode: 0644]
Modules/_testcapi/dict.c
Objects/dictobject.c

index 8471c98d04487218b5646968145b96dd1a184ee2..03f3d28187bfe9aafb5a9b016a8a643313e67c7b 100644 (file)
@@ -174,6 +174,26 @@ Dictionary Objects
    .. versionadded:: 3.4
 
 
+.. c:function:: int PyDict_SetDefaultRef(PyObject *p, PyObject *key, PyObject *default_value, PyObject **result)
+
+   Inserts *default_value* into the dictionary *p* with a key of *key* if the
+   key is not already present in the dictionary. If *result* is not ``NULL``,
+   then *\*result* is set to a :term:`strong reference` to either
+   *default_value*, if the key was not present, or the existing value, if *key*
+   was already present in the dictionary.
+   Returns ``1`` if the key was present and *default_value* was not inserted,
+   or ``0`` if the key was not present and *default_value* was inserted.
+   On failure, returns ``-1``, sets an exception, and sets ``*result``
+   to ``NULL``.
+
+   For clarity: if you have a strong reference to *default_value* before
+   calling this function, then after it returns, you hold a strong reference
+   to both *default_value* and *\*result* (if it's not ``NULL``).
+   These may refer to the same object: in that case you hold two separate
+   references to it.
+   .. versionadded:: 3.13
+
+
 .. c:function:: int PyDict_Pop(PyObject *p, PyObject *key, PyObject **result)
 
    Remove *key* from dictionary *p* and optionally return the removed value.
index 372757759b986f938d8dc67131d17bf2791732cf..e034d34c5fb5abc8ae6e71473f0a5071fd2c8be4 100644 (file)
@@ -1440,6 +1440,12 @@ New Features
   not needed.
   (Contributed by Victor Stinner in :gh:`106004`.)
 
+* Added :c:func:`PyDict_SetDefaultRef`, which is similar to
+  :c:func:`PyDict_SetDefault` but returns a :term:`strong reference` instead of
+  a :term:`borrowed reference`. This function returns ``-1`` on error, ``0`` on
+  insertion, and ``1`` if the key was already present in the dictionary.
+  (Contributed by Sam Gross in :gh:`112066`.)
+
 * Add :c:func:`PyDict_ContainsString` function: same as
   :c:func:`PyDict_Contains`, but *key* is specified as a :c:expr:`const char*`
   UTF-8 encoded bytes string, rather than a :c:expr:`PyObject*`.
index 1720fe6f01ea37d11a09d28182bbfe521418bb07..35b6a822a0dfffd8c3bff01f3d2b6a6f38a3886b 100644 (file)
@@ -41,6 +41,16 @@ PyAPI_FUNC(PyObject *) _PyDict_GetItemStringWithError(PyObject *, const char *);
 PyAPI_FUNC(PyObject *) PyDict_SetDefault(
     PyObject *mp, PyObject *key, PyObject *defaultobj);
 
+// Inserts `key` with a value `default_value`, if `key` is not already present
+// in the dictionary.  If `result` is not NULL, then the value associated
+// with `key` is returned in `*result` (either the existing value, or the now
+// inserted `default_value`).
+// Returns:
+//   -1 on error
+//    0 if `key` was not present and `default_value` was inserted
+//    1 if `key` was present and `default_value` was not inserted
+PyAPI_FUNC(int) PyDict_SetDefaultRef(PyObject *mp, PyObject *key, PyObject *default_value, PyObject **result);
+
 /* Get the number of items of a dictionary. */
 static inline Py_ssize_t PyDict_GET_SIZE(PyObject *op) {
     PyDictObject *mp;
index 57a7238588eae09b64628638b8f79aad34efdc0c..cca6145bc90c047ebefb370da40a8aa76aec0828 100644 (file)
@@ -339,6 +339,28 @@ class CAPITest(unittest.TestCase):
         # CRASHES setdefault({}, 'a', NULL)
         # CRASHES setdefault(NULL, 'a', 5)
 
+    def test_dict_setdefaultref(self):
+        setdefault = _testcapi.dict_setdefaultref
+        dct = {}
+        self.assertEqual(setdefault(dct, 'a', 5), 5)
+        self.assertEqual(dct, {'a': 5})
+        self.assertEqual(setdefault(dct, 'a', 8), 5)
+        self.assertEqual(dct, {'a': 5})
+
+        dct2 = DictSubclass()
+        self.assertEqual(setdefault(dct2, 'a', 5), 5)
+        self.assertEqual(dct2, {'a': 5})
+        self.assertEqual(setdefault(dct2, 'a', 8), 5)
+        self.assertEqual(dct2, {'a': 5})
+
+        self.assertRaises(TypeError, setdefault, {}, [], 5)  # unhashable
+        self.assertRaises(SystemError, setdefault, UserDict(), 'a', 5)
+        self.assertRaises(SystemError, setdefault, [1], 0, 5)
+        self.assertRaises(SystemError, setdefault, 42, 'a', 5)
+        # CRASHES setdefault({}, NULL, 5)
+        # CRASHES setdefault({}, 'a', NULL)
+        # CRASHES setdefault(NULL, 'a', 5)
+
     def test_mapping_keys_valuesitems(self):
         class BadMapping(dict):
             def keys(self):
diff --git a/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst b/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst
new file mode 100644 (file)
index 0000000..ae2b8b2
--- /dev/null
@@ -0,0 +1,5 @@
+Add :c:func:`PyDict_SetDefaultRef`: insert a key and value into a dictionary
+if the key is not already present. This is similar to
+:meth:`dict.setdefault`, but returns an integer value indicating if the key
+was already present. It is also similar to :c:func:`PyDict_SetDefault`, but
+returns a strong reference instead of a borrowed reference.
index 42e056b7d07a31fccb709eaeb24a30f0e9d659b3..fe03c24f75e196b5800731812a47103bea2b5657 100644 (file)
@@ -225,6 +225,31 @@ dict_setdefault(PyObject *self, PyObject *args)
     return PyDict_SetDefault(mapping, key, defaultobj);
 }
 
+static PyObject *
+dict_setdefaultref(PyObject *self, PyObject *args)
+{
+    PyObject *obj, *key, *default_value, *result = UNINITIALIZED_PTR;
+    if (!PyArg_ParseTuple(args, "OOO", &obj, &key, &default_value)) {
+        return NULL;
+    }
+    NULLABLE(obj);
+    NULLABLE(key);
+    NULLABLE(default_value);
+    switch (PyDict_SetDefaultRef(obj, key, default_value, &result)) {
+        case -1:
+            assert(result == NULL);
+            return NULL;
+        case 0:
+            assert(result == default_value);
+            return result;
+        case 1:
+            return result;
+        default:
+            Py_FatalError("PyDict_SetDefaultRef() returned invalid code");
+            Py_UNREACHABLE();
+    }
+}
+
 static PyObject *
 dict_delitem(PyObject *self, PyObject *args)
 {
@@ -433,6 +458,7 @@ static PyMethodDef test_methods[] = {
     {"dict_delitem", dict_delitem, METH_VARARGS},
     {"dict_delitemstring", dict_delitemstring, METH_VARARGS},
     {"dict_setdefault", dict_setdefault, METH_VARARGS},
+    {"dict_setdefaultref", dict_setdefaultref, METH_VARARGS},
     {"dict_keys", dict_keys, METH_O},
     {"dict_values", dict_values, METH_O},
     {"dict_items", dict_items, METH_O},
index 4bb818b90a4a72d7df968a8649816ae68d468fc5..11b388d9f4adb0520b32b46d7fc870b2fdc77233 100644 (file)
@@ -3355,8 +3355,9 @@ dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value)
     return Py_NewRef(val);
 }
 
-PyObject *
-PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
+static int
+dict_setdefault_ref(PyObject *d, PyObject *key, PyObject *default_value,
+                    PyObject **result, int incref_result)
 {
     PyDictObject *mp = (PyDictObject *)d;
     PyObject *value;
@@ -3365,41 +3366,64 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
 
     if (!PyDict_Check(d)) {
         PyErr_BadInternalCall();
-        return NULL;
+        if (result) {
+            *result = NULL;
+        }
+        return -1;
     }
 
     if (!PyUnicode_CheckExact(key) || (hash = unicode_get_hash(key)) == -1) {
         hash = PyObject_Hash(key);
-        if (hash == -1)
-            return NULL;
+        if (hash == -1) {
+            if (result) {
+                *result = NULL;
+            }
+            return -1;
+        }
     }
 
     if (mp->ma_keys == Py_EMPTY_KEYS) {
         if (insert_to_emptydict(interp, mp, Py_NewRef(key), hash,
-                                Py_NewRef(defaultobj)) < 0) {
-            return NULL;
+                                Py_NewRef(default_value)) < 0) {
+            if (result) {
+                *result = NULL;
+            }
+            return -1;
+        }
+        if (result) {
+            *result = incref_result ? Py_NewRef(default_value) : default_value;
         }
-        return defaultobj;
+        return 0;
     }
 
     if (!PyUnicode_CheckExact(key) && DK_IS_UNICODE(mp->ma_keys)) {
         if (insertion_resize(interp, mp, 0) < 0) {
-            return NULL;
+            if (result) {
+                *result = NULL;
+            }
+            return -1;
         }
     }
 
     Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &value);
-    if (ix == DKIX_ERROR)
-        return NULL;
+    if (ix == DKIX_ERROR) {
+        if (result) {
+            *result = NULL;
+        }
+        return -1;
+    }
 
     if (ix == DKIX_EMPTY) {
         uint64_t new_version = _PyDict_NotifyEvent(
-                interp, PyDict_EVENT_ADDED, mp, key, defaultobj);
+                interp, PyDict_EVENT_ADDED, mp, key, default_value);
         mp->ma_keys->dk_version = 0;
-        value = defaultobj;
+        value = default_value;
         if (mp->ma_keys->dk_usable <= 0) {
             if (insertion_resize(interp, mp, 1) < 0) {
-                return NULL;
+                if (result) {
+                    *result = NULL;
+                }
+                return -1;
             }
         }
         Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash);
@@ -3431,11 +3455,16 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
         mp->ma_keys->dk_usable--;
         mp->ma_keys->dk_nentries++;
         assert(mp->ma_keys->dk_usable >= 0);
+        ASSERT_CONSISTENT(mp);
+        if (result) {
+            *result = incref_result ? Py_NewRef(value) : value;
+        }
+        return 0;
     }
     else if (value == NULL) {
         uint64_t new_version = _PyDict_NotifyEvent(
-                interp, PyDict_EVENT_ADDED, mp, key, defaultobj);
-        value = defaultobj;
+                interp, PyDict_EVENT_ADDED, mp, key, default_value);
+        value = default_value;
         assert(_PyDict_HasSplitTable(mp));
         assert(mp->ma_values->values[ix] == NULL);
         MAINTAIN_TRACKING(mp, key, value);
@@ -3443,10 +3472,33 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
         _PyDictValues_AddToInsertionOrder(mp->ma_values, ix);
         mp->ma_used++;
         mp->ma_version_tag = new_version;
+        ASSERT_CONSISTENT(mp);
+        if (result) {
+            *result = incref_result ? Py_NewRef(value) : value;
+        }
+        return 0;
     }
 
     ASSERT_CONSISTENT(mp);
-    return value;
+    if (result) {
+        *result = incref_result ? Py_NewRef(value) : value;
+    }
+    return 1;
+}
+
+int
+PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value,
+                     PyObject **result)
+{
+    return dict_setdefault_ref(d, key, default_value, result, 1);
+}
+
+PyObject *
+PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
+{
+    PyObject *result;
+    dict_setdefault_ref(d, key, defaultobj, &result, 0);
+    return result;
 }
 
 /*[clinic input]
@@ -3467,9 +3519,8 @@ dict_setdefault_impl(PyDictObject *self, PyObject *key,
 /*[clinic end generated code: output=f8c1101ebf69e220 input=0f063756e815fd9d]*/
 {
     PyObject *val;
-
-    val = PyDict_SetDefault((PyObject *)self, key, default_value);
-    return Py_XNewRef(val);
+    PyDict_SetDefaultRef((PyObject *)self, key, default_value, &val);
+    return val;
 }