]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-116738: Make _abc module thread-safe (#117488)
authorBrett Simmers <swtaarrs@users.noreply.github.com>
Thu, 11 Apr 2024 22:13:25 +0000 (15:13 -0700)
committerGitHub <noreply@github.com>
Thu, 11 Apr 2024 22:13:25 +0000 (18:13 -0400)
A collection of small changes aimed at making the `_abc` module safe to
use in a free-threaded build.

Include/internal/pycore_typeobject.h
Modules/_abc.c
Objects/typeobject.c

index 8a25935f308178fb34c6240806986d4777a3e92d..1693119ffece03f451fa276aa36745b707d479e5 100644 (file)
@@ -152,6 +152,18 @@ PyAPI_FUNC(PyObject*) _PySuper_Lookup(PyTypeObject *su_type, PyObject *su_obj,
 
 extern PyObject* _PyType_GetFullyQualifiedName(PyTypeObject *type, char sep);
 
+// Perform the following operation, in a thread-safe way when required by the
+// build mode.
+//
+// self->tp_flags = (self->tp_flags & ~mask) | flags;
+extern void _PyType_SetFlags(PyTypeObject *self, unsigned long mask,
+                             unsigned long flags);
+
+// Like _PyType_SetFlags(), but apply the operation to self and any of its
+// subclasses without Py_TPFLAGS_IMMUTABLETYPE set.
+extern void _PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask,
+                                      unsigned long flags);
+
 
 #ifdef __cplusplus
 }
index 399ecbbd6a21728d6dfbc9571fb4c558c72f5eaa..ad28035843fd321f7140389f4d11c8d99f3465e3 100644 (file)
@@ -21,7 +21,7 @@ PyDoc_STRVAR(_abc__doc__,
 
 typedef struct {
     PyTypeObject *_abc_data_type;
-    unsigned long long abc_invalidation_counter;
+    uint64_t abc_invalidation_counter;
 } _abcmodule_state;
 
 static inline _abcmodule_state*
@@ -32,17 +32,61 @@ get_abc_state(PyObject *module)
     return (_abcmodule_state *)state;
 }
 
+static inline uint64_t
+get_invalidation_counter(_abcmodule_state *state)
+{
+#ifdef Py_GIL_DISABLED
+    return _Py_atomic_load_uint64(&state->abc_invalidation_counter);
+#else
+    return state->abc_invalidation_counter;
+#endif
+}
+
+static inline void
+increment_invalidation_counter(_abcmodule_state *state)
+{
+#ifdef Py_GIL_DISABLED
+    _Py_atomic_add_uint64(&state->abc_invalidation_counter, 1);
+#else
+    state->abc_invalidation_counter++;
+#endif
+}
+
 /* This object stores internal state for ABCs.
    Note that we can use normal sets for caches,
    since they are never iterated over. */
 typedef struct {
     PyObject_HEAD
+    /* These sets of weak references are lazily created. Once created, they
+       will point to the same sets until the ABCMeta object is destroyed or
+       cleared, both of which will only happen while the object is visible to a
+       single thread. */
     PyObject *_abc_registry;
-    PyObject *_abc_cache; /* Normal set of weak references. */
-    PyObject *_abc_negative_cache; /* Normal set of weak references. */
-    unsigned long long _abc_negative_cache_version;
+    PyObject *_abc_cache;
+    PyObject *_abc_negative_cache;
+    uint64_t _abc_negative_cache_version;
 } _abc_data;
 
+static inline uint64_t
+get_cache_version(_abc_data *impl)
+{
+#ifdef Py_GIL_DISABLED
+    return _Py_atomic_load_uint64(&impl->_abc_negative_cache_version);
+#else
+    return impl->_abc_negative_cache_version;
+#endif
+}
+
+static inline void
+set_cache_version(_abc_data *impl, uint64_t version)
+{
+#ifdef Py_GIL_DISABLED
+    _Py_atomic_store_uint64(&impl->_abc_negative_cache_version, version);
+#else
+    impl->_abc_negative_cache_version = version;
+#endif
+}
+
 static int
 abc_data_traverse(_abc_data *self, visitproc visit, void *arg)
 {
@@ -90,7 +134,7 @@ abc_data_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
     self->_abc_registry = NULL;
     self->_abc_cache = NULL;
     self->_abc_negative_cache = NULL;
-    self->_abc_negative_cache_version = state->abc_invalidation_counter;
+    self->_abc_negative_cache_version = get_invalidation_counter(state);
     return (PyObject *) self;
 }
 
@@ -130,8 +174,12 @@ _get_impl(PyObject *module, PyObject *self)
 }
 
 static int
-_in_weak_set(PyObject *set, PyObject *obj)
+_in_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
 {
+    PyObject *set;
+    Py_BEGIN_CRITICAL_SECTION(impl);
+    set = *pset;
+    Py_END_CRITICAL_SECTION();
     if (set == NULL || PySet_GET_SIZE(set) == 0) {
         return 0;
     }
@@ -168,16 +216,19 @@ static PyMethodDef _destroy_def = {
 };
 
 static int
-_add_to_weak_set(PyObject **pset, PyObject *obj)
+_add_to_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
 {
-    if (*pset == NULL) {
-        *pset = PySet_New(NULL);
-        if (*pset == NULL) {
-            return -1;
-        }
+    PyObject *set;
+    Py_BEGIN_CRITICAL_SECTION(impl);
+    set = *pset;
+    if (set == NULL) {
+        set = *pset = PySet_New(NULL);
+    }
+    Py_END_CRITICAL_SECTION();
+    if (set == NULL) {
+        return -1;
     }
 
-    PyObject *set = *pset;
     PyObject *ref, *wr;
     PyObject *destroy_cb;
     wr = PyWeakref_NewRef(set, NULL);
@@ -220,7 +271,11 @@ _abc__reset_registry(PyObject *module, PyObject *self)
     if (impl == NULL) {
         return NULL;
     }
-    if (impl->_abc_registry != NULL && PySet_Clear(impl->_abc_registry) < 0) {
+    PyObject *registry;
+    Py_BEGIN_CRITICAL_SECTION(impl);
+    registry = impl->_abc_registry;
+    Py_END_CRITICAL_SECTION();
+    if (registry != NULL && PySet_Clear(registry) < 0) {
         Py_DECREF(impl);
         return NULL;
     }
@@ -247,13 +302,17 @@ _abc__reset_caches(PyObject *module, PyObject *self)
     if (impl == NULL) {
         return NULL;
     }
-    if (impl->_abc_cache != NULL && PySet_Clear(impl->_abc_cache) < 0) {
+    PyObject *cache, *negative_cache;
+    Py_BEGIN_CRITICAL_SECTION(impl);
+    cache = impl->_abc_cache;
+    negative_cache = impl->_abc_negative_cache;
+    Py_END_CRITICAL_SECTION();
+    if (cache != NULL && PySet_Clear(cache) < 0) {
         Py_DECREF(impl);
         return NULL;
     }
     /* also the second cache */
-    if (impl->_abc_negative_cache != NULL &&
-            PySet_Clear(impl->_abc_negative_cache) < 0) {
+    if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
         Py_DECREF(impl);
         return NULL;
     }
@@ -282,11 +341,14 @@ _abc__get_dump(PyObject *module, PyObject *self)
     if (impl == NULL) {
         return NULL;
     }
-    PyObject *res = Py_BuildValue("NNNK",
-                                  PySet_New(impl->_abc_registry),
-                                  PySet_New(impl->_abc_cache),
-                                  PySet_New(impl->_abc_negative_cache),
-                                  impl->_abc_negative_cache_version);
+    PyObject *res;
+    Py_BEGIN_CRITICAL_SECTION(impl);
+    res = Py_BuildValue("NNNK",
+                        PySet_New(impl->_abc_registry),
+                        PySet_New(impl->_abc_cache),
+                        PySet_New(impl->_abc_negative_cache),
+                        get_cache_version(impl));
+    Py_END_CRITICAL_SECTION();
     Py_DECREF(impl);
     return res;
 }
@@ -453,56 +515,27 @@ _abc__abc_init(PyObject *module, PyObject *self)
     if (PyType_Check(self)) {
         PyTypeObject *cls = (PyTypeObject *)self;
         PyObject *dict = _PyType_GetDict(cls);
-        PyObject *flags = PyDict_GetItemWithError(dict,
-                                                  &_Py_ID(__abc_tpflags__));
-        if (flags == NULL) {
-            if (PyErr_Occurred()) {
-                return NULL;
-            }
+        PyObject *flags = NULL;
+        if (PyDict_Pop(dict, &_Py_ID(__abc_tpflags__), &flags) < 0) {
+            return NULL;
         }
-        else {
-            if (PyLong_CheckExact(flags)) {
-                long val = PyLong_AsLong(flags);
-                if (val == -1 && PyErr_Occurred()) {
-                    return NULL;
-                }
-                if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
-                    PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
-                    return NULL;
-                }
-                ((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
-            }
-            if (PyDict_DelItem(dict, &_Py_ID(__abc_tpflags__)) < 0) {
-                return NULL;
-            }
+        if (flags == NULL || !PyLong_CheckExact(flags)) {
+            Py_XDECREF(flags);
+            Py_RETURN_NONE;
         }
-    }
-    Py_RETURN_NONE;
-}
-
-static void
-set_collection_flag_recursive(PyTypeObject *child, unsigned long flag)
-{
-    assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE);
-    if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) ||
-        (child->tp_flags & COLLECTION_FLAGS) == flag)
-    {
-        return;
-    }
-
-    child->tp_flags &= ~COLLECTION_FLAGS;
-    child->tp_flags |= flag;
-
-    PyObject *grandchildren = _PyType_GetSubclasses(child);
-    if (grandchildren == NULL) {
-        return;
-    }
 
-    for (Py_ssize_t i = 0; i < PyList_GET_SIZE(grandchildren); i++) {
-        PyObject *grandchild = PyList_GET_ITEM(grandchildren, i);
-        set_collection_flag_recursive((PyTypeObject *)grandchild, flag);
+        long val = PyLong_AsLong(flags);
+        Py_DECREF(flags);
+        if (val == -1 && PyErr_Occurred()) {
+            return NULL;
+        }
+        if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
+            PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
+            return NULL;
+        }
+        _PyType_SetFlags((PyTypeObject *)self, 0, val & COLLECTION_FLAGS);
     }
-    Py_DECREF(grandchildren);
+    Py_RETURN_NONE;
 }
 
 /*[clinic input]
@@ -545,20 +578,23 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
     if (impl == NULL) {
         return NULL;
     }
-    if (_add_to_weak_set(&impl->_abc_registry, subclass) < 0) {
+    if (_add_to_weak_set(impl, &impl->_abc_registry, subclass) < 0) {
         Py_DECREF(impl);
         return NULL;
     }
     Py_DECREF(impl);
 
     /* Invalidate negative cache */
-    get_abc_state(module)->abc_invalidation_counter++;
+    increment_invalidation_counter(get_abc_state(module));
 
-    /* Set Py_TPFLAGS_SEQUENCE  or Py_TPFLAGS_MAPPING flag */
+    /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
     if (PyType_Check(self)) {
-        unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS;
+        unsigned long collection_flag =
+            PyType_GetFlags((PyTypeObject *)self) & COLLECTION_FLAGS;
         if (collection_flag) {
-            set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag);
+            _PyType_SetFlagsRecursive((PyTypeObject *)subclass,
+                                      COLLECTION_FLAGS,
+                                      collection_flag);
         }
     }
     return Py_NewRef(subclass);
@@ -592,7 +628,7 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
         return NULL;
     }
     /* Inline the cache checking. */
-    int incache = _in_weak_set(impl->_abc_cache, subclass);
+    int incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
     if (incache < 0) {
         goto end;
     }
@@ -602,8 +638,8 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
     }
     subtype = (PyObject *)Py_TYPE(instance);
     if (subtype == subclass) {
-        if (impl->_abc_negative_cache_version == get_abc_state(module)->abc_invalidation_counter) {
-            incache = _in_weak_set(impl->_abc_negative_cache, subclass);
+        if (get_cache_version(impl) == get_invalidation_counter(get_abc_state(module))) {
+            incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
             if (incache < 0) {
                 goto end;
             }
@@ -681,7 +717,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
     }
 
     /* 1. Check cache. */
-    incache = _in_weak_set(impl->_abc_cache, subclass);
+    incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
     if (incache < 0) {
         goto end;
     }
@@ -692,17 +728,20 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
 
     state = get_abc_state(module);
     /* 2. Check negative cache; may have to invalidate. */
-    if (impl->_abc_negative_cache_version < state->abc_invalidation_counter) {
+    uint64_t invalidation_counter = get_invalidation_counter(state);
+    if (get_cache_version(impl) < invalidation_counter) {
         /* Invalidate the negative cache. */
-        if (impl->_abc_negative_cache != NULL &&
-                PySet_Clear(impl->_abc_negative_cache) < 0)
-        {
+        PyObject *negative_cache;
+        Py_BEGIN_CRITICAL_SECTION(impl);
+        negative_cache = impl->_abc_negative_cache;
+        Py_END_CRITICAL_SECTION();
+        if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
             goto end;
         }
-        impl->_abc_negative_cache_version = state->abc_invalidation_counter;
+        set_cache_version(impl, invalidation_counter);
     }
     else {
-        incache = _in_weak_set(impl->_abc_negative_cache, subclass);
+        incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
         if (incache < 0) {
             goto end;
         }
@@ -720,7 +759,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
     }
     if (ok == Py_True) {
         Py_DECREF(ok);
-        if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
+        if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
             goto end;
         }
         result = Py_True;
@@ -728,7 +767,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
     }
     if (ok == Py_False) {
         Py_DECREF(ok);
-        if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) {
+        if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
             goto end;
         }
         result = Py_False;
@@ -744,7 +783,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
 
     /* 4. Check if it's a direct subclass. */
     if (PyType_IsSubtype((PyTypeObject *)subclass, (PyTypeObject *)self)) {
-        if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
+        if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
             goto end;
         }
         result = Py_True;
@@ -767,12 +806,14 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
         goto end;
     }
     for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
-        PyObject *scls = PyList_GET_ITEM(subclasses, pos);
-        Py_INCREF(scls);
+        PyObject *scls = PyList_GetItemRef(subclasses, pos);
+        if (scls == NULL) {
+            goto end;
+        }
         int r = PyObject_IsSubclass(subclass, scls);
         Py_DECREF(scls);
         if (r > 0) {
-            if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
+            if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
                 goto end;
             }
             result = Py_True;
@@ -784,7 +825,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
     }
 
     /* No dice; update negative cache. */
-    if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) {
+    if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
         goto end;
     }
     result = Py_False;
@@ -801,7 +842,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
                              PyObject **result)
 {
     // Fast path: check subclass is in weakref directly.
-    int ret = _in_weak_set(impl->_abc_registry, subclass);
+    int ret = _in_weak_set(impl, &impl->_abc_registry, subclass);
     if (ret < 0) {
         *result = NULL;
         return -1;
@@ -811,33 +852,27 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
         return 1;
     }
 
-    if (impl->_abc_registry == NULL) {
+    PyObject *registry_shared;
+    Py_BEGIN_CRITICAL_SECTION(impl);
+    registry_shared = impl->_abc_registry;
+    Py_END_CRITICAL_SECTION();
+    if (registry_shared == NULL) {
         return 0;
     }
-    Py_ssize_t registry_size = PySet_Size(impl->_abc_registry);
-    if (registry_size == 0) {
-        return 0;
-    }
-    // Weakref callback may remove entry from set.
-    // So we take snapshot of registry first.
-    PyObject **copy = PyMem_Malloc(sizeof(PyObject*) * registry_size);
-    if (copy == NULL) {
-        PyErr_NoMemory();
+
+    // Make a local copy of the registry to protect against concurrent
+    // modifications of _abc_registry.
+    PyObject *registry = PySet_New(registry_shared);
+    if (registry == NULL) {
         return -1;
     }
     PyObject *key;
     Py_ssize_t pos = 0;
     Py_hash_t hash;
-    Py_ssize_t i = 0;
 
-    while (_PySet_NextEntry(impl->_abc_registry, &pos, &key, &hash)) {
-        copy[i++] = Py_NewRef(key);
-    }
-    assert(i == registry_size);
-
-    for (i = 0; i < registry_size; i++) {
+    while (_PySet_NextEntry(registry, &pos, &key, &hash)) {
         PyObject *rkey;
-        if (PyWeakref_GetRef(copy[i], &rkey) < 0) {
+        if (PyWeakref_GetRef(key, &rkey) < 0) {
             // Someone inject non-weakref type in the registry.
             ret = -1;
             break;
@@ -853,7 +888,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
             break;
         }
         if (r > 0) {
-            if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
+            if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
                 ret = -1;
                 break;
             }
@@ -863,10 +898,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
         }
     }
 
-    for (i = 0; i < registry_size; i++) {
-        Py_DECREF(copy[i]);
-    }
-    PyMem_Free(copy);
+    Py_DECREF(registry);
     return ret;
 }
 
@@ -885,7 +917,7 @@ _abc_get_cache_token_impl(PyObject *module)
 /*[clinic end generated code: output=c7d87841e033dacc input=70413d1c423ad9f9]*/
 {
     _abcmodule_state *state = get_abc_state(module);
-    return PyLong_FromUnsignedLongLong(state->abc_invalidation_counter);
+    return PyLong_FromUnsignedLongLong(get_invalidation_counter(state));
 }
 
 static struct PyMethodDef _abcmodule_methods[] = {
index e9f2d2577e9fabfc11cb3b3988306dadc2a866ab..3f38abfcfe5b11c04ea8ae1ceb23a31cf901553f 100644 (file)
@@ -5117,6 +5117,52 @@ _PyType_LookupId(PyTypeObject *type, _Py_Identifier *name)
     return _PyType_Lookup(type, oname);
 }
 
+static void
+set_flags(PyTypeObject *self, unsigned long mask, unsigned long flags)
+{
+    ASSERT_TYPE_LOCK_HELD();
+    self->tp_flags = (self->tp_flags & ~mask) | flags;
+}
+
+void
+_PyType_SetFlags(PyTypeObject *self, unsigned long mask, unsigned long flags)
+{
+    BEGIN_TYPE_LOCK();
+    set_flags(self, mask, flags);
+    END_TYPE_LOCK();
+}
+
+static void
+set_flags_recursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
+{
+    if (PyType_HasFeature(self, Py_TPFLAGS_IMMUTABLETYPE) ||
+        (self->tp_flags & mask) == flags)
+    {
+        return;
+    }
+
+    set_flags(self, mask, flags);
+
+    PyObject *children = _PyType_GetSubclasses(self);
+    if (children == NULL) {
+        return;
+    }
+
+    for (Py_ssize_t i = 0; i < PyList_GET_SIZE(children); i++) {
+        PyObject *child = PyList_GET_ITEM(children, i);
+        set_flags_recursive((PyTypeObject *)child, mask, flags);
+    }
+    Py_DECREF(children);
+}
+
+void
+_PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
+{
+    BEGIN_TYPE_LOCK();
+    set_flags_recursive(self, mask, flags);
+    END_TYPE_LOCK();
+}
+
 /* This is similar to PyObject_GenericGetAttr(),
    but uses _PyType_Lookup() instead of just looking in type->tp_dict.