]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Update _clone protocol for immutable rdatasets.
authorBob Halley <halley@dnspython.org>
Mon, 17 Aug 2020 13:27:07 +0000 (06:27 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 17 Aug 2020 13:27:07 +0000 (06:27 -0700)
dns/rdataset.py
dns/set.py
tests/test_rdataset.py

index 1f372cd61242a80de8267d2d3fa1385cb135b402..10cb252fa9864dce254186ab6947f59f175be7bb 100644 (file)
@@ -312,6 +312,8 @@ class ImmutableRdataset(Rdataset):
 
     """An immutable DNS rdataset."""
 
+    _clone_class = Rdataset
+
     def __init__(self, rdataset):
         """Create an immutable rdataset from the specified rdataset."""
 
@@ -352,6 +354,21 @@ class ImmutableRdataset(Rdataset):
     def clear(self):
         raise TypeError('immutable')
 
+    def __copy__(self):
+        return ImmutableRdataset(super().copy())
+
+    def copy(self):
+        return ImmutableRdataset(super().copy())
+
+    def union(self, other):
+        return ImmutableRdataset(super().union(other))
+
+    def intersection(self, other):
+        return ImmutableRdataset(super().intersection(other))
+
+    def difference(self, other):
+        return ImmutableRdataset(super().difference(other))
+
 
 def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
                    origin=None, relativize=True, relativize_to=None):
index 0982d787d4582bf371456b8cc5e03ffd27f4b201..1fd4d0ae79919716e17a3225025a7255aa1ea733 100644 (file)
@@ -84,9 +84,13 @@ class Set:
         subclasses.
         """
 
-        cls = self.__class__
+        if hasattr(self, '_clone_class'):
+            cls = self._clone_class
+        else:
+            cls = self.__class__
         obj = cls.__new__(cls)
-        obj.items = self.items.copy()
+        obj.items = odict()
+        obj.items.update(self.items)
         return obj
 
     def __copy__(self):
index 88b48400bd2cf19a571c83d2cc133330568ba76f..4710e2a9b11d59b3bf9d36fd0227d54ba66b7cde 100644 (file)
@@ -151,5 +151,14 @@ class ImmutableRdatasetTestCase(unittest.TestCase):
         with self.assertRaises(TypeError):
             irds.clear()
 
+    def test_cloning(self):
+        rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2')
+        rds1 = dns.rdataset.ImmutableRdataset(rds1)
+        rds2 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2', '10.0.0.3')
+        rds2 = dns.rdataset.ImmutableRdataset(rds2)
+        expected = dns.rdataset.from_text('in', 'a', 300, '10.0.0.2')
+        intersection = rds1.intersection(rds2)
+        self.assertEqual(intersection, expected)
+
 if __name__ == '__main__':
     unittest.main()