]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-141510: Check argument in PyDict_Contains() (#145083)
authorVictor Stinner <vstinner@python.org>
Sat, 21 Feb 2026 17:36:02 +0000 (18:36 +0100)
committerGitHub <noreply@github.com>
Sat, 21 Feb 2026 17:36:02 +0000 (18:36 +0100)
PyDict_Contains() and PyDict_ContainsString() now fail with
SystemError if the first argument is not a dict, frozendict, dict
subclass or frozendict subclass.

Lib/test/test_capi/test_dict.py
Objects/dictobject.c

index d3cc279cd3f95569cfffc6f290ad9700628f74f0..f69ccbdbd1117d3408d1d10944d5c936b324b25b 100644 (file)
@@ -223,6 +223,7 @@ class CAPITest(unittest.TestCase):
         # CRASHES getitem(NULL, 'a')
 
     def test_dict_contains(self):
+        # Test PyDict_Contains()
         contains = _testlimitedcapi.dict_contains
         dct = {'a': 1, '\U0001f40d': 2}
         self.assertTrue(contains(dct, 'a'))
@@ -235,11 +236,12 @@ class CAPITest(unittest.TestCase):
 
         self.assertRaises(TypeError, contains, {}, [])  # unhashable
         # CRASHES contains({}, NULL)
-        # CRASHES contains(UserDict(), 'a')
-        # CRASHES contains(42, 'a')
+        self.assertRaises(SystemError, contains, UserDict(), 'a')
+        self.assertRaises(SystemError, contains, 42, 'a')
         # CRASHES contains(NULL, 'a')
 
     def test_dict_contains_string(self):
+        # Test PyDict_ContainsString()
         contains_string = _testcapi.dict_containsstring
         dct = {'a': 1, '\U0001f40d': 2}
         self.assertTrue(contains_string(dct, b'a'))
@@ -251,6 +253,8 @@ class CAPITest(unittest.TestCase):
         self.assertTrue(contains_string(dct2, b'a'))
         self.assertFalse(contains_string(dct2, b'b'))
 
+        self.assertRaises(SystemError, contains_string, UserDict(), 'a')
+        self.assertRaises(SystemError, contains_string, 42, 'a')
         # CRASHES contains({}, NULL)
         # CRASHES contains(NULL, b'a')
 
index 276e1df21a80d8ec6b22ef2f1ca1bb3d411a195d..0a8ba74c2287c144aa85139aa95a057f72a74702 100644 (file)
@@ -140,6 +140,7 @@ static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
                                 PyObject *kwds);
 static PyObject* dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
 static int dict_merge(PyObject *a, PyObject *b, int override);
+static int dict_contains(PyObject *op, PyObject *key);
 static int dict_merge_from_seq2(PyObject *d, PyObject *seq2, int override);
 
 
@@ -4126,7 +4127,7 @@ dict_merge(PyObject *a, PyObject *b, int override)
 
         for (key = PyIter_Next(iter); key; key = PyIter_Next(iter)) {
             if (override != 1) {
-                status = PyDict_Contains(a, key);
+                status = dict_contains(a, key);
                 if (status != 0) {
                     if (status > 0) {
                         if (override == 0) {
@@ -4484,7 +4485,7 @@ static PyObject *
 dict___contains___impl(PyDictObject *self, PyObject *key)
 /*[clinic end generated code: output=1b314e6da7687dae input=fe1cb42ad831e820]*/
 {
-    int contains = PyDict_Contains((PyObject *)self, key);
+    int contains = dict_contains((PyObject *)self, key);
     if (contains < 0) {
         return NULL;
     }
@@ -4984,9 +4985,8 @@ static PyMethodDef mapp_methods[] = {
     {NULL,              NULL}   /* sentinel */
 };
 
-/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
-int
-PyDict_Contains(PyObject *op, PyObject *key)
+static int
+dict_contains(PyObject *op, PyObject *key)
 {
     Py_hash_t hash = _PyObject_HashFast(key);
     if (hash == -1) {
@@ -4997,6 +4997,18 @@ PyDict_Contains(PyObject *op, PyObject *key)
     return _PyDict_Contains_KnownHash(op, key, hash);
 }
 
+/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
+int
+PyDict_Contains(PyObject *op, PyObject *key)
+{
+    if (!PyAnyDict_Check(op)) {
+        PyErr_BadInternalCall();
+        return -1;
+    }
+
+    return dict_contains(op, key);
+}
+
 int
 PyDict_ContainsString(PyObject *op, const char *key)
 {
@@ -5013,7 +5025,7 @@ PyDict_ContainsString(PyObject *op, const char *key)
 int
 _PyDict_Contains_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash)
 {
-    PyDictObject *mp = (PyDictObject *)op;
+    PyDictObject *mp = _PyAnyDict_CAST(op);
     PyObject *value;
     Py_ssize_t ix;
 
@@ -5042,7 +5054,7 @@ static PySequenceMethods dict_as_sequence = {
     0,                          /* sq_slice */
     0,                          /* sq_ass_item */
     0,                          /* sq_ass_slice */
-    PyDict_Contains,            /* sq_contains */
+    dict_contains,              /* sq_contains */
     0,                          /* sq_inplace_concat */
     0,                          /* sq_inplace_repeat */
 };
@@ -6292,7 +6304,7 @@ dictkeys_contains(PyObject *self, PyObject *obj)
     _PyDictViewObject *dv = (_PyDictViewObject *)self;
     if (dv->dv_dict == NULL)
         return 0;
-    return PyDict_Contains((PyObject *)dv->dv_dict, obj);
+    return dict_contains((PyObject *)dv->dv_dict, obj);
 }
 
 static PySequenceMethods dictkeys_as_sequence = {