]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue #892902: Fixed pickling recursive objects.
authorSerhiy Storchaka <storchaka@gmail.com>
Sat, 7 Nov 2015 09:15:32 +0000 (11:15 +0200)
committerSerhiy Storchaka <storchaka@gmail.com>
Sat, 7 Nov 2015 09:15:32 +0000 (11:15 +0200)
Lib/pickle.py
Lib/test/pickletester.py
Lib/test/test_cpickle.py
Misc/NEWS
Modules/cPickle.c

index 299de16f519b56a7cc964fa1af3fdab7b9e6b030..1b3196ff753f67fbbab10c0f8bda5f57197e3333 100644 (file)
@@ -402,7 +402,13 @@ class Pickler:
             write(REDUCE)
 
         if obj is not None:
-            self.memoize(obj)
+            # If the object is already in the memo, this means it is
+            # recursive. In this case, throw away everything we put on the
+            # stack, and fetch the object back from the memo.
+            if id(obj) in self.memo:
+                write(POP + self.get(self.memo[id(obj)][0]))
+            else:
+                self.memoize(obj)
 
         # More new special cases (that work with older protocols as
         # well): when __reduce__ returns a tuple with 4 or 5 items,
index f7b9225cb298c555b4af9e9cd584bf79b93ff31a..d8346ea757b6f9bf1131f1ff2cdb57a13e57115b 100644 (file)
@@ -117,6 +117,18 @@ class E(C):
     def __getinitargs__(self):
         return ()
 
+class H(object):
+    pass
+
+# Hashable mutable key
+class K(object):
+    def __init__(self, value):
+        self.value = value
+
+    def __reduce__(self):
+        # Shouldn't support the recursion itself
+        return K, (self.value,)
+
 import __main__
 __main__.C = C
 C.__module__ = "__main__"
@@ -124,6 +136,10 @@ __main__.D = D
 D.__module__ = "__main__"
 __main__.E = E
 E.__module__ = "__main__"
+__main__.H = H
+H.__module__ = "__main__"
+__main__.K = K
+K.__module__ = "__main__"
 
 class myint(int):
     def __init__(self, x):
@@ -676,18 +692,21 @@ class AbstractPickleTests(unittest.TestCase):
         for proto in protocols:
             s = self.dumps(l, proto)
             x = self.loads(s)
+            self.assertIsInstance(x, list)
             self.assertEqual(len(x), 1)
-            self.assertTrue(x is x[0])
+            self.assertIs(x[0], x)
 
-    def test_recursive_tuple(self):
+    def test_recursive_tuple_and_list(self):
         t = ([],)
         t[0].append(t)
         for proto in protocols:
             s = self.dumps(t, proto)
             x = self.loads(s)
+            self.assertIsInstance(x, tuple)
             self.assertEqual(len(x), 1)
+            self.assertIsInstance(x[0], list)
             self.assertEqual(len(x[0]), 1)
-            self.assertTrue(x is x[0][0])
+            self.assertIs(x[0][0], x)
 
     def test_recursive_dict(self):
         d = {}
@@ -695,8 +714,50 @@ class AbstractPickleTests(unittest.TestCase):
         for proto in protocols:
             s = self.dumps(d, proto)
             x = self.loads(s)
+            self.assertIsInstance(x, dict)
             self.assertEqual(x.keys(), [1])
-            self.assertTrue(x[1] is x)
+            self.assertIs(x[1], x)
+
+    def test_recursive_dict_key(self):
+        d = {}
+        k = K(d)
+        d[k] = 1
+        for proto in protocols:
+            s = self.dumps(d, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, dict)
+            self.assertEqual(len(x.keys()), 1)
+            self.assertIsInstance(x.keys()[0], K)
+            self.assertIs(x.keys()[0].value, x)
+
+    def test_recursive_list_subclass(self):
+        y = MyList()
+        y.append(y)
+        s = self.dumps(y, 2)
+        x = self.loads(s)
+        self.assertIsInstance(x, MyList)
+        self.assertEqual(len(x), 1)
+        self.assertIs(x[0], x)
+
+    def test_recursive_dict_subclass(self):
+        d = MyDict()
+        d[1] = d
+        s = self.dumps(d, 2)
+        x = self.loads(s)
+        self.assertIsInstance(x, MyDict)
+        self.assertEqual(x.keys(), [1])
+        self.assertIs(x[1], x)
+
+    def test_recursive_dict_subclass_key(self):
+        d = MyDict()
+        k = K(d)
+        d[k] = 1
+        s = self.dumps(d, 2)
+        x = self.loads(s)
+        self.assertIsInstance(x, MyDict)
+        self.assertEqual(len(x.keys()), 1)
+        self.assertIsInstance(x.keys()[0], K)
+        self.assertIs(x.keys()[0].value, x)
 
     def test_recursive_inst(self):
         i = C()
@@ -721,6 +782,42 @@ class AbstractPickleTests(unittest.TestCase):
             self.assertEqual(x[0].attr.keys(), [1])
             self.assertTrue(x[0].attr[1] is x)
 
+    def check_recursive_collection_and_inst(self, factory):
+        h = H()
+        y = factory([h])
+        h.attr = y
+        for proto in protocols:
+            s = self.dumps(y, proto)
+            x = self.loads(s)
+            self.assertIsInstance(x, type(y))
+            self.assertEqual(len(x), 1)
+            self.assertIsInstance(list(x)[0], H)
+            self.assertIs(list(x)[0].attr, x)
+
+    def test_recursive_list_and_inst(self):
+        self.check_recursive_collection_and_inst(list)
+
+    def test_recursive_tuple_and_inst(self):
+        self.check_recursive_collection_and_inst(tuple)
+
+    def test_recursive_dict_and_inst(self):
+        self.check_recursive_collection_and_inst(dict.fromkeys)
+
+    def test_recursive_set_and_inst(self):
+        self.check_recursive_collection_and_inst(set)
+
+    def test_recursive_frozenset_and_inst(self):
+        self.check_recursive_collection_and_inst(frozenset)
+
+    def test_recursive_list_subclass_and_inst(self):
+        self.check_recursive_collection_and_inst(MyList)
+
+    def test_recursive_tuple_subclass_and_inst(self):
+        self.check_recursive_collection_and_inst(MyTuple)
+
+    def test_recursive_dict_subclass_and_inst(self):
+        self.check_recursive_collection_and_inst(MyDict.fromkeys)
+
     if have_unicode:
         def test_unicode(self):
             endcases = [u'', u'<\\u>', u'<\\\u1234>', u'<\n>',
index f6b3347543c109e13b08bb74a82d43d6ed6a8404..0a1eb43a31a746990dcf7a82a2416d2979169a0d 100644 (file)
@@ -1,6 +1,7 @@
 import cPickle
 import cStringIO
 import io
+import functools
 import unittest
 from test.pickletester import (AbstractUnpickleTests,
                                AbstractPickleTests,
@@ -151,31 +152,6 @@ class cPickleFastPicklerTests(AbstractPickleTests):
         finally:
             self.close(f)
 
-    def test_recursive_list(self):
-        self.assertRaises(ValueError,
-                          AbstractPickleTests.test_recursive_list,
-                          self)
-
-    def test_recursive_tuple(self):
-        self.assertRaises(ValueError,
-                          AbstractPickleTests.test_recursive_tuple,
-                          self)
-
-    def test_recursive_inst(self):
-        self.assertRaises(ValueError,
-                          AbstractPickleTests.test_recursive_inst,
-                          self)
-
-    def test_recursive_dict(self):
-        self.assertRaises(ValueError,
-                          AbstractPickleTests.test_recursive_dict,
-                          self)
-
-    def test_recursive_multi(self):
-        self.assertRaises(ValueError,
-                          AbstractPickleTests.test_recursive_multi,
-                          self)
-
     def test_nonrecursive_deep(self):
         # If it's not cyclic, it should pickle OK even if the nesting
         # depth exceeds PY_CPICKLE_FAST_LIMIT.  That happens to be
@@ -187,6 +163,19 @@ class cPickleFastPicklerTests(AbstractPickleTests):
         b = self.loads(self.dumps(a))
         self.assertEqual(a, b)
 
+for name in dir(AbstractPickleTests):
+    if name.startswith('test_recursive_'):
+        func = getattr(AbstractPickleTests, name)
+        if '_subclass' in name and '_and_inst' not in name:
+            assert_args = RuntimeError, 'maximum recursion depth exceeded'
+        else:
+            assert_args = ValueError, "can't pickle cyclic objects"
+        def wrapper(self, func=func, assert_args=assert_args):
+            with self.assertRaisesRegexp(*assert_args):
+                func(self)
+        functools.update_wrapper(wrapper, func)
+        setattr(cPickleFastPicklerTests, name, wrapper)
+
 class cStringIOCPicklerFastTests(cStringIOMixin, cPickleFastPicklerTests):
     pass
 
index 6f056b2f33580f16f325698279fbc6c33783129f..f9163d6123298585aef7ce440561325b7d86235f 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -46,6 +46,8 @@ Core and Builtins
 Library
 -------
 
+- Issue #892902: Fixed pickling recursive objects.
+
 - Issue #18010: Fix the pydoc GUI's search function to handle exceptions
   from importing packages.
 
index 89448a6f419002fff3b1a302628c2ca11139c5a3..0e9372360f34d27834b47407caf9787bc6614e22 100644 (file)
@@ -2533,6 +2533,27 @@ save_reduce(Picklerobject *self, PyObject *args, PyObject *fn, PyObject *ob)
     /* Memoize. */
     /* XXX How can ob be NULL? */
     if (ob != NULL) {
+        /* If the object is already in the memo, this means it is
+           recursive. In this case, throw away everything we put on the
+           stack, and fetch the object back from the memo. */
+        if (Py_REFCNT(ob) > 1 && !self->fast) {
+            PyObject *py_ob_id = PyLong_FromVoidPtr(ob);
+            if (!py_ob_id)
+                return -1;
+            if (PyDict_GetItem(self->memo, py_ob_id)) {
+                const char pop_op = POP;
+                if (self->write_func(self, &pop_op, 1) < 0 ||
+                    get(self, py_ob_id) < 0) {
+                    Py_DECREF(py_ob_id);
+                    return -1;
+                }
+                Py_DECREF(py_ob_id);
+                return 0;
+            }
+            Py_DECREF(py_ob_id);
+            if (PyErr_Occurred())
+                return -1;
+        }
         if (state && !PyDict_Check(state)) {
             if (put2(self, ob) < 0)
                 return -1;