]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-117578: Introduce _PyType_GetModuleByDef2 private function (GH-117661)
authorneonene <53406459+neonene@users.noreply.github.com>
Thu, 25 Apr 2024 11:51:31 +0000 (20:51 +0900)
committerGitHub <noreply@github.com>
Thu, 25 Apr 2024 11:51:31 +0000 (13:51 +0200)
Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
Co-authored-by: Petr Viktorin <encukou@gmail.com>
Include/internal/pycore_typeobject.h
Modules/_decimal/_decimal.c
Objects/typeobject.c

index 09c4501c38c935cb00ccc6062588f443b115bb8a..7e533bd138469baf4ce6485793b860881b5faa93 100644 (file)
@@ -164,6 +164,7 @@ extern PyObject * _PyType_GetBases(PyTypeObject *type);
 extern PyObject * _PyType_GetMRO(PyTypeObject *type);
 extern PyObject* _PyType_GetSubclasses(PyTypeObject *);
 extern int _PyType_HasSubclasses(PyTypeObject *);
+PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef2(PyTypeObject *, PyTypeObject *, PyModuleDef *);
 
 // PyType_Ready() must be called if _PyType_IsReady() is false.
 // See also the Py_TPFLAGS_READY flag.
index 2481455ac0d1437f429fb2183dedf2bb357fd2cd..fa425f4f740d31284e067ecf9047ff6c7a00b618 100644 (file)
@@ -32,6 +32,7 @@
 #include <Python.h>
 #include "pycore_long.h"          // _PyLong_IsZero()
 #include "pycore_pystate.h"       // _PyThreadState_GET()
+#include "pycore_typeobject.h"
 #include "complexobject.h"
 #include "mpdecimal.h"
 
@@ -120,11 +121,8 @@ get_module_state_by_def(PyTypeObject *tp)
 static inline decimal_state *
 find_state_left_or_right(PyObject *left, PyObject *right)
 {
-    PyObject *mod = PyType_GetModuleByDef(Py_TYPE(left), &_decimal_module);
-    if (mod == NULL) {
-        PyErr_Clear();
-        mod = PyType_GetModuleByDef(Py_TYPE(right), &_decimal_module);
-    }
+    PyObject *mod = _PyType_GetModuleByDef2(Py_TYPE(left), Py_TYPE(right),
+                                            &_decimal_module);
     assert(mod != NULL);
     return get_module_state(mod);
 }
index 808e11fcbaf1ff923a48fbee59e5079680d8f071..07e0a5a02da87ff53028f56d85f2aa906055e0a2 100644 (file)
@@ -4825,11 +4825,24 @@ PyType_GetModuleState(PyTypeObject *type)
 /* Get the module of the first superclass where the module has the
  * given PyModuleDef.
  */
-PyObject *
-PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
+static inline PyObject *
+get_module_by_def(PyTypeObject *type, PyModuleDef *def)
 {
     assert(PyType_Check(type));
 
+    if (!_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE)) {
+        // type_ready_mro() ensures that no heap type is
+        // contained in a static type MRO.
+        return NULL;
+    }
+    else {
+        PyHeapTypeObject *ht = (PyHeapTypeObject*)type;
+        PyObject *module = ht->ht_module;
+        if (module && _PyModule_GetDef(module) == def) {
+            return module;
+        }
+    }
+
     PyObject *res = NULL;
     BEGIN_TYPE_LOCK()
 
@@ -4837,12 +4850,14 @@ PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
     // The type must be ready
     assert(mro != NULL);
     assert(PyTuple_Check(mro));
-    // mro_invoke() ensures that the type MRO cannot be empty, so we don't have
-    // to check i < PyTuple_GET_SIZE(mro) at the first loop iteration.
+    // mro_invoke() ensures that the type MRO cannot be empty.
     assert(PyTuple_GET_SIZE(mro) >= 1);
+    // Also, the first item in the MRO is the type itself, which
+    // we already checked above. We skip it in the loop.
+    assert(PyTuple_GET_ITEM(mro, 0) == (PyObject *)type);
 
     Py_ssize_t n = PyTuple_GET_SIZE(mro);
-    for (Py_ssize_t i = 0; i < n; i++) {
+    for (Py_ssize_t i = 1; i < n; i++) {
         PyObject *super = PyTuple_GET_ITEM(mro, i);
         if(!_PyType_HasFeature((PyTypeObject *)super, Py_TPFLAGS_HEAPTYPE)) {
             // Static types in the MRO need to be skipped
@@ -4857,14 +4872,37 @@ PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
         }
     }
     END_TYPE_LOCK()
+    return res;
+}
 
-    if (res == NULL) {
+PyObject *
+PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
+{
+    PyObject *module = get_module_by_def(type, def);
+    if (module == NULL) {
         PyErr_Format(
             PyExc_TypeError,
             "PyType_GetModuleByDef: No superclass of '%s' has the given module",
             type->tp_name);
     }
-    return res;
+    return module;
+}
+
+PyObject *
+_PyType_GetModuleByDef2(PyTypeObject *left, PyTypeObject *right,
+                        PyModuleDef *def)
+{
+    PyObject *module = get_module_by_def(left, def);
+    if (module == NULL) {
+        module = get_module_by_def(right, def);
+        if (module == NULL) {
+            PyErr_Format(
+                PyExc_TypeError,
+                "PyType_GetModuleByDef: No superclass of '%s' nor '%s' has "
+                "the given module", left->tp_name, right->tp_name);
+        }
+    }
+    return module;
 }
 
 void *