]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix rdata pickling.
authorBob Halley <halley@dnspython.org>
Thu, 18 Jun 2020 23:09:47 +0000 (16:09 -0700)
committerBob Halley <halley@dnspython.org>
Thu, 18 Jun 2020 23:09:47 +0000 (16:09 -0700)
Coverage testing showed that while rdatas would pickle and unpickle
apparently successfully, in fact only the slots from the deepest class
in the inheritance chain would be restored.  So, e.g., a restored A rdata
would have an address attribute but no rdclass or rdtype attributes,
and so things like rdata comparison would break.

This change preserves the whole set of slots, from all ancestors as
well as the object, as a dictionary.

dns/rdata.py

index cc44d00243e63ed6ae9c59abce71d63ddc7c1c8e..d3fd6a6cbdb3ec1453cc69425289364c80db0f85 100644 (file)
@@ -128,10 +128,21 @@ class Rdata:
         raise TypeError("object doesn't support attribute deletion")
 
     def __getstate__(self):
-        return tuple(getattr(self, slot) for slot in self.__slots__)
+        # We used to try to do a tuple of all slots here, but it
+        # doesn't work as self._all_slots isn't available at
+        # __setstate__() time.  Before that we tried to store a tuple
+        # of __slots__, but that didn't work as it didn't store the
+        # slots defined by ancestors.  This older way didn't fail
+        # outright, but ended up with partially broken objects, e.g.
+        # if you unpickled an A RR it wouldn't have rdclass and rdtype
+        # attributes, and would compare badly.
+        state = {}
+        for slot in self._all_slots:
+            state[slot] = getattr(self, slot)
+        return state
 
     def __setstate__(self, state):
-        for slot, val in zip(self.__slots__, state):
+        for slot, val in state.items():
             object.__setattr__(self, slot, val)
 
     def covers(self):
@@ -367,6 +378,12 @@ class GenericRdata(Rdata):
 _rdata_classes = {}
 _module_prefix = 'dns.rdtypes'
 
+def _get_all_slots(cls):
+    all_slots = []
+    for scls in cls.__mro__:
+        all_slots.extend(getattr(scls, '__slots__', []))
+    return all_slots
+
 def get_rdata_class(rdclass, rdtype):
     cls = _rdata_classes.get((rdclass, rdtype))
     if not cls:
@@ -379,12 +396,17 @@ def get_rdata_class(rdclass, rdtype):
                 mod = import_module('.'.join([_module_prefix,
                                               rdclass_text, rdtype_text]))
                 cls = getattr(mod, rdtype_text)
+                # initialize cls._all_slots to save effort pickling, as
+                # we don't want to compute the list of all ancestor
+                # slots every __getstate__().
+                cls._all_slots = _get_all_slots(cls)
                 _rdata_classes[(rdclass, rdtype)] = cls
             except ImportError:
                 try:
                     mod = import_module('.'.join([_module_prefix,
                                                   'ANY', rdtype_text]))
                     cls = getattr(mod, rdtype_text)
+                    cls._all_slots = _get_all_slots(cls)
                     _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
                     _rdata_classes[(rdclass, rdtype)] = cls
                 except ImportError: