From: Bob Halley Date: Mon, 17 Aug 2020 13:27:07 +0000 (-0700) Subject: Update _clone protocol for immutable rdatasets. X-Git-Tag: v2.1.0rc1~72 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e5545ade37d15ebf87fac626e602ac1ad9852be2;p=thirdparty%2Fdnspython.git Update _clone protocol for immutable rdatasets. --- diff --git a/dns/rdataset.py b/dns/rdataset.py index 1f372cd6..10cb252f 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -312,6 +312,8 @@ class ImmutableRdataset(Rdataset): """An immutable DNS rdataset.""" + _clone_class = Rdataset + def __init__(self, rdataset): """Create an immutable rdataset from the specified rdataset.""" @@ -352,6 +354,21 @@ class ImmutableRdataset(Rdataset): def clear(self): raise TypeError('immutable') + def __copy__(self): + return ImmutableRdataset(super().copy()) + + def copy(self): + return ImmutableRdataset(super().copy()) + + def union(self, other): + return ImmutableRdataset(super().union(other)) + + def intersection(self, other): + return ImmutableRdataset(super().intersection(other)) + + def difference(self, other): + return ImmutableRdataset(super().difference(other)) + def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, origin=None, relativize=True, relativize_to=None): diff --git a/dns/set.py b/dns/set.py index 0982d787..1fd4d0ae 100644 --- a/dns/set.py +++ b/dns/set.py @@ -84,9 +84,13 @@ class Set: subclasses. """ - cls = self.__class__ + if hasattr(self, '_clone_class'): + cls = self._clone_class + else: + cls = self.__class__ obj = cls.__new__(cls) - obj.items = self.items.copy() + obj.items = odict() + obj.items.update(self.items) return obj def __copy__(self): diff --git a/tests/test_rdataset.py b/tests/test_rdataset.py index 88b48400..4710e2a9 100644 --- a/tests/test_rdataset.py +++ b/tests/test_rdataset.py @@ -151,5 +151,14 @@ class ImmutableRdatasetTestCase(unittest.TestCase): with self.assertRaises(TypeError): irds.clear() + def test_cloning(self): + rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2') + rds1 = dns.rdataset.ImmutableRdataset(rds1) + rds2 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2', '10.0.0.3') + rds2 = dns.rdataset.ImmutableRdataset(rds2) + expected = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2') + intersection = rds1.intersection(rds2) + self.assertEqual(intersection, expected) + if __name__ == '__main__': unittest.main()