]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-122459: Optimize pickling by name objects without __module__ (GH-122460)
authorSerhiy Storchaka <storchaka@gmail.com>
Mon, 5 Aug 2024 13:21:32 +0000 (16:21 +0300)
committerGitHub <noreply@github.com>
Mon, 5 Aug 2024 13:21:32 +0000 (16:21 +0300)
Lib/pickle.py
Lib/test/pickletester.py
Misc/NEWS.d/next/Library/2024-07-30-15-57-07.gh-issue-122459.AYIoeN.rst [new file with mode: 0644]
Modules/_pickle.c

index 299c9e0e5e5641c16a8bc162cde20375128241a3..b8e114a79f22029b24747d4b029ee19be9e3a81e 100644 (file)
@@ -313,38 +313,45 @@ class _Unframer:
 
 # Tools used for pickling.
 
-def _getattribute(obj, name):
-    top = obj
-    for subpath in name.split('.'):
-        if subpath == '<locals>':
-            raise AttributeError("Can't get local attribute {!r} on {!r}"
-                                 .format(name, top))
-        try:
-            parent = obj
-            obj = getattr(obj, subpath)
-        except AttributeError:
-            raise AttributeError("Can't get attribute {!r} on {!r}"
-                                 .format(name, top)) from None
-    return obj, parent
+def _getattribute(obj, dotted_path):
+    for subpath in dotted_path:
+        obj = getattr(obj, subpath)
+    return obj
 
 def whichmodule(obj, name):
     """Find the module an object belong to."""
+    dotted_path = name.split('.')
     module_name = getattr(obj, '__module__', None)
-    if module_name is not None:
-        return module_name
-    # Protect the iteration by using a list copy of sys.modules against dynamic
-    # modules that trigger imports of other modules upon calls to getattr.
-    for module_name, module in sys.modules.copy().items():
-        if (module_name == '__main__'
-            or module_name == '__mp_main__'  # bpo-42406
-            or module is None):
-            continue
-        try:
-            if _getattribute(module, name)[0] is obj:
-                return module_name
-        except AttributeError:
-            pass
-    return '__main__'
+    if module_name is None and '<locals>' not in dotted_path:
+        # Protect the iteration by using a list copy of sys.modules against dynamic
+        # modules that trigger imports of other modules upon calls to getattr.
+        for module_name, module in sys.modules.copy().items():
+            if (module_name == '__main__'
+                or module_name == '__mp_main__'  # bpo-42406
+                or module is None):
+                continue
+            try:
+                if _getattribute(module, dotted_path) is obj:
+                    return module_name
+            except AttributeError:
+                pass
+        module_name = '__main__'
+    elif module_name is None:
+        module_name = '__main__'
+
+    try:
+        __import__(module_name, level=0)
+        module = sys.modules[module_name]
+        if _getattribute(module, dotted_path) is obj:
+            return module_name
+    except (ImportError, KeyError, AttributeError):
+        raise PicklingError(
+            "Can't pickle %r: it's not found as %s.%s" %
+            (obj, module_name, name)) from None
+
+    raise PicklingError(
+        "Can't pickle %r: it's not the same object as %s.%s" %
+        (obj, module_name, name))
 
 def encode_long(x):
     r"""Encode a long to a two's complement little-endian binary string.
@@ -1074,24 +1081,10 @@ class _Pickler:
 
         if name is None:
             name = getattr(obj, '__qualname__', None)
-        if name is None:
-            name = obj.__name__
+            if name is None:
+                name = obj.__name__
 
         module_name = whichmodule(obj, name)
-        try:
-            __import__(module_name, level=0)
-            module = sys.modules[module_name]
-            obj2, parent = _getattribute(module, name)
-        except (ImportError, KeyError, AttributeError):
-            raise PicklingError(
-                "Can't pickle %r: it's not found as %s.%s" %
-                (obj, module_name, name)) from None
-        else:
-            if obj2 is not obj:
-                raise PicklingError(
-                    "Can't pickle %r: it's not the same object as %s.%s" %
-                    (obj, module_name, name))
-
         if self.proto >= 2:
             code = _extension_registry.get((module_name, name))
             if code:
@@ -1103,10 +1096,7 @@ class _Pickler:
                 else:
                     write(EXT4 + pack("<i", code))
                 return
-        lastname = name.rpartition('.')[2]
-        if parent is module:
-            name = lastname
-        # Non-ASCII identifiers are supported only with protocols >= 3.
+
         if self.proto >= 4:
             self.save(module_name)
             self.save(name)
@@ -1616,7 +1606,16 @@ class _Unpickler:
                 module = _compat_pickle.IMPORT_MAPPING[module]
         __import__(module, level=0)
         if self.proto >= 4:
-            return _getattribute(sys.modules[module], name)[0]
+            module = sys.modules[module]
+            dotted_path = name.split('.')
+            if '<locals>' in dotted_path:
+                raise AttributeError(
+                    f"Can't get local attribute {name!r} on {module!r}")
+            try:
+                return _getattribute(module, dotted_path)
+            except AttributeError:
+                raise AttributeError(
+                    f"Can't get attribute {name!r} on {module!r}") from None
         else:
             return getattr(sys.modules[module], name)
 
index 3c936b3bc4029e81b58562ea474faec7cc64c01b..db42f13b0b98abf4fc93511e746862c736f5136e 100644 (file)
@@ -2068,7 +2068,7 @@ class AbstractPicklingErrorTests:
                     self.dumps(f, proto)
                 self.assertIn(str(cm.exception), {
                     f"Can't pickle {f!r}: it's not found as {__name__}.{f.__qualname__}",
-                    f"Can't get local object {f.__qualname__!r}"})
+                    f"Can't get local attribute {f.__qualname__!r} on {sys.modules[__name__]}"})
         # Same without a __module__ attribute (exercises a different path
         # in _pickle.c).
         del f.__module__
diff --git a/Misc/NEWS.d/next/Library/2024-07-30-15-57-07.gh-issue-122459.AYIoeN.rst b/Misc/NEWS.d/next/Library/2024-07-30-15-57-07.gh-issue-122459.AYIoeN.rst
new file mode 100644 (file)
index 0000000..5955040
--- /dev/null
@@ -0,0 +1,2 @@
+Optimize :mod:`pickling <pickle>` by name objects without the ``__module__``
+attribute.
index 50c73dca0db281e8283e94d8be655c7cb0263512..5d9ee8cb6c679d7d861acb986054963b1d7d6281 100644 (file)
@@ -1803,13 +1803,15 @@ memo_put(PickleState *st, PicklerObject *self, PyObject *obj)
 }
 
 static PyObject *
-get_dotted_path(PyObject *obj, PyObject *name)
+get_dotted_path(PyObject *name)
+{
+    return PyUnicode_Split(name, _Py_LATIN1_CHR('.'), -1);
+}
+
+static int
+check_dotted_path(PyObject *obj, PyObject *name, PyObject *dotted_path)
 {
-    PyObject *dotted_path;
     Py_ssize_t i, n;
-    dotted_path = PyUnicode_Split(name, _Py_LATIN1_CHR('.'), -1);
-    if (dotted_path == NULL)
-        return NULL;
     n = PyList_GET_SIZE(dotted_path);
     assert(n >= 1);
     for (i = 0; i < n; i++) {
@@ -1821,61 +1823,33 @@ get_dotted_path(PyObject *obj, PyObject *name)
             else
                 PyErr_Format(PyExc_AttributeError,
                              "Can't get local attribute %R on %R", name, obj);
-            Py_DECREF(dotted_path);
-            return NULL;
+            return -1;
         }
     }
-    return dotted_path;
+    return 0;
 }
 
 static PyObject *
-get_deep_attribute(PyObject *obj, PyObject *names, PyObject **pparent)
+getattribute(PyObject *obj, PyObject *names)
 {
     Py_ssize_t i, n;
-    PyObject *parent = NULL;
 
     assert(PyList_CheckExact(names));
     Py_INCREF(obj);
     n = PyList_GET_SIZE(names);
     for (i = 0; i < n; i++) {
         PyObject *name = PyList_GET_ITEM(names, i);
-        Py_XSETREF(parent, obj);
+        PyObject *parent = obj;
         (void)PyObject_GetOptionalAttr(parent, name, &obj);
+        Py_DECREF(parent);
         if (obj == NULL) {
-            Py_DECREF(parent);
             return NULL;
         }
     }
-    if (pparent != NULL)
-        *pparent = parent;
-    else
-        Py_XDECREF(parent);
     return obj;
 }
 
 
-static PyObject *
-getattribute(PyObject *obj, PyObject *name, int allow_qualname)
-{
-    PyObject *dotted_path, *attr;
-
-    if (allow_qualname) {
-        dotted_path = get_dotted_path(obj, name);
-        if (dotted_path == NULL)
-            return NULL;
-        attr = get_deep_attribute(obj, dotted_path, NULL);
-        Py_DECREF(dotted_path);
-    }
-    else {
-        (void)PyObject_GetOptionalAttr(obj, name, &attr);
-    }
-    if (attr == NULL && !PyErr_Occurred()) {
-        PyErr_Format(PyExc_AttributeError,
-                     "Can't get attribute %R on %R", name, obj);
-    }
-    return attr;
-}
-
 static int
 _checkmodule(PyObject *module_name, PyObject *module,
              PyObject *global, PyObject *dotted_path)
@@ -1888,7 +1862,7 @@ _checkmodule(PyObject *module_name, PyObject *module,
         return -1;
     }
 
-    PyObject *candidate = get_deep_attribute(module, dotted_path, NULL);
+    PyObject *candidate = getattribute(module, dotted_path);
     if (candidate == NULL) {
         return -1;
     }
@@ -1901,7 +1875,7 @@ _checkmodule(PyObject *module_name, PyObject *module,
 }
 
 static PyObject *
-whichmodule(PyObject *global, PyObject *dotted_path)
+whichmodule(PickleState *st, PyObject *global, PyObject *global_name, PyObject *dotted_path)
 {
     PyObject *module_name;
     PyObject *module = NULL;
@@ -1911,63 +1885,106 @@ whichmodule(PyObject *global, PyObject *dotted_path)
     if (PyObject_GetOptionalAttr(global, &_Py_ID(__module__), &module_name) < 0) {
         return NULL;
     }
-    if (module_name) {
+    if (module_name == NULL || module_name == Py_None) {
         /* In some rare cases (e.g., bound methods of extension types),
            __module__ can be None. If it is so, then search sys.modules for
            the module of global. */
-        if (module_name != Py_None)
-            return module_name;
         Py_CLEAR(module_name);
-    }
-    assert(module_name == NULL);
-
-    /* Fallback on walking sys.modules */
-    PyThreadState *tstate = _PyThreadState_GET();
-    modules = _PySys_GetAttr(tstate, &_Py_ID(modules));
-    if (modules == NULL) {
-        PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules");
-        return NULL;
-    }
-    if (PyDict_CheckExact(modules)) {
-        i = 0;
-        while (PyDict_Next(modules, &i, &module_name, &module)) {
-            if (_checkmodule(module_name, module, global, dotted_path) == 0) {
-                return Py_NewRef(module_name);
-            }
-            if (PyErr_Occurred()) {
-                return NULL;
-            }
+        if (check_dotted_path(NULL, global_name, dotted_path) < 0) {
+            return NULL;
         }
-    }
-    else {
-        PyObject *iterator = PyObject_GetIter(modules);
-        if (iterator == NULL) {
+        PyThreadState *tstate = _PyThreadState_GET();
+        modules = _PySys_GetAttr(tstate, &_Py_ID(modules));
+        if (modules == NULL) {
+            PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules");
             return NULL;
         }
-        while ((module_name = PyIter_Next(iterator))) {
-            module = PyObject_GetItem(modules, module_name);
-            if (module == NULL) {
+        if (PyDict_CheckExact(modules)) {
+            i = 0;
+            while (PyDict_Next(modules, &i, &module_name, &module)) {
+                Py_INCREF(module_name);
+                Py_INCREF(module);
+                if (_checkmodule(module_name, module, global, dotted_path) == 0) {
+                    Py_DECREF(module);
+                    return module_name;
+                }
+                Py_DECREF(module);
                 Py_DECREF(module_name);
-                Py_DECREF(iterator);
+                if (PyErr_Occurred()) {
+                    return NULL;
+                }
+            }
+        }
+        else {
+            PyObject *iterator = PyObject_GetIter(modules);
+            if (iterator == NULL) {
                 return NULL;
             }
-            if (_checkmodule(module_name, module, global, dotted_path) == 0) {
+            while ((module_name = PyIter_Next(iterator))) {
+                module = PyObject_GetItem(modules, module_name);
+                if (module == NULL) {
+                    Py_DECREF(module_name);
+                    Py_DECREF(iterator);
+                    return NULL;
+                }
+                if (_checkmodule(module_name, module, global, dotted_path) == 0) {
+                    Py_DECREF(module);
+                    Py_DECREF(iterator);
+                    return module_name;
+                }
                 Py_DECREF(module);
-                Py_DECREF(iterator);
-                return module_name;
-            }
-            Py_DECREF(module);
-            Py_DECREF(module_name);
-            if (PyErr_Occurred()) {
-                Py_DECREF(iterator);
-                return NULL;
+                Py_DECREF(module_name);
+                if (PyErr_Occurred()) {
+                    Py_DECREF(iterator);
+                    return NULL;
+                }
             }
+            Py_DECREF(iterator);
+        }
+        if (PyErr_Occurred()) {
+            return NULL;
         }
-        Py_DECREF(iterator);
+
+        /* If no module is found, use __main__. */
+        module_name = Py_NewRef(&_Py_ID(__main__));
     }
 
-    /* If no module is found, use __main__. */
-    return &_Py_ID(__main__);
+    /* XXX: Change to use the import C API directly with level=0 to disallow
+       relative imports.
+
+       XXX: PyImport_ImportModuleLevel could be used. However, this bypasses
+       builtins.__import__. Therefore, _pickle, unlike pickle.py, will ignore
+       custom import functions (IMHO, this would be a nice security
+       feature). The import C API would need to be extended to support the
+       extra parameters of __import__ to fix that. */
+    module = PyImport_Import(module_name);
+    if (module == NULL) {
+        PyErr_Format(st->PicklingError,
+                     "Can't pickle %R: import of module %R failed",
+                     global, module_name);
+        return NULL;
+    }
+    if (check_dotted_path(module, global_name, dotted_path) < 0) {
+        Py_DECREF(module);
+        return NULL;
+    }
+    PyObject *actual = getattribute(module, dotted_path);
+    Py_DECREF(module);
+    if (actual == NULL) {
+        PyErr_Format(st->PicklingError,
+                     "Can't pickle %R: attribute lookup %S on %S failed",
+                     global, global_name, module_name);
+        return NULL;
+    }
+    if (actual != global) {
+        Py_DECREF(actual);
+        PyErr_Format(st->PicklingError,
+                     "Can't pickle %R: it's not the same object as %S.%S",
+                     global, module_name, global_name);
+        return NULL;
+    }
+    Py_DECREF(actual);
+    return module_name;
 }
 
 /* fast_save_enter() and fast_save_leave() are guards against recursive
@@ -3590,10 +3607,7 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
 {
     PyObject *global_name = NULL;
     PyObject *module_name = NULL;
-    PyObject *module = NULL;
-    PyObject *parent = NULL;
     PyObject *dotted_path = NULL;
-    PyObject *cls;
     int status = 0;
 
     const char global_op = GLOBAL;
@@ -3611,44 +3625,13 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
         }
     }
 
-    dotted_path = get_dotted_path(module, global_name);
+    dotted_path = get_dotted_path(global_name);
     if (dotted_path == NULL)
         goto error;
-    module_name = whichmodule(obj, dotted_path);
+    module_name = whichmodule(st, obj, global_name, dotted_path);
     if (module_name == NULL)
         goto error;
 
-    /* XXX: Change to use the import C API directly with level=0 to disallow
-       relative imports.
-
-       XXX: PyImport_ImportModuleLevel could be used. However, this bypasses
-       builtins.__import__. Therefore, _pickle, unlike pickle.py, will ignore
-       custom import functions (IMHO, this would be a nice security
-       feature). The import C API would need to be extended to support the
-       extra parameters of __import__ to fix that. */
-    module = PyImport_Import(module_name);
-    if (module == NULL) {
-        PyErr_Format(st->PicklingError,
-                     "Can't pickle %R: import of module %R failed",
-                     obj, module_name);
-        goto error;
-    }
-    cls = get_deep_attribute(module, dotted_path, &parent);
-    if (cls == NULL) {
-        PyErr_Format(st->PicklingError,
-                     "Can't pickle %R: attribute lookup %S on %S failed",
-                     obj, global_name, module_name);
-        goto error;
-    }
-    if (cls != obj) {
-        Py_DECREF(cls);
-        PyErr_Format(st->PicklingError,
-                     "Can't pickle %R: it's not the same object as %S.%S",
-                     obj, module_name, global_name);
-        goto error;
-    }
-    Py_DECREF(cls);
-
     if (self->proto >= 2) {
         /* See whether this is in the extension registry, and if
          * so generate an EXT opcode.
@@ -3720,12 +3703,6 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
     }
     else {
   gen_global:
-        if (parent == module) {
-            Py_SETREF(global_name,
-                Py_NewRef(PyList_GET_ITEM(dotted_path,
-                                          PyList_GET_SIZE(dotted_path) - 1)));
-            Py_CLEAR(dotted_path);
-        }
         if (self->proto >= 4) {
             const char stack_global_op = STACK_GLOBAL;
 
@@ -3845,8 +3822,6 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
     }
     Py_XDECREF(module_name);
     Py_XDECREF(global_name);
-    Py_XDECREF(module);
-    Py_XDECREF(parent);
     Py_XDECREF(dotted_path);
 
     return status;
@@ -7063,7 +7038,27 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls,
     if (module == NULL) {
         return NULL;
     }
-    global = getattribute(module, global_name, self->proto >= 4);
+    if (self->proto >= 4) {
+        PyObject *dotted_path = get_dotted_path(global_name);
+        if (dotted_path == NULL) {
+            Py_DECREF(module);
+            return NULL;
+        }
+        if (check_dotted_path(module, global_name, dotted_path) < 0) {
+            Py_DECREF(dotted_path);
+            Py_DECREF(module);
+            return NULL;
+        }
+        global = getattribute(module, dotted_path);
+        Py_DECREF(dotted_path);
+        if (global == NULL && !PyErr_Occurred()) {
+            PyErr_Format(PyExc_AttributeError,
+                         "Can't get attribute %R on %R", global_name, module);
+        }
+    }
+    else {
+        global = PyObject_GetAttr(module, global_name);
+    }
     Py_DECREF(module);
     return global;
 }