]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
If using a frozen class with slots, add __getstate__ and __setstate__ to set the...
authorEric V. Smith <ericvsmith@users.noreply.github.com>
Sat, 1 May 2021 17:27:30 +0000 (13:27 -0400)
committerGitHub <noreply@github.com>
Sat, 1 May 2021 17:27:30 +0000 (13:27 -0400)
Lib/dataclasses.py
Lib/test/test_dataclasses.py

index 5e5716316f09541a5371d3087ecb3d00986fc583..363d0b66d208e4db3decd1fd6012cd63736f7eec 100644 (file)
@@ -1087,14 +1087,28 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
                            tuple(f.name for f in std_init_fields))
 
     if slots:
-        cls = _add_slots(cls)
+        cls = _add_slots(cls, frozen)
 
     abc.update_abstractmethods(cls)
 
     return cls
 
 
-def _add_slots(cls):
+# _dataclass_getstate and _dataclass_setstate are needed for pickling frozen
+# classes with slots.  These could be slighly more performant if we generated
+# the code instead of iterating over fields.  But that can be a project for
+# another day, if performance becomes an issue.
+def _dataclass_getstate(self):
+    return [getattr(self, f.name) for f in fields(self)]
+
+
+def _dataclass_setstate(self, state):
+    for field, value in zip(fields(self), state):
+        # use setattr because dataclass may be frozen
+        object.__setattr__(self, field.name, value)
+
+
+def _add_slots(cls, is_frozen):
     # Need to create a new class, since we can't set __slots__
     #  after a class has been created.
 
@@ -1120,6 +1134,11 @@ def _add_slots(cls):
     if qualname is not None:
         cls.__qualname__ = qualname
 
+    if is_frozen:
+        # Need this for pickling frozen classes with slots.
+        cls.__getstate__ = _dataclass_getstate
+        cls.__setstate__ = _dataclass_setstate
+
     return cls
 
 
index 2fa0ae0126bf88afa5a0910a932ae040258cd127..16ee4c7705d8cc8fa7fd60eeb6c50cc1ec73a4c8 100644 (file)
@@ -2833,6 +2833,19 @@ class TestSlots(unittest.TestCase):
         self.assertFalse(hasattr(A, "__slots__"))
         self.assertTrue(hasattr(B, "__slots__"))
 
+    # Can't be local to test_frozen_pickle.
+    @dataclass(frozen=True, slots=True)
+    class FrozenSlotsClass:
+        foo: str
+        bar: int
+
+    def test_frozen_pickle(self):
+        # bpo-43999
+
+        assert self.FrozenSlotsClass.__slots__ == ("foo", "bar")
+        p = pickle.dumps(self.FrozenSlotsClass("a", 1))
+        assert pickle.loads(p) == self.FrozenSlotsClass("a", 1)
+
 
 class TestDescriptors(unittest.TestCase):
     def test_set_name(self):