]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
The cached _all_slots technique of the earlier pickle fix didn't work for
authorBob Halley <halley@dnspython.org>
Tue, 23 Jun 2020 01:07:25 +0000 (18:07 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 23 Jun 2020 01:08:29 +0000 (18:08 -0700)
GenericRdata or for directly imported types.  This fix just computes the
all slots relatively efficiently every __getstate__().

dns/rdata.py

index d3fd6a6cbdb3ec1453cc69425289364c80db0f85..2de1763a4d703dba7742a7436a6cb324d113a8dd 100644 (file)
@@ -22,6 +22,7 @@ import base64
 import binascii
 import io
 import inspect
+import itertools
 
 import dns.exception
 import dns.name
@@ -127,6 +128,10 @@ class Rdata:
         # Rdatas are immutable
         raise TypeError("object doesn't support attribute deletion")
 
+    def _get_all_slots(self):
+        return itertools.chain.from_iterable(getattr(cls, '__slots__', [])
+                                             for cls in self.__class__.__mro__)
+
     def __getstate__(self):
         # We used to try to do a tuple of all slots here, but it
         # doesn't work as self._all_slots isn't available at
@@ -137,7 +142,7 @@ class Rdata:
         # if you unpickled an A RR it wouldn't have rdclass and rdtype
         # attributes, and would compare badly.
         state = {}
-        for slot in self._all_slots:
+        for slot in self._get_all_slots():
             state[slot] = getattr(self, slot)
         return state
 
@@ -378,12 +383,6 @@ class GenericRdata(Rdata):
 _rdata_classes = {}
 _module_prefix = 'dns.rdtypes'
 
-def _get_all_slots(cls):
-    all_slots = []
-    for scls in cls.__mro__:
-        all_slots.extend(getattr(scls, '__slots__', []))
-    return all_slots
-
 def get_rdata_class(rdclass, rdtype):
     cls = _rdata_classes.get((rdclass, rdtype))
     if not cls:
@@ -396,17 +395,12 @@ def get_rdata_class(rdclass, rdtype):
                 mod = import_module('.'.join([_module_prefix,
                                               rdclass_text, rdtype_text]))
                 cls = getattr(mod, rdtype_text)
-                # initialize cls._all_slots to save effort pickling, as
-                # we don't want to compute the list of all ancestor
-                # slots every __getstate__().
-                cls._all_slots = _get_all_slots(cls)
                 _rdata_classes[(rdclass, rdtype)] = cls
             except ImportError:
                 try:
                     mod = import_module('.'.join([_module_prefix,
                                                   'ANY', rdtype_text]))
                     cls = getattr(mod, rdtype_text)
-                    cls._all_slots = _get_all_slots(cls)
                     _rdata_classes[(dns.rdataclass.ANY, rdtype)] = cls
                     _rdata_classes[(rdclass, rdtype)] = cls
                 except ImportError: