]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-116664: Ensure thread-safe dict access in _warnings (#116768)
authorErlend E. Aasland <erlend@python.org>
Mon, 18 Mar 2024 09:37:48 +0000 (10:37 +0100)
committerGitHub <noreply@github.com>
Mon, 18 Mar 2024 09:37:48 +0000 (09:37 +0000)
Replace _PyDict_GetItemWithError() with PyDict_GetItemRef().

Python/_warnings.c

index d4765032824e5666ddf11203d2c802ca2879f9da..dfa82c569e13838bb7d2f03c28cacb78a6400ae9 100644 (file)
@@ -1,5 +1,4 @@
 #include "Python.h"
-#include "pycore_dict.h"          // _PyDict_GetItemWithError()
 #include "pycore_interp.h"        // PyInterpreterState.warnings
 #include "pycore_long.h"          // _PyLong_GetZero()
 #include "pycore_pyerrors.h"      // _PyErr_Occurred()
@@ -8,6 +7,8 @@
 #include "pycore_sysmodule.h"     // _PySys_GetAttr()
 #include "pycore_traceback.h"     // _Py_DisplaySourceLine()
 
+#include <stdbool.h>
+
 #include "clinic/_warnings.c.h"
 
 #define MODULE_NAME "_warnings"
@@ -397,7 +398,7 @@ static int
 already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
                int should_set)
 {
-    PyObject *version_obj, *already_warned;
+    PyObject *already_warned;
 
     if (key == NULL)
         return -1;
@@ -406,14 +407,17 @@ already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
     if (st == NULL) {
         return -1;
     }
-    version_obj = _PyDict_GetItemWithError(registry, &_Py_ID(version));
-    if (version_obj == NULL
+    PyObject *version_obj;
+    if (PyDict_GetItemRef(registry, &_Py_ID(version), &version_obj) < 0) {
+        return -1;
+    }
+    bool should_update_version = (
+        version_obj == NULL
         || !PyLong_CheckExact(version_obj)
-        || PyLong_AsLong(version_obj) != st->filters_version)
-    {
-        if (PyErr_Occurred()) {
-            return -1;
-        }
+        || PyLong_AsLong(version_obj) != st->filters_version
+    );
+    Py_XDECREF(version_obj);
+    if (should_update_version) {
         PyDict_Clear(registry);
         version_obj = PyLong_FromLong(st->filters_version);
         if (version_obj == NULL)
@@ -911,13 +915,12 @@ setup_context(Py_ssize_t stack_level,
     /* Setup registry. */
     assert(globals != NULL);
     assert(PyDict_Check(globals));
-    *registry = _PyDict_GetItemWithError(globals, &_Py_ID(__warningregistry__));
+    int rc = PyDict_GetItemRef(globals, &_Py_ID(__warningregistry__),
+                               registry);
+    if (rc < 0) {
+        goto handle_error;
+    }
     if (*registry == NULL) {
-        int rc;
-
-        if (_PyErr_Occurred(tstate)) {
-            goto handle_error;
-        }
         *registry = PyDict_New();
         if (*registry == NULL)
             goto handle_error;
@@ -926,21 +929,21 @@ setup_context(Py_ssize_t stack_level,
          if (rc < 0)
             goto handle_error;
     }
-    else
-        Py_INCREF(*registry);
 
     /* Setup module. */
-    *module = _PyDict_GetItemWithError(globals, &_Py_ID(__name__));
-    if (*module == Py_None || (*module != NULL && PyUnicode_Check(*module))) {
-        Py_INCREF(*module);
-    }
-    else if (_PyErr_Occurred(tstate)) {
+    rc = PyDict_GetItemRef(globals, &_Py_ID(__name__), module);
+    if (rc < 0) {
         goto handle_error;
     }
-    else {
-        *module = PyUnicode_FromString("<string>");
-        if (*module == NULL)
-            goto handle_error;
+    if (rc > 0) {
+        if (Py_IsNone(*module) || PyUnicode_Check(*module)) {
+            return 1;
+        }
+        Py_DECREF(*module);
+    }
+    *module = PyUnicode_FromString("<string>");
+    if (*module == NULL) {
+        goto handle_error;
     }
 
     return 1;
@@ -1063,12 +1066,12 @@ get_source_line(PyInterpreterState *interp, PyObject *module_globals, int lineno
         return NULL;
     }
 
-    module_name = _PyDict_GetItemWithError(module_globals, &_Py_ID(__name__));
-    if (!module_name) {
+    int rc = PyDict_GetItemRef(module_globals, &_Py_ID(__name__),
+                               &module_name);
+    if (rc < 0 || rc == 0) {
         Py_DECREF(loader);
         return NULL;
     }
-    Py_INCREF(module_name);
 
     /* Make sure the loader implements the optional get_source() method. */
     (void)PyObject_GetOptionalAttr(loader, &_Py_ID(get_source), &get_source);