From e1f6a890688d342458fd46c65eebf4885ec800d4 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Thu, 18 Jun 2020 16:09:47 -0700 Subject: [PATCH] Fix rdata pickling. Coverage testing showed that while rdatas would pickle and unpickle apparently successfully, in fact only the slots from the deepest class in the inheritance chain would be restored. So, e.g., a restored A rdata would have an address attribute but no rdclass or rdtype attributes, and so things like rdata comparison would break. This change preserves the whole set of slots, from all ancestors as well as the object, as a dictionary. --- dns/rdata.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/dns/rdata.py b/dns/rdata.py index cc44d002..d3fd6a6c 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -128,10 +128,21 @@ class Rdata: raise TypeError("object doesn't support attribute deletion") def __getstate__(self): - return tuple(getattr(self, slot) for slot in self.__slots__) + # We used to try to do a tuple of all slots here, but it + # doesn't work as self._all_slots isn't available at + # __setstate__() time. Before that we tried to store a tuple + # of __slots__, but that didn't work as it didn't store the + # slots defined by ancestors. This older way didn't fail + # outright, but ended up with partially broken objects, e.g. + # if you unpickled an A RR it wouldn't have rdclass and rdtype + # attributes, and would compare badly. + state = {} + for slot in self._all_slots: + state[slot] = getattr(self, slot) + return state def __setstate__(self, state): - for slot, val in zip(self.__slots__, state): + for slot, val in state.items(): object.__setattr__(self, slot, val) def covers(self): @@ -367,6 +378,12 @@ class GenericRdata(Rdata): _rdata_classes = {} _module_prefix = 'dns.rdtypes' +def _get_all_slots(cls): + all_slots = [] + for scls in cls.__mro__: + all_slots.extend(getattr(scls, '__slots__', [])) + return all_slots + def get_rdata_class(rdclass, rdtype): cls = _rdata_classes.get((rdclass, rdtype)) if not cls: @@ -379,12 +396,17 @@ def get_rdata_class(rdclass, rdtype): mod = import_module('.'.join([_module_prefix, rdclass_text, rdtype_text])) cls = getattr(mod, rdtype_text) + # initialize cls._all_slots to save effort pickling, as + # we don't want to compute the list of all ancestor + # slots every __getstate__(). + cls._all_slots = _get_all_slots(cls) _rdata_classes[(rdclass, rdtype)] = cls except ImportError: try: mod = import_module('.'.join([_module_prefix, 'ANY', rdtype_text])) cls = getattr(mod, rdtype_text) + cls._all_slots = _get_all_slots(cls) _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls _rdata_classes[(rdclass, rdtype)] = cls except ImportError: -- 2.47.3