]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-122311: Improve and unify pickle errors (GH-122771)
authorSerhiy Storchaka <storchaka@gmail.com>
Mon, 9 Sep 2024 12:04:51 +0000 (15:04 +0300)
committerGitHub <noreply@github.com>
Mon, 9 Sep 2024 12:04:51 +0000 (15:04 +0300)
* Raise PicklingError instead of UnicodeEncodeError, ValueError
  and AttributeError in both implementations.
* Chain the original exception to the pickle-specific one as __context__.
* Include the error message of ImportError and some AttributeError in
  the PicklingError error message.
* Unify error messages between Python and C implementations.
* Refer to documented __reduce__ and __newobj__ callables instead of
  internal methods (e.g. save_reduce()) or pickle opcodes (e.g. NEWOBJ).
* Include more details in error messages (what expected, what got).
* Avoid including a potentially long repr of an arbitrary object in
  error messages.

Lib/pickle.py
Lib/test/pickletester.py
Misc/NEWS.d/next/Library/2024-08-07-11-57-41.gh-issue-122311.LDExnJ.rst [new file with mode: 0644]
Modules/_pickle.c

index ade0491ca740b601c77d429d35a13621c79be522..f40b8e3adbe333895920399ea3f6015d7e840f4e 100644 (file)
@@ -322,7 +322,9 @@ def whichmodule(obj, name):
     """Find the module an object belong to."""
     dotted_path = name.split('.')
     module_name = getattr(obj, '__module__', None)
-    if module_name is None and '<locals>' not in dotted_path:
+    if '<locals>' in dotted_path:
+        raise PicklingError(f"Can't pickle local object {obj!r}")
+    if module_name is None:
         # Protect the iteration by using a list copy of sys.modules against dynamic
         # modules that trigger imports of other modules upon calls to getattr.
         for module_name, module in sys.modules.copy().items():
@@ -336,22 +338,21 @@ def whichmodule(obj, name):
             except AttributeError:
                 pass
         module_name = '__main__'
-    elif module_name is None:
-        module_name = '__main__'
 
     try:
         __import__(module_name, level=0)
         module = sys.modules[module_name]
+    except (ImportError, ValueError, KeyError) as exc:
+        raise PicklingError(f"Can't pickle {obj!r}: {exc!s}")
+    try:
         if _getattribute(module, dotted_path) is obj:
             return module_name
-    except (ImportError, KeyError, AttributeError):
-        raise PicklingError(
-            "Can't pickle %r: it's not found as %s.%s" %
-            (obj, module_name, name)) from None
+    except AttributeError:
+        raise PicklingError(f"Can't pickle {obj!r}: "
+                            f"it's not found as {module_name}.{name}")
 
     raise PicklingError(
-        "Can't pickle %r: it's not the same object as %s.%s" %
-        (obj, module_name, name))
+        f"Can't pickle {obj!r}: it's not the same object as {module_name}.{name}")
 
 def encode_long(x):
     r"""Encode a long to a two's complement little-endian binary string.
@@ -403,6 +404,13 @@ def decode_long(data):
     """
     return int.from_bytes(data, byteorder='little', signed=True)
 
+def _T(obj):
+    cls = type(obj)
+    module = cls.__module__
+    if module in (None, 'builtins', '__main__'):
+        return cls.__qualname__
+    return f'{module}.{cls.__qualname__}'
+
 
 _NoValue = object()
 
@@ -585,8 +593,7 @@ class _Pickler:
                     if reduce is not _NoValue:
                         rv = reduce()
                     else:
-                        raise PicklingError("Can't pickle %r object: %r" %
-                                            (t.__name__, obj))
+                        raise PicklingError(f"Can't pickle {_T(t)} object")
 
         # Check for string returned by reduce(), meaning "save as global"
         if isinstance(rv, str):
@@ -595,13 +602,13 @@ class _Pickler:
 
         # Assert that reduce() returned a tuple
         if not isinstance(rv, tuple):
-            raise PicklingError("%s must return string or tuple" % reduce)
+            raise PicklingError(f'__reduce__ must return a string or tuple, not {_T(rv)}')
 
         # Assert that it returned an appropriately sized tuple
         l = len(rv)
         if not (2 <= l <= 6):
-            raise PicklingError("Tuple returned by %s must have "
-                                "two to six elements" % reduce)
+            raise PicklingError("tuple returned by __reduce__ "
+                                "must contain 2 through 6 elements")
 
         # Save the reduce() output and finally memoize the object
         self.save_reduce(obj=obj, *rv)
@@ -626,10 +633,12 @@ class _Pickler:
                     dictitems=None, state_setter=None, *, obj=None):
         # This API is called by some subclasses
 
-        if not isinstance(args, tuple):
-            raise PicklingError("args from save_reduce() must be a tuple")
         if not callable(func):
-            raise PicklingError("func from save_reduce() must be callable")
+            raise PicklingError(f"first item of the tuple returned by __reduce__ "
+                                f"must be callable, not {_T(func)}")
+        if not isinstance(args, tuple):
+            raise PicklingError(f"second item of the tuple returned by __reduce__ "
+                                f"must be a tuple, not {_T(args)}")
 
         save = self.save
         write = self.write
@@ -638,11 +647,10 @@ class _Pickler:
         if self.proto >= 2 and func_name == "__newobj_ex__":
             cls, args, kwargs = args
             if not hasattr(cls, "__new__"):
-                raise PicklingError("args[0] from {} args has no __new__"
-                                    .format(func_name))
+                raise PicklingError("first argument to __newobj_ex__() has no __new__")
             if obj is not None and cls is not obj.__class__:
-                raise PicklingError("args[0] from {} args has the wrong class"
-                                    .format(func_name))
+                raise PicklingError(f"first argument to __newobj_ex__() "
+                                    f"must be {obj.__class__!r}, not {cls!r}")
             if self.proto >= 4:
                 save(cls)
                 save(args)
@@ -682,11 +690,10 @@ class _Pickler:
             # Python 2.2).
             cls = args[0]
             if not hasattr(cls, "__new__"):
-                raise PicklingError(
-                    "args[0] from __newobj__ args has no __new__")
+                raise PicklingError("first argument to __newobj__() has no __new__")
             if obj is not None and cls is not obj.__class__:
-                raise PicklingError(
-                    "args[0] from __newobj__ args has the wrong class")
+                raise PicklingError(f"first argument to __newobj__() "
+                                    f"must be {obj.__class__!r}, not {cls!r}")
             args = args[1:]
             save(cls)
             save(args)
@@ -1133,8 +1140,7 @@ class _Pickler:
     def _save_toplevel_by_name(self, module_name, name):
         if self.proto >= 3:
             # Non-ASCII identifiers are supported only with protocols >= 3.
-            self.write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
-                       bytes(name, "utf-8") + b'\n')
+            encoding = "utf-8"
         else:
             if self.fix_imports:
                 r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
@@ -1143,13 +1149,19 @@ class _Pickler:
                     module_name, name = r_name_mapping[(module_name, name)]
                 elif module_name in r_import_mapping:
                     module_name = r_import_mapping[module_name]
-            try:
-                self.write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
-                           bytes(name, "ascii") + b'\n')
-            except UnicodeEncodeError:
-                raise PicklingError(
-                    "can't pickle global identifier '%s.%s' using "
-                    "pickle protocol %i" % (module_name, name, self.proto)) from None
+            encoding = "ascii"
+        try:
+            self.write(GLOBAL + bytes(module_name, encoding) + b'\n')
+        except UnicodeEncodeError:
+            raise PicklingError(
+                f"can't pickle module identifier {module_name!r} using "
+                f"pickle protocol {self.proto}")
+        try:
+            self.write(bytes(name, encoding) + b'\n')
+        except UnicodeEncodeError:
+            raise PicklingError(
+                f"can't pickle global identifier {name!r} using "
+                f"pickle protocol {self.proto}")
 
     def save_type(self, obj):
         if obj is type(None):
@@ -1609,17 +1621,13 @@ class _Unpickler:
             elif module in _compat_pickle.IMPORT_MAPPING:
                 module = _compat_pickle.IMPORT_MAPPING[module]
         __import__(module, level=0)
-        if self.proto >= 4:
-            module = sys.modules[module]
+        if self.proto >= 4 and '.' in name:
             dotted_path = name.split('.')
-            if '<locals>' in dotted_path:
-                raise AttributeError(
-                    f"Can't get local attribute {name!r} on {module!r}")
             try:
-                return _getattribute(module, dotted_path)
+                return _getattribute(sys.modules[module], dotted_path)
             except AttributeError:
                 raise AttributeError(
-                    f"Can't get attribute {name!r} on {module!r}") from None
+                    f"Can't resolve path {name!r} on module {module!r}")
         else:
             return getattr(sys.modules[module], name)
 
index 2e16b6b741b0b9b1c520c24a2ec6ec460fbbfd2f..e2297e5dd1a8e78e8ab14f09c86e7932317c28e3 100644 (file)
@@ -1230,37 +1230,36 @@ class AbstractUnpickleTests:
 
         self.assertIs(unpickler4.find_class('builtins', 'str.upper'), str.upper)
         with self.assertRaisesRegex(AttributeError,
-                r"module 'builtins' has no attribute 'str\.upper'|"
-                r"Can't get attribute 'str\.upper' on <module 'builtins'"):
+                r"module 'builtins' has no attribute 'str\.upper'"):
             unpickler.find_class('builtins', 'str.upper')
 
         with self.assertRaisesRegex(AttributeError,
-                "module 'math' has no attribute 'spam'|"
-                "Can't get attribute 'spam' on <module 'math'"):
+                "module 'math' has no attribute 'spam'"):
             unpickler.find_class('math', 'spam')
         with self.assertRaisesRegex(AttributeError,
-                "Can't get attribute 'spam' on <module 'math'"):
+                "module 'math' has no attribute 'spam'"):
             unpickler4.find_class('math', 'spam')
         with self.assertRaisesRegex(AttributeError,
-                r"module 'math' has no attribute 'log\.spam'|"
-                r"Can't get attribute 'log\.spam' on <module 'math'"):
+                r"module 'math' has no attribute 'log\.spam'"):
             unpickler.find_class('math', 'log.spam')
         with self.assertRaisesRegex(AttributeError,
-                r"Can't get attribute 'log\.spam' on <module 'math'"):
+                r"Can't resolve path 'log\.spam' on module 'math'") as cm:
             unpickler4.find_class('math', 'log.spam')
+        self.assertEqual(str(cm.exception.__context__),
+            "'builtin_function_or_method' object has no attribute 'spam'")
         with self.assertRaisesRegex(AttributeError,
-                r"module 'math' has no attribute 'log\.<locals>\.spam'|"
-                r"Can't get attribute 'log\.<locals>\.spam' on <module 'math'"):
+                r"module 'math' has no attribute 'log\.<locals>\.spam'"):
             unpickler.find_class('math', 'log.<locals>.spam')
         with self.assertRaisesRegex(AttributeError,
-                r"Can't get local attribute 'log\.<locals>\.spam' on <module 'math'"):
+                r"Can't resolve path 'log\.<locals>\.spam' on module 'math'") as cm:
             unpickler4.find_class('math', 'log.<locals>.spam')
+        self.assertEqual(str(cm.exception.__context__),
+            "'builtin_function_or_method' object has no attribute '<locals>'")
         with self.assertRaisesRegex(AttributeError,
-                "module 'math' has no attribute ''|"
-                "Can't get attribute '' on <module 'math'"):
+                "module 'math' has no attribute ''"):
             unpickler.find_class('math', '')
         with self.assertRaisesRegex(AttributeError,
-                "Can't get attribute '' on <module 'math'"):
+                "module 'math' has no attribute ''"):
             unpickler4.find_class('math', '')
         self.assertRaises(ModuleNotFoundError, unpickler.find_class, 'spam', 'log')
         self.assertRaises(ValueError, unpickler.find_class, '', 'log')
@@ -1613,27 +1612,24 @@ class AbstractPicklingErrorTests:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f'{obj.__reduce_ex__!r} must return string or tuple',
-                    '__reduce__ must return a string or tuple'})
+                self.assertEqual(str(cm.exception),
+                    '__reduce__ must return a string or tuple, not list')
 
         obj = REX((print,))
         for proto in protocols:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f'Tuple returned by {obj.__reduce_ex__!r} must have two to six elements',
-                    'tuple returned by __reduce__ must contain 2 through 6 elements'})
+                self.assertEqual(str(cm.exception),
+                    'tuple returned by __reduce__ must contain 2 through 6 elements')
 
         obj = REX((print, (), None, None, None, None, None))
         for proto in protocols:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f'Tuple returned by {obj.__reduce_ex__!r} must have two to six elements',
-                    'tuple returned by __reduce__ must contain 2 through 6 elements'})
+                self.assertEqual(str(cm.exception),
+                    'tuple returned by __reduce__ must contain 2 through 6 elements')
 
     def test_bad_reconstructor(self):
         obj = REX((42, ()))
@@ -1641,9 +1637,9 @@ class AbstractPicklingErrorTests:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    'func from save_reduce() must be callable',
-                    'first item of the tuple returned by __reduce__ must be callable'})
+                self.assertEqual(str(cm.exception),
+                    'first item of the tuple returned by __reduce__ '
+                    'must be callable, not int')
 
     def test_unpickleable_reconstructor(self):
         obj = REX((UnpickleableCallable(), ()))
@@ -1658,9 +1654,9 @@ class AbstractPicklingErrorTests:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    'args from save_reduce() must be a tuple',
-                    'second item of the tuple returned by __reduce__ must be a tuple'})
+                self.assertEqual(str(cm.exception),
+                    'second item of the tuple returned by __reduce__ '
+                    'must be a tuple, not list')
 
     def test_unpickleable_reconstructor_args(self):
         obj = REX((print, (1, 2, UNPICKLEABLE)))
@@ -1677,16 +1673,16 @@ class AbstractPicklingErrorTests:
                     self.dumps(obj, proto)
                 self.assertIn(str(cm.exception), {
                     'tuple index out of range',
-                    '__newobj__ arglist is empty'})
+                    '__newobj__ expected at least 1 argument, got 0'})
 
         obj = REX((copyreg.__newobj__, [REX]))
         for proto in protocols[2:]:
             with self.subTest(proto=proto):
-                with self.assertRaises((IndexError, pickle.PicklingError)) as cm:
+                with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    'args from save_reduce() must be a tuple',
-                    'second item of the tuple returned by __reduce__ must be a tuple'})
+                self.assertEqual(str(cm.exception),
+                    'second item of the tuple returned by __reduce__ '
+                    'must be a tuple, not list')
 
     def test_bad_newobj_class(self):
         obj = REX((copyreg.__newobj__, (NoNew(),)))
@@ -1695,8 +1691,8 @@ class AbstractPicklingErrorTests:
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
                 self.assertIn(str(cm.exception), {
-                    'args[0] from __newobj__ args has no __new__',
-                    'args[0] from __newobj__ args is not a type'})
+                    'first argument to __newobj__() has no __new__',
+                    f'first argument to __newobj__() must be a class, not {__name__}.NoNew'})
 
     def test_wrong_newobj_class(self):
         obj = REX((copyreg.__newobj__, (str,)))
@@ -1705,14 +1701,14 @@ class AbstractPicklingErrorTests:
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
                 self.assertEqual(str(cm.exception),
-                    'args[0] from __newobj__ args has the wrong class')
+                    f'first argument to __newobj__() must be {REX!r}, not {str!r}')
 
     def test_unpickleable_newobj_class(self):
         class LocalREX(REX): pass
         obj = LocalREX((copyreg.__newobj__, (LocalREX,)))
         for proto in protocols:
             with self.subTest(proto=proto):
-                with self.assertRaises((pickle.PicklingError, AttributeError)):
+                with self.assertRaises(pickle.PicklingError):
                     self.dumps(obj, proto)
 
     def test_unpickleable_newobj_args(self):
@@ -1730,16 +1726,16 @@ class AbstractPicklingErrorTests:
                     self.dumps(obj, proto)
                 self.assertIn(str(cm.exception), {
                     'not enough values to unpack (expected 3, got 0)',
-                    'length of the NEWOBJ_EX argument tuple must be exactly 3, not 0'})
+                    '__newobj_ex__ expected 3 arguments, got 0'})
 
         obj = REX((copyreg.__newobj_ex__, 42))
         for proto in protocols[2:]:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    'args from save_reduce() must be a tuple',
-                    'second item of the tuple returned by __reduce__ must be a tuple'})
+                self.assertEqual(str(cm.exception),
+                    'second item of the tuple returned by __reduce__ '
+                    'must be a tuple, not int')
 
         obj = REX((copyreg.__newobj_ex__, (REX, 42, {})))
         if self.pickler is pickle._Pickler:
@@ -1755,7 +1751,7 @@ class AbstractPicklingErrorTests:
                     with self.assertRaises(pickle.PicklingError) as cm:
                         self.dumps(obj, proto)
                     self.assertEqual(str(cm.exception),
-                        'second item from NEWOBJ_EX argument tuple must be a tuple, not int')
+                        'second argument to __newobj_ex__() must be a tuple, not int')
 
         obj = REX((copyreg.__newobj_ex__, (REX, (), [])))
         if self.pickler is pickle._Pickler:
@@ -1771,7 +1767,7 @@ class AbstractPicklingErrorTests:
                     with self.assertRaises(pickle.PicklingError) as cm:
                         self.dumps(obj, proto)
                     self.assertEqual(str(cm.exception),
-                        'third item from NEWOBJ_EX argument tuple must be a dict, not list')
+                        'third argument to __newobj_ex__() must be a dict, not list')
 
     def test_bad_newobj_ex__class(self):
         obj = REX((copyreg.__newobj_ex__, (NoNew(), (), {})))
@@ -1780,8 +1776,8 @@ class AbstractPicklingErrorTests:
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
                 self.assertIn(str(cm.exception), {
-                    'args[0] from __newobj_ex__ args has no __new__',
-                    'first item from NEWOBJ_EX argument tuple must be a class, not NoNew'})
+                    'first argument to __newobj_ex__() has no __new__',
+                    f'first argument to __newobj_ex__() must be a class, not {__name__}.NoNew'})
 
     def test_wrong_newobj_ex_class(self):
         if self.pickler is not pickle._Pickler:
@@ -1792,14 +1788,14 @@ class AbstractPicklingErrorTests:
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
                 self.assertEqual(str(cm.exception),
-                    'args[0] from __newobj_ex__ args has the wrong class')
+                    f'first argument to __newobj_ex__() must be {REX}, not {str}')
 
     def test_unpickleable_newobj_ex_class(self):
         class LocalREX(REX): pass
         obj = LocalREX((copyreg.__newobj_ex__, (LocalREX, (), {})))
         for proto in protocols:
             with self.subTest(proto=proto):
-                with self.assertRaises((pickle.PicklingError, AttributeError)):
+                with self.assertRaises(pickle.PicklingError):
                     self.dumps(obj, proto)
 
     def test_unpickleable_newobj_ex_args(self):
@@ -1832,7 +1828,8 @@ class AbstractPicklingErrorTests:
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
                 self.assertEqual(str(cm.exception),
-                    'sixth element of the tuple returned by __reduce__ must be a function, not int')
+                    'sixth item of the tuple returned by __reduce__ '
+                    'must be callable, not int')
 
     def test_unpickleable_state_setter(self):
         obj = REX((print, (), 'state', None, None, UnpickleableCallable()))
@@ -1858,18 +1855,19 @@ class AbstractPicklingErrorTests:
                     self.dumps(obj, proto)
                 self.assertIn(str(cm.exception), {
                     "'int' object is not iterable",
-                    'fourth element of the tuple returned by __reduce__ must be an iterator, not int'})
+                    'fourth item of the tuple returned by __reduce__ '
+                    'must be an iterator, not int'})
 
         if self.pickler is not pickle._Pickler:
             # Python implementation is less strict and also accepts iterables.
             obj = REX((list, (), None, []))
             for proto in protocols:
                 with self.subTest(proto=proto):
-                    with self.assertRaises((TypeError, pickle.PicklingError)):
+                    with self.assertRaises(pickle.PicklingError):
                         self.dumps(obj, proto)
-                    self.assertIn(str(cm.exception), {
-                        "'int' object is not iterable",
-                        'fourth element of the tuple returned by __reduce__ must be an iterator, not int'})
+                    self.assertEqual(str(cm.exception),
+                        'fourth item of the tuple returned by __reduce__ '
+                        'must be an iterator, not int')
 
     def test_unpickleable_object_list_items(self):
         obj = REX_six([1, 2, UNPICKLEABLE])
@@ -1888,7 +1886,8 @@ class AbstractPicklingErrorTests:
                     self.dumps(obj, proto)
                 self.assertIn(str(cm.exception), {
                     "'int' object is not iterable",
-                    'fifth element of the tuple returned by __reduce__ must be an iterator, not int'})
+                    'fifth item of the tuple returned by __reduce__ '
+                    'must be an iterator, not int'})
 
         for proto in protocols:
             obj = REX((dict, (), None, None, iter([('a',)])))
@@ -1904,7 +1903,7 @@ class AbstractPicklingErrorTests:
             obj = REX((dict, (), None, None, []))
             for proto in protocols:
                 with self.subTest(proto=proto):
-                    with self.assertRaises((TypeError, pickle.PicklingError)):
+                    with self.assertRaises(pickle.PicklingError):
                         self.dumps(obj, proto)
                     self.assertEqual(str(cm.exception),
                         'dict items iterator must return 2-tuples')
@@ -1977,36 +1976,40 @@ class AbstractPicklingErrorTests:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {obj!r}: it's not found as {__name__}.spam",
-                    f"Can't pickle {obj!r}: attribute lookup spam on {__name__} failed"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle {obj!r}: it's not found as {__name__}.spam")
+                self.assertEqual(str(cm.exception.__context__),
+                    f"module '{__name__}' has no attribute 'spam'")
 
         obj.__module__ = 'nonexisting'
         for proto in protocols:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {obj!r}: it's not found as nonexisting.spam",
-                    f"Can't pickle {obj!r}: import of module 'nonexisting' failed"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle {obj!r}: No module named 'nonexisting'")
+                self.assertEqual(str(cm.exception.__context__),
+                    "No module named 'nonexisting'")
 
         obj.__module__ = ''
         for proto in protocols:
             with self.subTest(proto=proto):
-                with self.assertRaises((ValueError, pickle.PicklingError)) as cm:
+                with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    'Empty module name',
-                    f"Can't pickle {obj!r}: import of module '' failed"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle {obj!r}: Empty module name")
+                self.assertEqual(str(cm.exception.__context__),
+                    "Empty module name")
 
         obj.__module__ = None
         for proto in protocols:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {obj!r}: it's not found as __main__.spam",
-                    f"Can't pickle {obj!r}: attribute lookup spam on __main__ failed"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle {obj!r}: it's not found as __main__.spam")
+                self.assertEqual(str(cm.exception.__context__),
+                    "module '__main__' has no attribute 'spam'")
 
     def test_nonencodable_global_name_error(self):
         for proto in protocols[:4]:
@@ -2015,15 +2018,11 @@ class AbstractPicklingErrorTests:
                 obj = REX(name)
                 obj.__module__ = __name__
                 with support.swap_item(globals(), name, obj):
-                    if proto == 3 and self.pickler is pickle._Pickler:
-                        with self.assertRaises(UnicodeEncodeError):
-                            self.dumps(obj, proto)
-                    else:
-                        with self.assertRaises(pickle.PicklingError) as cm:
-                            self.dumps(obj, proto)
-                        self.assertIn(str(cm.exception), {
-                            f"can't pickle global identifier '{__name__}.{name}' using pickle protocol {proto}",
-                            f"can't pickle global identifier '{name}' using pickle protocol {proto}"})
+                    with self.assertRaises(pickle.PicklingError) as cm:
+                        self.dumps(obj, proto)
+                    self.assertEqual(str(cm.exception),
+                        f"can't pickle global identifier {name!r} using pickle protocol {proto}")
+                    self.assertIsInstance(cm.exception.__context__, UnicodeEncodeError)
 
     def test_nonencodable_module_name_error(self):
         for proto in protocols[:4]:
@@ -2033,15 +2032,11 @@ class AbstractPicklingErrorTests:
                 obj.__module__ = name
                 mod = types.SimpleNamespace(test=obj)
                 with support.swap_item(sys.modules, name, mod):
-                    if proto == 3 and self.pickler is pickle._Pickler:
-                        with self.assertRaises(UnicodeEncodeError):
-                            self.dumps(obj, proto)
-                    else:
-                        with self.assertRaises(pickle.PicklingError) as cm:
-                            self.dumps(obj, proto)
-                        self.assertIn(str(cm.exception), {
-                            f"can't pickle global identifier '{name}.test' using pickle protocol {proto}",
-                            f"can't pickle module identifier '{name}' using pickle protocol {proto}"})
+                    with self.assertRaises(pickle.PicklingError) as cm:
+                        self.dumps(obj, proto)
+                    self.assertEqual(str(cm.exception),
+                        f"can't pickle module identifier {name!r} using pickle protocol {proto}")
+                    self.assertIsInstance(cm.exception.__context__, UnicodeEncodeError)
 
     def test_nested_lookup_error(self):
         # Nested name does not exist
@@ -2051,18 +2046,21 @@ class AbstractPicklingErrorTests:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {obj!r}: it's not found as {__name__}.AbstractPickleTests.spam",
-                    f"Can't pickle {obj!r}: attribute lookup AbstractPickleTests.spam on {__name__} failed"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle {obj!r}: "
+                    f"it's not found as {__name__}.AbstractPickleTests.spam")
+                self.assertEqual(str(cm.exception.__context__),
+                    "type object 'AbstractPickleTests' has no attribute 'spam'")
 
         obj.__module__ = None
         for proto in protocols:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests.spam",
-                    f"Can't pickle {obj!r}: attribute lookup AbstractPickleTests.spam on __main__ failed"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests.spam")
+                self.assertEqual(str(cm.exception.__context__),
+                    "module '__main__' has no attribute 'AbstractPickleTests'")
 
     def test_wrong_object_lookup_error(self):
         # Name is bound to different object
@@ -2075,15 +2073,17 @@ class AbstractPicklingErrorTests:
                     self.dumps(obj, proto)
                 self.assertEqual(str(cm.exception),
                     f"Can't pickle {obj!r}: it's not the same object as {__name__}.AbstractPickleTests")
+                self.assertIsNone(cm.exception.__context__)
 
         obj.__module__ = None
         for proto in protocols:
             with self.subTest(proto=proto):
                 with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(obj, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests",
-                    f"Can't pickle {obj!r}: attribute lookup AbstractPickleTests on __main__ failed"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle {obj!r}: it's not found as __main__.AbstractPickleTests")
+                self.assertEqual(str(cm.exception.__context__),
+                    "module '__main__' has no attribute 'AbstractPickleTests'")
 
     def test_local_lookup_error(self):
         # Test that whichmodule() errors out cleanly when looking up
@@ -2093,30 +2093,27 @@ class AbstractPicklingErrorTests:
         # Since the function is local, lookup will fail
         for proto in protocols:
             with self.subTest(proto=proto):
-                with self.assertRaises((AttributeError, pickle.PicklingError)) as cm:
+                with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(f, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {f!r}: it's not found as {__name__}.{f.__qualname__}",
-                    f"Can't get local attribute {f.__qualname__!r} on {sys.modules[__name__]}"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle local object {f!r}")
         # Same without a __module__ attribute (exercises a different path
         # in _pickle.c).
         del f.__module__
         for proto in protocols:
             with self.subTest(proto=proto):
-                with self.assertRaises((AttributeError, pickle.PicklingError)) as cm:
+                with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(f, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {f!r}: it's not found as __main__.{f.__qualname__}",
-                    f"Can't get local object {f.__qualname__!r}"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle local object {f!r}")
         # Yet a different path.
         f.__name__ = f.__qualname__
         for proto in protocols:
             with self.subTest(proto=proto):
-                with self.assertRaises((AttributeError, pickle.PicklingError)) as cm:
+                with self.assertRaises(pickle.PicklingError) as cm:
                     self.dumps(f, proto)
-                self.assertIn(str(cm.exception), {
-                    f"Can't pickle {f!r}: it's not found as __main__.{f.__qualname__}",
-                    f"Can't get local object {f.__qualname__!r}"})
+                self.assertEqual(str(cm.exception),
+                    f"Can't pickle local object {f!r}")
 
     def test_reduce_ex_None(self):
         c = REX_None()
diff --git a/Misc/NEWS.d/next/Library/2024-08-07-11-57-41.gh-issue-122311.LDExnJ.rst b/Misc/NEWS.d/next/Library/2024-08-07-11-57-41.gh-issue-122311.LDExnJ.rst
new file mode 100644 (file)
index 0000000..07ade20
--- /dev/null
@@ -0,0 +1,5 @@
+Improve errors in the :mod:`pickle` module. :exc:`~pickle.PicklingError` is
+now raised more often instead of :exc:`UnicodeEncodeError`,
+:exc:`ValueError` and :exc:`AttributeError`, and the original exception is
+chained to it. Improve and unify error messages in Python and C
+implementations.
index e0eda6080e42b3fce0c59b5d42da3049378170f4..9863baf0a29660932143ebaf1f94539e8f6fd230 100644 (file)
@@ -1809,7 +1809,7 @@ get_dotted_path(PyObject *name)
 }
 
 static int
-check_dotted_path(PyObject *obj, PyObject *name, PyObject *dotted_path)
+check_dotted_path(PickleState *st, PyObject *obj, PyObject *dotted_path)
 {
     Py_ssize_t i, n;
     n = PyList_GET_SIZE(dotted_path);
@@ -1817,12 +1817,8 @@ check_dotted_path(PyObject *obj, PyObject *name, PyObject *dotted_path)
     for (i = 0; i < n; i++) {
         PyObject *subpath = PyList_GET_ITEM(dotted_path, i);
         if (_PyUnicode_EqualToASCIIString(subpath, "<locals>")) {
-            if (obj == NULL)
-                PyErr_Format(PyExc_AttributeError,
-                             "Can't get local object %R", name);
-            else
-                PyErr_Format(PyExc_AttributeError,
-                             "Can't get local attribute %R on %R", name, obj);
+            PyErr_Format(st->PicklingError,
+                         "Can't pickle local object %R", obj);
             return -1;
         }
     }
@@ -1830,7 +1826,7 @@ check_dotted_path(PyObject *obj, PyObject *name, PyObject *dotted_path)
 }
 
 static PyObject *
-getattribute(PyObject *obj, PyObject *names)
+getattribute(PyObject *obj, PyObject *names, int raises)
 {
     Py_ssize_t i, n;
 
@@ -1840,7 +1836,12 @@ getattribute(PyObject *obj, PyObject *names)
     for (i = 0; i < n; i++) {
         PyObject *name = PyList_GET_ITEM(names, i);
         PyObject *parent = obj;
-        (void)PyObject_GetOptionalAttr(parent, name, &obj);
+        if (raises) {
+            obj = PyObject_GetAttr(parent, name);
+        }
+        else {
+            (void)PyObject_GetOptionalAttr(parent, name, &obj);
+        }
         Py_DECREF(parent);
         if (obj == NULL) {
             return NULL;
@@ -1849,7 +1850,6 @@ getattribute(PyObject *obj, PyObject *names)
     return obj;
 }
 
-
 static int
 _checkmodule(PyObject *module_name, PyObject *module,
              PyObject *global, PyObject *dotted_path)
@@ -1862,7 +1862,7 @@ _checkmodule(PyObject *module_name, PyObject *module,
         return -1;
     }
 
-    PyObject *candidate = getattribute(module, dotted_path);
+    PyObject *candidate = getattribute(module, dotted_path, 0);
     if (candidate == NULL) {
         return -1;
     }
@@ -1882,6 +1882,9 @@ whichmodule(PickleState *st, PyObject *global, PyObject *global_name, PyObject *
     Py_ssize_t i;
     PyObject *modules;
 
+    if (check_dotted_path(st, global, dotted_path) < 0) {
+        return NULL;
+    }
     if (PyObject_GetOptionalAttr(global, &_Py_ID(__module__), &module_name) < 0) {
         return NULL;
     }
@@ -1890,9 +1893,6 @@ whichmodule(PickleState *st, PyObject *global, PyObject *global_name, PyObject *
            __module__ can be None. If it is so, then search sys.modules for
            the module of global. */
         Py_CLEAR(module_name);
-        if (check_dotted_path(NULL, global_name, dotted_path) < 0) {
-            return NULL;
-        }
         PyThreadState *tstate = _PyThreadState_GET();
         modules = _PySys_GetAttr(tstate, &_Py_ID(modules));
         if (modules == NULL) {
@@ -1959,23 +1959,28 @@ whichmodule(PickleState *st, PyObject *global, PyObject *global_name, PyObject *
        extra parameters of __import__ to fix that. */
     module = PyImport_Import(module_name);
     if (module == NULL) {
-        PyErr_Format(st->PicklingError,
-                     "Can't pickle %R: import of module %R failed",
-                     global, module_name);
-        Py_DECREF(module_name);
-        return NULL;
-    }
-    if (check_dotted_path(module, global_name, dotted_path) < 0) {
+        if (PyErr_ExceptionMatches(PyExc_ImportError) ||
+            PyErr_ExceptionMatches(PyExc_ValueError))
+        {
+            PyObject *exc = PyErr_GetRaisedException();
+            PyErr_Format(st->PicklingError,
+                         "Can't pickle %R: %S", global, exc);
+            _PyErr_ChainExceptions1(exc);
+        }
         Py_DECREF(module_name);
-        Py_DECREF(module);
         return NULL;
     }
-    PyObject *actual = getattribute(module, dotted_path);
+    PyObject *actual = getattribute(module, dotted_path, 1);
     Py_DECREF(module);
     if (actual == NULL) {
-        PyErr_Format(st->PicklingError,
-                     "Can't pickle %R: attribute lookup %S on %S failed",
-                     global, global_name, module_name);
+        assert(PyErr_Occurred());
+        if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
+            PyObject *exc = PyErr_GetRaisedException();
+            PyErr_Format(st->PicklingError,
+                         "Can't pickle %R: it's not found as %S.%S",
+                         global, module_name, global_name);
+            _PyErr_ChainExceptions1(exc);
+        }
         Py_DECREF(module_name);
         return NULL;
     }
@@ -3759,11 +3764,14 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
             }
             encoded = unicode_encoder(module_name);
             if (encoded == NULL) {
-                if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
+                if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) {
+                    PyObject *exc = PyErr_GetRaisedException();
                     PyErr_Format(st->PicklingError,
-                                 "can't pickle module identifier '%S' using "
+                                 "can't pickle module identifier %R using "
                                  "pickle protocol %i",
                                  module_name, self->proto);
+                    _PyErr_ChainExceptions1(exc);
+                }
                 goto error;
             }
             if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
@@ -3778,11 +3786,14 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
             /* Save the name of the module. */
             encoded = unicode_encoder(global_name);
             if (encoded == NULL) {
-                if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError))
+                if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) {
+                    PyObject *exc = PyErr_GetRaisedException();
                     PyErr_Format(st->PicklingError,
-                                 "can't pickle global identifier '%S' using "
+                                 "can't pickle global identifier %R using "
                                  "pickle protocol %i",
                                  global_name, self->proto);
+                    _PyErr_ChainExceptions1(exc);
+                }
                 goto error;
             }
             if (_Pickler_Write(self, PyBytes_AS_STRING(encoded),
@@ -3943,8 +3954,9 @@ save_reduce(PickleState *st, PicklerObject *self, PyObject *args,
 
     size = PyTuple_Size(args);
     if (size < 2 || size > 6) {
-        PyErr_SetString(st->PicklingError, "tuple returned by "
-                        "__reduce__ must contain 2 through 6 elements");
+        PyErr_SetString(st->PicklingError,
+                        "tuple returned by __reduce__ "
+                        "must contain 2 through 6 elements");
         return -1;
     }
 
@@ -3954,13 +3966,15 @@ save_reduce(PickleState *st, PicklerObject *self, PyObject *args,
         return -1;
 
     if (!PyCallable_Check(callable)) {
-        PyErr_SetString(st->PicklingError, "first item of the tuple "
-                        "returned by __reduce__ must be callable");
+        PyErr_Format(st->PicklingError,
+                     "first item of the tuple returned by __reduce__ "
+                     "must be callable, not %T", callable);
         return -1;
     }
     if (!PyTuple_Check(argtup)) {
-        PyErr_SetString(st->PicklingError, "second item of the tuple "
-                        "returned by __reduce__ must be a tuple");
+        PyErr_Format(st->PicklingError,
+                     "second item of the tuple returned by __reduce__ "
+                     "must be a tuple, not %T", argtup);
         return -1;
     }
 
@@ -3970,27 +3984,27 @@ save_reduce(PickleState *st, PicklerObject *self, PyObject *args,
     if (listitems == Py_None)
         listitems = NULL;
     else if (!PyIter_Check(listitems)) {
-        PyErr_Format(st->PicklingError, "fourth element of the tuple "
-                     "returned by __reduce__ must be an iterator, not %s",
-                     Py_TYPE(listitems)->tp_name);
+        PyErr_Format(st->PicklingError,
+                     "fourth item of the tuple returned by __reduce__ "
+                     "must be an iterator, not %T", listitems);
         return -1;
     }
 
     if (dictitems == Py_None)
         dictitems = NULL;
     else if (!PyIter_Check(dictitems)) {
-        PyErr_Format(st->PicklingError, "fifth element of the tuple "
-                     "returned by __reduce__ must be an iterator, not %s",
-                     Py_TYPE(dictitems)->tp_name);
+        PyErr_Format(st->PicklingError,
+                     "fifth item of the tuple returned by __reduce__ "
+                     "must be an iterator, not %T", dictitems);
         return -1;
     }
 
     if (state_setter == Py_None)
         state_setter = NULL;
     else if (!PyCallable_Check(state_setter)) {
-        PyErr_Format(st->PicklingError, "sixth element of the tuple "
-                     "returned by __reduce__ must be a function, not %s",
-                     Py_TYPE(state_setter)->tp_name);
+        PyErr_Format(st->PicklingError,
+                     "sixth item of the tuple returned by __reduce__ "
+                     "must be callable, not %T", state_setter);
         return -1;
     }
 
@@ -4016,30 +4030,30 @@ save_reduce(PickleState *st, PicklerObject *self, PyObject *args,
 
         if (PyTuple_GET_SIZE(argtup) != 3) {
             PyErr_Format(st->PicklingError,
-                         "length of the NEWOBJ_EX argument tuple must be "
-                         "exactly 3, not %zd", PyTuple_GET_SIZE(argtup));
+                         "__newobj_ex__ expected 3 arguments, got %zd",
+                         PyTuple_GET_SIZE(argtup));
             return -1;
         }
 
         cls = PyTuple_GET_ITEM(argtup, 0);
         if (!PyType_Check(cls)) {
             PyErr_Format(st->PicklingError,
-                         "first item from NEWOBJ_EX argument tuple must "
-                         "be a class, not %.200s", Py_TYPE(cls)->tp_name);
+                         "first argument to __newobj_ex__() "
+                         "must be a class, not %T", cls);
             return -1;
         }
         args = PyTuple_GET_ITEM(argtup, 1);
         if (!PyTuple_Check(args)) {
             PyErr_Format(st->PicklingError,
-                         "second item from NEWOBJ_EX argument tuple must "
-                         "be a tuple, not %.200s", Py_TYPE(args)->tp_name);
+                         "second argument to __newobj_ex__() "
+                         "must be a tuple, not %T", args);
             return -1;
         }
         kwargs = PyTuple_GET_ITEM(argtup, 2);
         if (!PyDict_Check(kwargs)) {
             PyErr_Format(st->PicklingError,
-                         "third item from NEWOBJ_EX argument tuple must "
-                         "be a dict, not %.200s", Py_TYPE(kwargs)->tp_name);
+                         "third argument to __newobj_ex__() "
+                         "must be a dict, not %T", kwargs);
             return -1;
         }
 
@@ -4102,14 +4116,17 @@ save_reduce(PickleState *st, PicklerObject *self, PyObject *args,
 
         /* Sanity checks. */
         if (PyTuple_GET_SIZE(argtup) < 1) {
-            PyErr_SetString(st->PicklingError, "__newobj__ arglist is empty");
+            PyErr_Format(st->PicklingError,
+                         "__newobj__ expected at least 1 argument, got %zd",
+                         PyTuple_GET_SIZE(argtup));
             return -1;
         }
 
         cls = PyTuple_GET_ITEM(argtup, 0);
         if (!PyType_Check(cls)) {
-            PyErr_SetString(st->PicklingError, "args[0] from "
-                            "__newobj__ args is not a type");
+            PyErr_Format(st->PicklingError,
+                         "first argument to __newobj__() "
+                         "must be a class, not %T", cls);
             return -1;
         }
 
@@ -4118,13 +4135,14 @@ save_reduce(PickleState *st, PicklerObject *self, PyObject *args,
             if (obj_class == NULL) {
                 return -1;
             }
-            p = obj_class != cls;
-            Py_DECREF(obj_class);
-            if (p) {
-                PyErr_SetString(st->PicklingError, "args[0] from "
-                                "__newobj__ args has the wrong class");
+            if (obj_class != cls) {
+                PyErr_Format(st->PicklingError,
+                             "first argument to __newobj__() "
+                             "must be %R, not %R", obj_class, cls);
+                Py_DECREF(obj_class);
                 return -1;
             }
+            Py_DECREF(obj_class);
         }
         /* XXX: These calls save() are prone to infinite recursion. Imagine
            what happen if the value returned by the __reduce__() method of
@@ -4417,8 +4435,7 @@ save(PickleState *st, PicklerObject *self, PyObject *obj, int pers_save)
             }
             else {
                 PyErr_Format(st->PicklingError,
-                             "can't pickle '%.200s' object: %R",
-                             type->tp_name, obj);
+                             "Can't pickle %T object", obj);
                 goto error;
             }
         }
@@ -4434,8 +4451,8 @@ save(PickleState *st, PicklerObject *self, PyObject *obj, int pers_save)
     }
 
     if (!PyTuple_Check(reduce_value)) {
-        PyErr_SetString(st->PicklingError,
-                        "__reduce__ must return a string or tuple");
+        PyErr_Format(st->PicklingError,
+                     "__reduce__ must return a string or tuple, not %T", reduce_value);
         goto error;
     }
 
@@ -7038,17 +7055,16 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls,
             Py_DECREF(module);
             return NULL;
         }
-        if (check_dotted_path(module, global_name, dotted_path) < 0) {
-            Py_DECREF(dotted_path);
-            Py_DECREF(module);
-            return NULL;
-        }
-        global = getattribute(module, dotted_path);
-        Py_DECREF(dotted_path);
-        if (global == NULL && !PyErr_Occurred()) {
+        global = getattribute(module, dotted_path, 1);
+        assert(global != NULL || PyErr_Occurred());
+        if (global == NULL && PyList_GET_SIZE(dotted_path) > 1) {
+            PyObject *exc = PyErr_GetRaisedException();
             PyErr_Format(PyExc_AttributeError,
-                         "Can't get attribute %R on %R", global_name, module);
+                         "Can't resolve path %R on module %R",
+                         global_name, module_name);
+            _PyErr_ChainExceptions1(exc);
         }
+        Py_DECREF(dotted_path);
     }
     else {
         global = PyObject_GetAttr(module, global_name);