]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-146059: Call fast_save_leave() in pickle save_frozenset() (#146173)
authorVictor Stinner <vstinner@python.org>
Thu, 26 Mar 2026 16:35:27 +0000 (17:35 +0100)
committerGitHub <noreply@github.com>
Thu, 26 Mar 2026 16:35:27 +0000 (17:35 +0100)
Add more pickle tests: test also nested structures.

Lib/test/pickletester.py
Modules/_pickle.c

index 6ac4b19da3ea9c98abfbdd969764c32dafff294e..881e672a76ff3ffc2c31965b49dd4bd49b718e08 100644 (file)
@@ -57,6 +57,8 @@ requires_32b = unittest.skipUnless(sys.maxsize < 2**32,
 # kind of outer loop.
 protocols = range(pickle.HIGHEST_PROTOCOL + 1)
 
+FAST_NESTING_LIMIT = 50
+
 
 # Return True if opcode code appears in the pickle, else False.
 def opcode_in_pickle(code, pickle):
@@ -4552,6 +4554,98 @@ class AbstractPickleTests:
                     expected = "changed size during iteration"
                     self.assertIn(expected, str(e))
 
+    def fast_save_enter(self, create_data, minprotocol=0):
+        # gh-146059: Check that fast_save() is called when
+        # fast_save_enter() is called.
+        if not hasattr(self, "pickler"):
+            self.skipTest("need Pickler class")
+
+        data = [create_data(i) for i in range(FAST_NESTING_LIMIT * 2)]
+        data = {"key": data}
+        protocols = range(minprotocol, pickle.HIGHEST_PROTOCOL + 1)
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                buf = io.BytesIO()
+                pickler = self.pickler(buf, protocol=proto)
+                # Enable fast mode (disables memo, enables cycle detection)
+                pickler.fast = 1
+                pickler.dump(data)
+
+                buf.seek(0)
+                data2 = self.unpickler(buf).load()
+                self.assertEqual(data2, data)
+
+    def test_fast_save_enter_tuple(self):
+        self.fast_save_enter(lambda i: (i,))
+
+    def test_fast_save_enter_list(self):
+        self.fast_save_enter(lambda i: [i])
+
+    def test_fast_save_enter_frozenset(self):
+        self.fast_save_enter(lambda i: frozenset([i]))
+
+    def test_fast_save_enter_set(self):
+        self.fast_save_enter(lambda i: set([i]))
+
+    def test_fast_save_enter_frozendict(self):
+        if self.py_version < (3, 15):
+            self.skipTest('need frozendict')
+        self.fast_save_enter(lambda i: frozendict(key=i), minprotocol=2)
+
+    def test_fast_save_enter_dict(self):
+        self.fast_save_enter(lambda i: {"key": i})
+
+    def deep_nested_struct(self, seed, create_nested,
+                           minprotocol=0, compare_equal=True,
+                           depth=FAST_NESTING_LIMIT * 2):
+        # gh-146059: Check that fast_save() is called when
+        # fast_save_enter() is called.
+        if not hasattr(self, "pickler"):
+            self.skipTest("need Pickler class")
+
+        data = seed
+        for i in range(depth):
+            data = create_nested(data)
+        data = {"key": data}
+        protocols = range(minprotocol, pickle.HIGHEST_PROTOCOL + 1)
+        for proto in protocols:
+            with self.subTest(proto=proto):
+                buf = io.BytesIO()
+                pickler = self.pickler(buf, protocol=proto)
+                # Enable fast mode (disables memo, enables cycle detection)
+                pickler.fast = 1
+                pickler.dump(data)
+
+                buf.seek(0)
+                data2 = self.unpickler(buf).load()
+                if compare_equal:
+                    self.assertEqual(data2, data)
+
+    def test_deep_nested_struct_tuple(self):
+        self.deep_nested_struct((1,), lambda data: (data,))
+
+    def test_deep_nested_struct_list(self):
+        self.deep_nested_struct([1], lambda data: [data])
+
+    def test_deep_nested_struct_frozenset(self):
+        self.deep_nested_struct(frozenset((1,)),
+                                lambda data: frozenset((1, data)))
+
+    def test_deep_nested_struct_set(self):
+        self.deep_nested_struct({1}, lambda data: {K(data)},
+                                depth=FAST_NESTING_LIMIT+1,
+                                compare_equal=False)
+
+    def test_deep_nested_struct_frozendict(self):
+        if self.py_version < (3, 15):
+            self.skipTest('need frozendict')
+        self.deep_nested_struct(frozendict(x=1),
+                                lambda data: frozendict(x=data),
+                                minprotocol=2)
+
+    def test_deep_nested_struct_dict(self):
+        self.deep_nested_struct({'x': 1}, lambda data: {'x': data})
+
 
 class BigmemPickleTests:
 
index a55e04290b8fddd41d45076fea1d9a45469cca33..a28e5feebc1ed8c52623c28c10a90920c3d2312b 100644 (file)
@@ -3671,16 +3671,13 @@ save_set(PickleState *state, PicklerObject *self, PyObject *obj)
 }
 
 static int
-save_frozenset(PickleState *state, PicklerObject *self, PyObject *obj)
+save_frozenset_impl(PickleState *state, PicklerObject *self, PyObject *obj)
 {
     PyObject *iter;
 
     const char mark_op = MARK;
     const char frozenset_op = FROZENSET;
 
-    if (self->fast && !fast_save_enter(self, obj))
-        return -1;
-
     if (self->proto < 4) {
         PyObject *items;
         PyObject *reduce_value;
@@ -3751,6 +3748,19 @@ save_frozenset(PickleState *state, PicklerObject *self, PyObject *obj)
     return 0;
 }
 
+static int
+save_frozenset(PickleState *state, PicklerObject *self, PyObject *obj)
+{
+    if (self->fast && !fast_save_enter(self, obj)) {
+        return -1;
+    }
+    int status = save_frozenset_impl(state, self, obj);
+    if (self->fast && !fast_save_leave(self, obj)) {
+        return -1;
+    }
+    return status;
+}
+
 static int
 fix_imports(PickleState *st, PyObject **module_name, PyObject **global_name)
 {