]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-122311: Add more tests for pickle (GH-122376)
authorSerhiy Storchaka <storchaka@gmail.com>
Sun, 28 Jul 2024 08:33:17 +0000 (11:33 +0300)
committerGitHub <noreply@github.com>
Sun, 28 Jul 2024 08:33:17 +0000 (11:33 +0300)
Lib/test/pickletester.py
Lib/test/test_pickle.py

index 13663220fc77eaeb75ccd145c2ae54bec09a21e9..174f4ff6d021b20bd34d0b28401991d2adbc346e 100644 (file)
@@ -144,6 +144,14 @@ class E(C):
     def __getinitargs__(self):
         return ()
 
+import __main__
+__main__.C = C
+C.__module__ = "__main__"
+__main__.D = D
+D.__module__ = "__main__"
+__main__.E = E
+E.__module__ = "__main__"
+
 # Simple mutable object.
 class Object:
     pass
@@ -157,14 +165,6 @@ class K:
         # Shouldn't support the recursion itself
         return K, (self.value,)
 
-import __main__
-__main__.C = C
-C.__module__ = "__main__"
-__main__.D = D
-D.__module__ = "__main__"
-__main__.E = E
-E.__module__ = "__main__"
-
 class myint(int):
     def __init__(self, x):
         self.str = str(x)
@@ -1179,6 +1179,124 @@ class AbstractUnpickleTests:
             self.assertIs(type(unpickled), collections.UserDict)
             self.assertEqual(unpickled, collections.UserDict({1: 2}))
 
+    def test_load_global(self):
+        self.assertIs(self.loads(b'cbuiltins\nstr\n.'), str)
+        self.assertIs(self.loads(b'cmath\nlog\n.'), math.log)
+        self.assertIs(self.loads(b'cos.path\njoin\n.'), os.path.join)
+        self.assertIs(self.loads(b'\x80\x04cbuiltins\nstr.upper\n.'), str.upper)
+        with support.swap_item(sys.modules, 'mödule', types.SimpleNamespace(glöbal=42)):
+            self.assertEqual(self.loads(b'\x80\x04cm\xc3\xb6dule\ngl\xc3\xb6bal\n.'), 42)
+
+        self.assertRaises(UnicodeDecodeError, self.loads, b'c\xff\nlog\n.')
+        self.assertRaises(UnicodeDecodeError, self.loads, b'cmath\n\xff\n.')
+        self.assertRaises(self.truncated_errors, self.loads, b'c\nlog\n.')
+        self.assertRaises(self.truncated_errors, self.loads, b'cmath\n\n.')
+        self.assertRaises(self.truncated_errors, self.loads, b'\x80\x04cmath\n\n.')
+
+    def test_load_stack_global(self):
+        self.assertIs(self.loads(b'\x8c\x08builtins\x8c\x03str\x93.'), str)
+        self.assertIs(self.loads(b'\x8c\x04math\x8c\x03log\x93.'), math.log)
+        self.assertIs(self.loads(b'\x8c\x07os.path\x8c\x04join\x93.'),
+                      os.path.join)
+        self.assertIs(self.loads(b'\x80\x04\x8c\x08builtins\x8c\x09str.upper\x93.'),
+                      str.upper)
+        with support.swap_item(sys.modules, 'mödule', types.SimpleNamespace(glöbal=42)):
+            self.assertEqual(self.loads(b'\x80\x04\x8c\x07m\xc3\xb6dule\x8c\x07gl\xc3\xb6bal\x93.'), 42)
+
+        self.assertRaises(UnicodeDecodeError, self.loads, b'\x8c\x01\xff\x8c\x03log\x93.')
+        self.assertRaises(UnicodeDecodeError, self.loads, b'\x8c\x04math\x8c\x01\xff\x93.')
+        self.assertRaises(ValueError, self.loads, b'\x8c\x00\x8c\x03log\x93.')
+        self.assertRaises(AttributeError, self.loads, b'\x8c\x04math\x8c\x00\x93.')
+        self.assertRaises(AttributeError, self.loads, b'\x80\x04\x8c\x04math\x8c\x00\x93.')
+
+        self.assertRaises(pickle.UnpicklingError, self.loads, b'N\x8c\x03log\x93.')
+        self.assertRaises(pickle.UnpicklingError, self.loads, b'\x8c\x04mathN\x93.')
+        self.assertRaises(pickle.UnpicklingError, self.loads, b'\x80\x04\x8c\x04mathN\x93.')
+
+    def test_find_class(self):
+        unpickler = self.unpickler(io.BytesIO())
+        unpickler_nofix = self.unpickler(io.BytesIO(), fix_imports=False)
+        unpickler4 = self.unpickler(io.BytesIO(b'\x80\x04N.'))
+        unpickler4.load()
+
+        self.assertIs(unpickler.find_class('__builtin__', 'str'), str)
+        self.assertRaises(ModuleNotFoundError,
+                          unpickler_nofix.find_class, '__builtin__', 'str')
+        self.assertIs(unpickler.find_class('builtins', 'str'), str)
+        self.assertIs(unpickler_nofix.find_class('builtins', 'str'), str)
+        self.assertIs(unpickler.find_class('math', 'log'), math.log)
+        self.assertIs(unpickler.find_class('os.path', 'join'), os.path.join)
+        self.assertIs(unpickler.find_class('os.path', 'join'), os.path.join)
+
+        self.assertIs(unpickler4.find_class('builtins', 'str.upper'), str.upper)
+        with self.assertRaises(AttributeError):
+            unpickler.find_class('builtins', 'str.upper')
+
+        with self.assertRaises(AttributeError):
+            unpickler.find_class('math', 'spam')
+        with self.assertRaises(AttributeError):
+            unpickler4.find_class('math', 'spam')
+        with self.assertRaises(AttributeError):
+            unpickler.find_class('math', 'log.spam')
+        with self.assertRaises(AttributeError):
+            unpickler4.find_class('math', 'log.spam')
+        with self.assertRaises(AttributeError):
+            unpickler.find_class('math', 'log.<locals>.spam')
+        with self.assertRaises(AttributeError):
+            unpickler4.find_class('math', 'log.<locals>.spam')
+        with self.assertRaises(AttributeError):
+            unpickler.find_class('math', '')
+        with self.assertRaises(AttributeError):
+            unpickler4.find_class('math', '')
+        self.assertRaises(ModuleNotFoundError, unpickler.find_class, 'spam', 'log')
+        self.assertRaises(ValueError, unpickler.find_class, '', 'log')
+
+        self.assertRaises(TypeError, unpickler.find_class, None, 'log')
+        self.assertRaises(TypeError, unpickler.find_class, 'math', None)
+        self.assertRaises((TypeError, AttributeError), unpickler4.find_class, 'math', None)
+
+    def test_custom_find_class(self):
+        def loads(data):
+            class Unpickler(self.unpickler):
+                def find_class(self, module_name, global_name):
+                    return (module_name, global_name)
+            return Unpickler(io.BytesIO(data)).load()
+
+        self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
+        self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
+
+        def loads(data):
+            class Unpickler(self.unpickler):
+                @staticmethod
+                def find_class(module_name, global_name):
+                    return (module_name, global_name)
+            return Unpickler(io.BytesIO(data)).load()
+
+        self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
+        self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
+
+        def loads(data):
+            class Unpickler(self.unpickler):
+                @classmethod
+                def find_class(cls, module_name, global_name):
+                    return (module_name, global_name)
+            return Unpickler(io.BytesIO(data)).load()
+
+        self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
+        self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
+
+        def loads(data):
+            class Unpickler(self.unpickler):
+                pass
+            def find_class(module_name, global_name):
+                return (module_name, global_name)
+            unpickler = Unpickler(io.BytesIO(data))
+            unpickler.find_class = find_class
+            return unpickler.load()
+
+        self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
+        self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
+
     def test_bad_reduce(self):
         self.assertEqual(self.loads(b'cbuiltins\nint\n)R.'), 0)
         self.check_unpickling_error(TypeError, b'N)R.')
@@ -1443,6 +1561,474 @@ class AbstractUnpickleTests:
             [ToBeUnpickled] * 2)
 
 
+class AbstractPicklingErrorTests:
+    # Subclass must define self.dumps, self.pickler.
+
+    def test_bad_reduce_result(self):
+        obj = REX([print, ()])
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+        obj = REX((print,))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+        obj = REX((print, (), None, None, None, None, None))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_bad_reconstructor(self):
+        obj = REX((42, ()))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_reconstructor(self):
+        obj = REX((UnpickleableCallable(), ()))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_bad_reconstructor_args(self):
+        obj = REX((print, []))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_reconstructor_args(self):
+        obj = REX((print, (1, 2, UNPICKLEABLE)))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_bad_newobj_args(self):
+        obj = REX((copyreg.__newobj__, ()))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises((IndexError, pickle.PicklingError)) as cm:
+                    self.dumps(obj, proto)
+
+        obj = REX((copyreg.__newobj__, [REX]))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises((IndexError, pickle.PicklingError)):
+                    self.dumps(obj, proto)
+
+    def test_bad_newobj_class(self):
+        obj = REX((copyreg.__newobj__, (NoNew(),)))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_wrong_newobj_class(self):
+        obj = REX((copyreg.__newobj__, (str,)))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_newobj_class(self):
+        class LocalREX(REX): pass
+        obj = LocalREX((copyreg.__newobj__, (LocalREX,)))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((pickle.PicklingError, AttributeError)):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_newobj_args(self):
+        obj = REX((copyreg.__newobj__, (REX, 1, 2, UNPICKLEABLE)))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_bad_newobj_ex_args(self):
+        obj = REX((copyreg.__newobj_ex__, ()))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises((ValueError, pickle.PicklingError)):
+                    self.dumps(obj, proto)
+
+        obj = REX((copyreg.__newobj_ex__, 42))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+        obj = REX((copyreg.__newobj_ex__, (REX, 42, {})))
+        is_py = self.pickler is pickle._Pickler
+        for proto in protocols[2:4] if is_py else protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises((TypeError, pickle.PicklingError)):
+                    self.dumps(obj, proto)
+
+        obj = REX((copyreg.__newobj_ex__, (REX, (), [])))
+        for proto in protocols[2:4] if is_py else protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises((TypeError, pickle.PicklingError)):
+                    self.dumps(obj, proto)
+
+    def test_bad_newobj_ex__class(self):
+        obj = REX((copyreg.__newobj_ex__, (NoNew(), (), {})))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_wrong_newobj_ex_class(self):
+        if self.pickler is not pickle._Pickler:
+            self.skipTest('only verified in the Python implementation')
+        obj = REX((copyreg.__newobj_ex__, (str, (), {})))
+        for proto in protocols[2:]:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_newobj_ex_class(self):
+        class LocalREX(REX): pass
+        obj = LocalREX((copyreg.__newobj_ex__, (LocalREX, (), {})))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((pickle.PicklingError, AttributeError)):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_newobj_ex_args(self):
+        obj = REX((copyreg.__newobj_ex__, (REX, (1, 2, UNPICKLEABLE), {})))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_newobj_ex_kwargs(self):
+        obj = REX((copyreg.__newobj_ex__, (REX, (), {'a': UNPICKLEABLE})))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_state(self):
+        obj = REX_state(UNPICKLEABLE)
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_bad_state_setter(self):
+        if self.pickler is pickle._Pickler:
+            self.skipTest('only verified in the C implementation')
+        obj = REX((print, (), 'state', None, None, 42))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_state_setter(self):
+        obj = REX((print, (), 'state', None, None, UnpickleableCallable()))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_state_with_state_setter(self):
+        obj = REX((print, (), UNPICKLEABLE, None, None, print))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_bad_object_list_items(self):
+        # Issue4176: crash when 4th and 5th items of __reduce__()
+        # are not iterators
+        obj = REX((list, (), None, 42))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((TypeError, pickle.PicklingError)):
+                    self.dumps(obj, proto)
+
+        if self.pickler is not pickle._Pickler:
+            # Python implementation is less strict and also accepts iterables.
+            obj = REX((list, (), None, []))
+            for proto in protocols:
+                with self.subTest(proto=proto):
+                    with self.assertRaises((TypeError, pickle.PicklingError)):
+                        self.dumps(obj, proto)
+
+    def test_unpickleable_object_list_items(self):
+        obj = REX_six([1, 2, UNPICKLEABLE])
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_bad_object_dict_items(self):
+        # Issue4176: crash when 4th and 5th items of __reduce__()
+        # are not iterators
+        obj = REX((dict, (), None, None, 42))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((TypeError, pickle.PicklingError)):
+                    self.dumps(obj, proto)
+
+        for proto in protocols:
+            obj = REX((dict, (), None, None, iter([('a',)])))
+            with self.subTest(proto=proto):
+                with self.assertRaises((ValueError, TypeError)):
+                    self.dumps(obj, proto)
+
+        if self.pickler is not pickle._Pickler:
+            # Python implementation is less strict and also accepts iterables.
+            obj = REX((dict, (), None, None, []))
+            for proto in protocols:
+                with self.subTest(proto=proto):
+                    with self.assertRaises((TypeError, pickle.PicklingError)):
+                        self.dumps(obj, proto)
+
+    def test_unpickleable_object_dict_items(self):
+        obj = REX_seven({'a': UNPICKLEABLE})
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_list_items(self):
+        obj = [1, [2, 3, UNPICKLEABLE]]
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+        for n in [0, 1, 1000, 1005]:
+            obj = [*range(n), UNPICKLEABLE]
+            for proto in protocols:
+                with self.subTest(proto=proto):
+                    with self.assertRaises(CustomError):
+                        self.dumps(obj, proto)
+
+    def test_unpickleable_tuple_items(self):
+        obj = (1, (2, 3, UNPICKLEABLE))
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+        obj = (*range(10), UNPICKLEABLE)
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_dict_items(self):
+        obj = {'a': {'b': UNPICKLEABLE}}
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+        for n in [0, 1, 1000, 1005]:
+            obj = dict.fromkeys(range(n))
+            obj['a'] = UNPICKLEABLE
+            for proto in protocols:
+                with self.subTest(proto=proto, n=n):
+                    with self.assertRaises(CustomError):
+                        self.dumps(obj, proto)
+
+    def test_unpickleable_set_items(self):
+        obj = {UNPICKLEABLE}
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_unpickleable_frozenset_items(self):
+        obj = frozenset({frozenset({UNPICKLEABLE})})
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(CustomError):
+                    self.dumps(obj, proto)
+
+    def test_global_lookup_error(self):
+        # Global name does not exist
+        obj = REX('spam')
+        obj.__module__ = __name__
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+        obj.__module__ = 'nonexisting'
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+        obj.__module__ = ''
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((ValueError, pickle.PicklingError)):
+                    self.dumps(obj, proto)
+
+        obj.__module__ = None
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_nonencodable_global_name_error(self):
+        for proto in protocols[:4]:
+            with self.subTest(proto=proto):
+                name = 'nonascii\xff' if proto < 3 else 'nonencodable\udbff'
+                obj = REX(name)
+                obj.__module__ = __name__
+                with support.swap_item(globals(), name, obj):
+                    with self.assertRaises((UnicodeEncodeError, pickle.PicklingError)):
+                        self.dumps(obj, proto)
+
+    def test_nonencodable_module_name_error(self):
+        for proto in protocols[:4]:
+            with self.subTest(proto=proto):
+                name = 'nonascii\xff' if proto < 3 else 'nonencodable\udbff'
+                obj = REX('test')
+                obj.__module__ = name
+                mod = types.SimpleNamespace(test=obj)
+                with support.swap_item(sys.modules, name, mod):
+                    with self.assertRaises((UnicodeEncodeError, pickle.PicklingError)):
+                        self.dumps(obj, proto)
+
+    def test_nested_lookup_error(self):
+        # Nested name does not exist
+        obj = REX('AbstractPickleTests.spam')
+        obj.__module__ = __name__
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+        obj.__module__ = None
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_wrong_object_lookup_error(self):
+        # Name is bound to different object
+        obj = REX('AbstractPickleTests')
+        obj.__module__ = __name__
+        AbstractPickleTests.ham = []
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+        obj.__module__ = None
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(obj, proto)
+
+    def test_local_lookup_error(self):
+        # Test that whichmodule() errors out cleanly when looking up
+        # an assumed globally-reachable object fails.
+        def f():
+            pass
+        # Since the function is local, lookup will fail
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((AttributeError, pickle.PicklingError)):
+                    self.dumps(f, proto)
+        # Same without a __module__ attribute (exercises a different path
+        # in _pickle.c).
+        del f.__module__
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((AttributeError, pickle.PicklingError)):
+                    self.dumps(f, proto)
+        # Yet a different path.
+        f.__name__ = f.__qualname__
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                with self.assertRaises((AttributeError, pickle.PicklingError)):
+                    self.dumps(f, proto)
+
+    def test_reduce_ex_None(self):
+        c = REX_None()
+        with self.assertRaises(TypeError):
+            self.dumps(c)
+
+    def test_reduce_None(self):
+        c = R_None()
+        with self.assertRaises(TypeError):
+            self.dumps(c)
+
+    @no_tracing
+    def test_bad_getattr(self):
+        # Issue #3514: crash when there is an infinite loop in __getattr__
+        x = BadGetattr()
+        for proto in range(2):
+            with support.infinite_recursion(25):
+                self.assertRaises(RuntimeError, self.dumps, x, proto)
+        for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
+            s = self.dumps(x, proto)
+
+    def test_picklebuffer_error(self):
+        # PickleBuffer forbidden with protocol < 5
+        pb = pickle.PickleBuffer(b"foobar")
+        for proto in range(0, 5):
+            with self.subTest(proto=proto):
+                with self.assertRaises(pickle.PickleError):
+                    self.dumps(pb, proto)
+
+    def test_non_continuous_buffer(self):
+        if self.pickler is pickle._Pickler:
+            self.skipTest('CRASHES (see gh-122306)')
+        for proto in protocols[5:]:
+            with self.subTest(proto=proto):
+                pb = pickle.PickleBuffer(memoryview(b"foobar")[::2])
+                with self.assertRaises(pickle.PicklingError):
+                    self.dumps(pb, proto)
+
+    def test_buffer_callback_error(self):
+        def buffer_callback(buffers):
+            raise CustomError
+        pb = pickle.PickleBuffer(b"foobar")
+        with self.assertRaises(CustomError):
+            self.dumps(pb, 5, buffer_callback=buffer_callback)
+
+    def test_evil_pickler_mutating_collection(self):
+        # https://github.com/python/cpython/issues/92930
+        global Clearer
+        class Clearer:
+            pass
+
+        def check(collection):
+            class EvilPickler(self.pickler):
+                def persistent_id(self, obj):
+                    if isinstance(obj, Clearer):
+                        collection.clear()
+                    return None
+            pickler = EvilPickler(io.BytesIO(), proto)
+            try:
+                pickler.dump(collection)
+            except RuntimeError as e:
+                expected = "changed size during iteration"
+                self.assertIn(expected, str(e))
+
+        for proto in protocols:
+            check([Clearer()])
+            check([Clearer(), Clearer()])
+            check({Clearer()})
+            check({Clearer(), Clearer()})
+            check({Clearer(): 1})
+            check({Clearer(): 1, Clearer(): 2})
+            check({1: Clearer(), 2: Clearer()})
+
 
 class AbstractPickleTests:
     # Subclass must define self.dumps, self.loads.
@@ -2453,55 +3039,12 @@ class AbstractPickleTests:
             y = self.loads(s)
             self.assertEqual(y._reduce_called, 1)
 
-    def test_reduce_ex_None(self):
-        c = REX_None()
-        with self.assertRaises(TypeError):
-            self.dumps(c)
-
-    def test_reduce_None(self):
-        c = R_None()
-        with self.assertRaises(TypeError):
-            self.dumps(c)
-
     def test_pickle_setstate_None(self):
         c = C_None_setstate()
         p = self.dumps(c)
         with self.assertRaises(TypeError):
             self.loads(p)
 
-    @no_tracing
-    def test_bad_getattr(self):
-        # Issue #3514: crash when there is an infinite loop in __getattr__
-        x = BadGetattr()
-        for proto in range(2):
-            with support.infinite_recursion(25):
-                self.assertRaises(RuntimeError, self.dumps, x, proto)
-        for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
-            s = self.dumps(x, proto)
-
-    def test_reduce_bad_iterator(self):
-        # Issue4176: crash when 4th and 5th items of __reduce__()
-        # are not iterators
-        class C(object):
-            def __reduce__(self):
-                # 4th item is not an iterator
-                return list, (), None, [], None
-        class D(object):
-            def __reduce__(self):
-                # 5th item is not an iterator
-                return dict, (), None, None, []
-
-        # Python implementation is less strict and also accepts iterables.
-        for proto in protocols:
-            try:
-                self.dumps(C(), proto)
-            except pickle.PicklingError:
-                pass
-            try:
-                self.dumps(D(), proto)
-            except pickle.PicklingError:
-                pass
-
     def test_many_puts_and_gets(self):
         # Test that internal data structures correctly deal with lots of
         # puts/gets.
@@ -2950,27 +3493,6 @@ class AbstractPickleTests:
                     self.assertIn(('c%s\n%s' % (mod, name)).encode(), pickled)
                     self.assertIs(type(self.loads(pickled)), type(val))
 
-    def test_local_lookup_error(self):
-        # Test that whichmodule() errors out cleanly when looking up
-        # an assumed globally-reachable object fails.
-        def f():
-            pass
-        # Since the function is local, lookup will fail
-        for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
-            with self.assertRaises((AttributeError, pickle.PicklingError)):
-                pickletools.dis(self.dumps(f, proto))
-        # Same without a __module__ attribute (exercises a different path
-        # in _pickle.c).
-        del f.__module__
-        for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
-            with self.assertRaises((AttributeError, pickle.PicklingError)):
-                pickletools.dis(self.dumps(f, proto))
-        # Yet a different path.
-        f.__name__ = f.__qualname__
-        for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
-            with self.assertRaises((AttributeError, pickle.PicklingError)):
-                pickletools.dis(self.dumps(f, proto))
-
     #
     # PEP 574 tests below
     #
@@ -3081,20 +3603,6 @@ class AbstractPickleTests:
             self.assertIs(type(new), type(obj))
             self.assertEqual(new, obj)
 
-    def test_picklebuffer_error(self):
-        # PickleBuffer forbidden with protocol < 5
-        pb = pickle.PickleBuffer(b"foobar")
-        for proto in range(0, 5):
-            with self.assertRaises(pickle.PickleError):
-                self.dumps(pb, proto)
-
-    def test_buffer_callback_error(self):
-        def buffer_callback(buffers):
-            1/0
-        pb = pickle.PickleBuffer(b"foobar")
-        with self.assertRaises(ZeroDivisionError):
-            self.dumps(pb, 5, buffer_callback=buffer_callback)
-
     def test_buffers_error(self):
         pb = pickle.PickleBuffer(b"foobar")
         for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
@@ -3186,37 +3694,6 @@ class AbstractPickleTests:
                     expected = "changed size during iteration"
                     self.assertIn(expected, str(e))
 
-    def test_evil_pickler_mutating_collection(self):
-        # https://github.com/python/cpython/issues/92930
-        if not hasattr(self, "pickler"):
-            raise self.skipTest(f"{type(self)} has no associated pickler type")
-
-        global Clearer
-        class Clearer:
-            pass
-
-        def check(collection):
-            class EvilPickler(self.pickler):
-                def persistent_id(self, obj):
-                    if isinstance(obj, Clearer):
-                        collection.clear()
-                    return None
-            pickler = EvilPickler(io.BytesIO(), proto)
-            try:
-                pickler.dump(collection)
-            except RuntimeError as e:
-                expected = "changed size during iteration"
-                self.assertIn(expected, str(e))
-
-        for proto in protocols:
-            check([Clearer()])
-            check([Clearer(), Clearer()])
-            check({Clearer()})
-            check({Clearer(), Clearer()})
-            check({Clearer(): 1})
-            check({Clearer(): 1, Clearer(): 2})
-            check({1: Clearer(), 2: Clearer()})
-
 
 class BigmemPickleTests:
 
@@ -3347,6 +3824,18 @@ class BigmemPickleTests:
 
 # Test classes for reduce_ex
 
+class R:
+    def __init__(self, reduce=None):
+        self.reduce = reduce
+    def __reduce__(self, proto):
+        return self.reduce
+
+class REX:
+    def __init__(self, reduce_ex=None):
+        self.reduce_ex = reduce_ex
+    def __reduce_ex__(self, proto):
+        return self.reduce_ex
+
 class REX_one(object):
     """No __reduce_ex__ here, but inheriting it from object"""
     _reduce_called = 0
@@ -3437,6 +3926,19 @@ class C_None_setstate:
 
     __setstate__ = None
 
+class CustomError(Exception):
+    pass
+
+class Unpickleable:
+    def __reduce__(self):
+        raise CustomError
+
+UNPICKLEABLE = Unpickleable()
+
+class UnpickleableCallable(Unpickleable):
+    def __call__(self, *args, **kwargs):
+        pass
+
 
 # Test classes for newobj
 
@@ -3505,6 +4007,12 @@ class BadGetattr:
     def __getattr__(self, key):
         self.foo
 
+class NoNew:
+    def __getattribute__(self, name):
+        if name == '__new__':
+            raise AttributeError
+        return super().__getattribute__(name)
+
 
 class AbstractPickleModuleTests:
 
@@ -3577,7 +4085,7 @@ class AbstractPickleModuleTests:
             raise OSError
         @property
         def bad_property(self):
-            1/0
+            raise CustomError
 
         # File without read and readline
         class F:
@@ -3598,23 +4106,23 @@ class AbstractPickleModuleTests:
         class F:
             read = bad_property
             readline = raises_oserror
-        self.assertRaises(ZeroDivisionError, self.Unpickler, F())
+        self.assertRaises(CustomError, self.Unpickler, F())
 
         # File with bad readline
         class F:
             readline = bad_property
             read = raises_oserror
-        self.assertRaises(ZeroDivisionError, self.Unpickler, F())
+        self.assertRaises(CustomError, self.Unpickler, F())
 
         # File with bad readline, no read
         class F:
             readline = bad_property
-        self.assertRaises(ZeroDivisionError, self.Unpickler, F())
+        self.assertRaises(CustomError, self.Unpickler, F())
 
         # File with bad read, no readline
         class F:
             read = bad_property
-        self.assertRaises((AttributeError, ZeroDivisionError), self.Unpickler, F())
+        self.assertRaises((AttributeError, CustomError), self.Unpickler, F())
 
         # File with bad peek
         class F:
@@ -3623,7 +4131,7 @@ class AbstractPickleModuleTests:
             readline = raises_oserror
         try:
             self.Unpickler(F())
-        except ZeroDivisionError:
+        except CustomError:
             pass
 
         # File with bad readinto
@@ -3633,7 +4141,7 @@ class AbstractPickleModuleTests:
             readline = raises_oserror
         try:
             self.Unpickler(F())
-        except ZeroDivisionError:
+        except CustomError:
             pass
 
     def test_pickler_bad_file(self):
@@ -3646,8 +4154,8 @@ class AbstractPickleModuleTests:
         class F:
             @property
             def write(self):
-                1/0
-        self.assertRaises(ZeroDivisionError, self.Pickler, F())
+                raise CustomError
+        self.assertRaises(CustomError, self.Pickler, F())
 
     def check_dumps_loads_oob_buffers(self, dumps, loads):
         # No need to do the full gamut of tests here, just enough to
@@ -3755,9 +4263,15 @@ class AbstractIdentityPersistentPicklerTests:
 
     def test_protocol0_is_ascii_only(self):
         non_ascii_str = "\N{EMPTY SET}"
-        self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0)
+        with self.assertRaises(pickle.PicklingError) as cm:
+            self.dumps(non_ascii_str, 0)
+        self.assertEqual(str(cm.exception),
+                         'persistent IDs in protocol 0 must be ASCII strings')
         pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.'
-        self.assertRaises(pickle.UnpicklingError, self.loads, pickled)
+        with self.assertRaises(pickle.UnpicklingError) as cm:
+            self.loads(pickled)
+        self.assertEqual(str(cm.exception),
+                         'persistent IDs in protocol 0 must be ASCII strings')
 
 
 class AbstractPicklerUnpicklerObjectTests:
index 49aa4b386039ec60542a306b13a13e511a3eaf9a..c84e507cdf645f5872b347baadd604bc810f8ce1 100644 (file)
@@ -16,6 +16,7 @@ from test.support import import_helper
 
 from test.pickletester import AbstractHookTests
 from test.pickletester import AbstractUnpickleTests
+from test.pickletester import AbstractPicklingErrorTests
 from test.pickletester import AbstractPickleTests
 from test.pickletester import AbstractPickleModuleTests
 from test.pickletester import AbstractPersistentPicklerTests
@@ -55,6 +56,18 @@ class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
         return u.load()
 
 
+class PyPicklingErrorTests(AbstractPicklingErrorTests, unittest.TestCase):
+
+    pickler = pickle._Pickler
+
+    def dumps(self, arg, proto=None, **kwargs):
+        f = io.BytesIO()
+        p = self.pickler(f, proto, **kwargs)
+        p.dump(arg)
+        f.seek(0)
+        return bytes(f.read())
+
+
 class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
 
     pickler = pickle._Pickler
@@ -88,6 +101,8 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
         return pickle.loads(buf, **kwds)
 
     test_framed_write_sizes_with_delayed_writer = None
+    test_find_class = None
+    test_custom_find_class = None
 
 
 class PersistentPicklerUnpicklerMixin(object):
@@ -267,6 +282,9 @@ if has_c_implementation:
         bad_stack_errors = (pickle.UnpicklingError,)
         truncated_errors = (pickle.UnpicklingError,)
 
+    class CPicklingErrorTests(PyPicklingErrorTests):
+        pickler = _pickle.Pickler
+
     class CPicklerTests(PyPicklerTests):
         pickler = _pickle.Pickler
         unpickler = _pickle.Unpickler