]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-141510: Fix frozendict.fromkeys() for subclasses (#144952)
authorVictor Stinner <vstinner@python.org>
Wed, 18 Feb 2026 15:56:09 +0000 (16:56 +0100)
committerGitHub <noreply@github.com>
Wed, 18 Feb 2026 15:56:09 +0000 (15:56 +0000)
Copy the frozendict if needed.

Lib/test/test_dict.py
Objects/dictobject.c

index 21f8bb11071c90cb3f10d8536b346e9e7b5f7f11..1a8ae1cd42356e834b14583c17deb743358c6bf0 100644 (file)
@@ -1787,6 +1787,34 @@ class FrozenDictTests(unittest.TestCase):
         with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
             hash(fd)
 
+    def test_fromkeys(self):
+        self.assertEqual(frozendict.fromkeys('abc'),
+                         frozendict(a=None, b=None, c=None))
+
+        # Subclass which overrides the constructor
+        created = frozendict(x=1)
+        class FrozenDictSubclass(frozendict):
+            def __new__(self):
+                return created
+
+        fd = FrozenDictSubclass.fromkeys("abc")
+        self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
+        self.assertEqual(type(fd), FrozenDictSubclass)
+        self.assertEqual(created, frozendict(x=1))
+
+        fd = FrozenDictSubclass.fromkeys(frozendict(y=2))
+        self.assertEqual(fd, frozendict(x=1, y=None))
+        self.assertEqual(type(fd), FrozenDictSubclass)
+        self.assertEqual(created, frozendict(x=1))
+
+        # Subclass which doesn't override the constructor
+        class FrozenDictSubclass2(frozendict):
+            pass
+
+        fd = FrozenDictSubclass2.fromkeys("abc")
+        self.assertEqual(fd, frozendict(a=None, b=None, c=None))
+        self.assertEqual(type(fd), FrozenDictSubclass2)
+
 
 if __name__ == "__main__":
     unittest.main()
index 68602caf61401ad9b5db11943c7134ff954eb493..8d3c34f87e2afe1108aa94cc0c099cbf90948fcb 100644 (file)
@@ -138,6 +138,7 @@ As a consequence of this, split keys have a maximum size of 16.
 // Forward declarations
 static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
                                 PyObject *kwds);
+static int dict_merge(PyObject *a, PyObject *b, int override);
 
 
 /*[clinic input]
@@ -294,6 +295,8 @@ can_modify_dict(PyDictObject *mp)
         return PyUnstable_Object_IsUniquelyReferenced(_PyObject_CAST(mp));
     }
     else {
+        // Locking is only required if the dictionary is not
+        // uniquely referenced.
         ASSERT_DICT_LOCKED(mp);
         return 1;
     }
@@ -3238,6 +3241,8 @@ _PyDict_Pop(PyObject *dict, PyObject *key, PyObject *default_value)
 static PyDictObject *
 dict_dict_fromkeys(PyDictObject *mp, PyObject *iterable, PyObject *value)
 {
+    assert(can_modify_dict(mp));
+
     PyObject *oldvalue;
     Py_ssize_t pos = 0;
     PyObject *key;
@@ -3263,6 +3268,8 @@ dict_dict_fromkeys(PyDictObject *mp, PyObject *iterable, PyObject *value)
 static PyDictObject *
 dict_set_fromkeys(PyDictObject *mp, PyObject *iterable, PyObject *value)
 {
+    assert(can_modify_dict(mp));
+
     Py_ssize_t pos = 0;
     PyObject *key;
     Py_hash_t hash;
@@ -3294,9 +3301,31 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
     int status;
 
     d = _PyObject_CallNoArgs(cls);
-    if (d == NULL)
+    if (d == NULL) {
         return NULL;
+    }
 
+    // If cls is a frozendict subclass with overridden constructor,
+    // copy the frozendict.
+    PyTypeObject *cls_type = _PyType_CAST(cls);
+    if (PyFrozenDict_Check(d)
+        && PyObject_IsSubclass(cls, (PyObject*)&PyFrozenDict_Type)
+        && cls_type->tp_new != frozendict_new)
+    {
+        // Subclass-friendly copy
+        PyObject *copy = frozendict_new(cls_type, NULL, NULL);
+        if (copy == NULL) {
+            Py_DECREF(d);
+            return NULL;
+        }
+        if (dict_merge(copy, d, 1) < 0) {
+            Py_DECREF(d);
+            Py_DECREF(copy);
+            return NULL;
+        }
+        Py_SETREF(d, copy);
+    }
+    assert(!PyFrozenDict_Check(d) || can_modify_dict((PyDictObject*)d));
 
     if (PyDict_CheckExact(d)) {
         if (PyDict_CheckExact(iterable)) {
@@ -3367,7 +3396,7 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
 dict_iter_exit:;
         Py_END_CRITICAL_SECTION();
     }
-    else if (PyFrozenDict_CheckExact(d)) {
+    else if (PyFrozenDict_Check(d)) {
         while ((key = PyIter_Next(it)) != NULL) {
             // setitem_take2_lock_held consumes a reference to key
             status = setitem_take2_lock_held((PyDictObject *)d,
@@ -8002,6 +8031,8 @@ frozendict_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
     if (d == NULL) {
         return NULL;
     }
+    assert(can_modify_dict(_PyAnyDict_CAST(d)));
+
     PyFrozenDictObject *self = _PyFrozenDictObject_CAST(d);
     self->ma_hash = -1;