From e5545ade37d15ebf87fac626e602ac1ad9852be2 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Mon, 17 Aug 2020 06:27:07 -0700 Subject: [PATCH] Update _clone protocol for immutable rdatasets. --- dns/rdataset.py | 17 +++++++++++++++++ dns/set.py | 8 ++++++-- tests/test_rdataset.py | 9 +++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/dns/rdataset.py b/dns/rdataset.py index 1f372cd6..10cb252f 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -312,6 +312,8 @@ class ImmutableRdataset(Rdataset): """An immutable DNS rdataset.""" + _clone_class = Rdataset + def __init__(self, rdataset): """Create an immutable rdataset from the specified rdataset.""" @@ -352,6 +354,21 @@ class ImmutableRdataset(Rdataset): def clear(self): raise TypeError('immutable') + def __copy__(self): + return ImmutableRdataset(super().copy()) + + def copy(self): + return ImmutableRdataset(super().copy()) + + def union(self, other): + return ImmutableRdataset(super().union(other)) + + def intersection(self, other): + return ImmutableRdataset(super().intersection(other)) + + def difference(self, other): + return ImmutableRdataset(super().difference(other)) + def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, origin=None, relativize=True, relativize_to=None): diff --git a/dns/set.py b/dns/set.py index 0982d787..1fd4d0ae 100644 --- a/dns/set.py +++ b/dns/set.py @@ -84,9 +84,13 @@ class Set: subclasses. """ - cls = self.__class__ + if hasattr(self, '_clone_class'): + cls = self._clone_class + else: + cls = self.__class__ obj = cls.__new__(cls) - obj.items = self.items.copy() + obj.items = odict() + obj.items.update(self.items) return obj def __copy__(self): diff --git a/tests/test_rdataset.py b/tests/test_rdataset.py index 88b48400..4710e2a9 100644 --- a/tests/test_rdataset.py +++ b/tests/test_rdataset.py @@ -151,5 +151,14 @@ class ImmutableRdatasetTestCase(unittest.TestCase): with self.assertRaises(TypeError): irds.clear() + def test_cloning(self): + rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2') + rds1 = dns.rdataset.ImmutableRdataset(rds1) + rds2 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2', '10.0.0.3') + rds2 = dns.rdataset.ImmutableRdataset(rds2) + expected = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2') + intersection = rds1.intersection(rds2) + self.assertEqual(intersection, expected) + if __name__ == '__main__': unittest.main() -- 2.47.3