]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add dns.rdata.Rdata.replace() 443/head
authorBrian Wellington <bwelling@xbill.org>
Thu, 2 Apr 2020 18:04:08 +0000 (11:04 -0700)
committerBrian Wellington <bwelling@xbill.org>
Thu, 2 Apr 2020 18:04:08 +0000 (11:04 -0700)
Now that Rdata instances are immutable, there needs to be a way to make
a new Rdata based on an existing one.  replace() creates a clone of the
current Rdata, overriding fields with the specified parameters.

dns/rdata.py
tests/test_rdata.py

index 4e5d36c33def397cf98ac366ea091d5db2bb86f8..ed56535232bf4e463cfa77a833e7615ed7a54f1c 100644 (file)
@@ -21,6 +21,7 @@ from importlib import import_module
 from io import BytesIO
 import base64
 import binascii
+import inspect
 
 import dns.exception
 import dns.name
@@ -272,6 +273,39 @@ class Rdata(object):
     def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
         raise NotImplementedError
 
+    def replace(self, **kwargs):
+        """
+        Create a new Rdata instance based on the instance replace was
+        invoked on. It is possible to pass different parameters to
+        override the corresponding properties of the base Rdata.
+
+        Any field specific to the Rdata type can be replaced, but the
+        *rdtype* and *rdclass* fields cannot.
+
+        Returns an instance of the same Rdata subclass as *self*.
+        """
+
+        # Get the constructor parameters.
+        parameters = inspect.signature(self.__init__).parameters
+
+        # Ensure that all of the arguments correspond to valid fields.
+        # Don't allow rdclass or rdtype to be changed, though.
+        for key in kwargs:
+            if key not in parameters:
+                raise AttributeError("'{}' object has no attribute '{}'"
+                                     .format(self.__class__.__name__, key))
+            if key in ('rdclass', 'rdtype'):
+                raise AttributeError("Cannot overwrite '{}' attribute '{}'"
+                                     .format(self.__class__.__name__, key))
+
+        # Construct the parameter list.  For each field, use the value in
+        # kwargs if present, and the current value otherwise.
+        args = (kwargs.get(key, getattr(self, key)) for key in parameters)
+
+        # Create and return the new object.
+        return self.__class__(*args)
+
+
 class GenericRdata(Rdata):
 
     """Generic Rdata Class
index 7c2c6a5966b41445551f919993197a09489142b9..aed88041f3f70dc35d4aa8ee1f0c29a3c4a1c915 100644 (file)
@@ -47,5 +47,22 @@ class RdataTestCase(unittest.TestCase):
             dns.rdata.register_type(tests.ttxt_module, TTXTTWO, 'TTXTTWO')
         self.assertRaises(dns.rdata.RdatatypeExists, bad)
 
+    def test_replace(self):
+        a1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4")
+        a2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "2.3.4.5")
+        self.assertEqual(a1.replace(address="2.3.4.5"), a2)
+
+        mx = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX,
+                                  "10 foo.example")
+        name = dns.name.from_text("bar.example")
+        self.assertEqual(mx.replace(preference=20).preference, 20)
+        self.assertEqual(mx.replace(preference=20).exchange, mx.exchange)
+        self.assertEqual(mx.replace(exchange=name).exchange, name)
+        self.assertEqual(mx.replace(exchange=name).preference, mx.preference)
+
+        for invalid_parameter in ("rdclass", "rdtype", "foo", "__class__"):
+            with self.assertRaises(AttributeError):
+                mx.replace(invalid_parameter=1)
+
 if __name__ == '__main__':
     unittest.main()