]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue3065: Fixed pickling of named tuples. Added tests.
authorRaymond Hettinger <python@rcn.com>
Mon, 9 Jun 2008 01:28:30 +0000 (01:28 +0000)
committerRaymond Hettinger <python@rcn.com>
Mon, 9 Jun 2008 01:28:30 +0000 (01:28 +0000)
Doc/library/collections.rst
Lib/collections.py
Lib/test/test_collections.py

index 2b6f9b1078a4344536f79a7f917eba27e60b9c86..cbad297a4f603521eec25d0f9e710a0184967b41 100644 (file)
@@ -539,6 +539,9 @@ Example:
                if kwds:
                    raise ValueError('Got unexpected field names: %r' % kwds.keys())
                return result
+   <BLANKLINE>            
+        def __getnewargs__(self): 
+            return tuple(self)
    <BLANKLINE>
            x = property(itemgetter(0))
            y = property(itemgetter(1))
index f6233a7a44896ddca874eb1abdfdc7fd19b86ccb..24088183900a30f43c3242ddc33bd81ba88b6bf6 100644 (file)
@@ -82,7 +82,9 @@ def namedtuple(typename, field_names, verbose=False):
             result = self._make(map(kwds.pop, %(field_names)r, self))
             if kwds:
                 raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
-            return result \n\n''' % locals()
+            return result \n
+        def __getnewargs__(self):
+            return tuple(self) \n\n''' % locals()
     for i, name in enumerate(field_names):
         template += '        %s = property(itemgetter(%d))\n' % (name, i)
     if verbose:
index a770155bdf56554f65fbfc723b5737175c1be6bc..4f823e393ce0efc9e9e676f3a209aeed1a911b14 100644 (file)
@@ -1,12 +1,14 @@
 import unittest, doctest
 from test import test_support
 from collections import namedtuple
+import pickle, cPickle, copy
 from collections import Hashable, Iterable, Iterator
 from collections import Sized, Container, Callable
 from collections import Set, MutableSet
 from collections import Mapping, MutableMapping
 from collections import Sequence, MutableSequence
 
+TestNT = namedtuple('TestNT', 'x y z')    # type used for pickle tests
 
 class TestNamedTuple(unittest.TestCase):
 
@@ -108,7 +110,7 @@ class TestNamedTuple(unittest.TestCase):
         self.assertEqual(Dot(1)._replace(d=999), (999,))
         self.assertEqual(Dot(1)._fields, ('d',))
 
-        n = 10000
+        n = 5000
         import string, random
         names = list(set(''.join([random.choice(string.ascii_letters)
                                   for j in range(10)]) for i in range(n)))
@@ -130,6 +132,23 @@ class TestNamedTuple(unittest.TestCase):
         self.assertEqual(b2, tuple(b2_expected))
         self.assertEqual(b._fields, tuple(names))
 
+    def test_pickle(self):
+        p = TestNT(x=10, y=20, z=30)
+        for module in pickle, cPickle:
+            loads = getattr(module, 'loads')
+            dumps = getattr(module, 'dumps')
+            for protocol in -1, 0, 1, 2:
+                q = loads(dumps(p, protocol))
+                self.assertEqual(p, q)
+                self.assertEqual(p._fields, q._fields)
+
+    def test_copy(self):
+        p = TestNT(x=10, y=20, z=30)
+        for copier in copy.copy, copy.deepcopy:
+            q = copier(p)
+            self.assertEqual(p, q)
+            self.assertEqual(p._fields, q._fields)
+
 class TestOneTrickPonyABCs(unittest.TestCase):
 
     def test_Hashable(self):