From 50752b18d423ef0755a4a08cd8ca531698c1165e Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Thu, 2 Apr 2020 11:04:08 -0700 Subject: [PATCH] Add dns.rdata.Rdata.replace() Now that Rdata instances are immutable, there needs to be a way to make a new Rdata based on an existing one. replace() creates a clone of the current Rdata, overriding fields with the specified parameters. --- dns/rdata.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_rdata.py | 17 +++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/dns/rdata.py b/dns/rdata.py index 4e5d36c3..ed565352 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -21,6 +21,7 @@ from importlib import import_module from io import BytesIO import base64 import binascii +import inspect import dns.exception import dns.name @@ -272,6 +273,39 @@ class Rdata(object): def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): raise NotImplementedError + def replace(self, **kwargs): + """ + Create a new Rdata instance based on the instance replace was + invoked on. It is possible to pass different parameters to + override the corresponding properties of the base Rdata. + + Any field specific to the Rdata type can be replaced, but the + *rdtype* and *rdclass* fields cannot. + + Returns an instance of the same Rdata subclass as *self*. + """ + + # Get the constructor parameters. + parameters = inspect.signature(self.__init__).parameters + + # Ensure that all of the arguments correspond to valid fields. + # Don't allow rdclass or rdtype to be changed, though. + for key in kwargs: + if key not in parameters: + raise AttributeError("'{}' object has no attribute '{}'" + .format(self.__class__.__name__, key)) + if key in ('rdclass', 'rdtype'): + raise AttributeError("Cannot overwrite '{}' attribute '{}'" + .format(self.__class__.__name__, key)) + + # Construct the parameter list. For each field, use the value in + # kwargs if present, and the current value otherwise. + args = (kwargs.get(key, getattr(self, key)) for key in parameters) + + # Create and return the new object. + return self.__class__(*args) + + class GenericRdata(Rdata): """Generic Rdata Class diff --git a/tests/test_rdata.py b/tests/test_rdata.py index 7c2c6a59..aed88041 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -47,5 +47,22 @@ class RdataTestCase(unittest.TestCase): dns.rdata.register_type(tests.ttxt_module, TTXTTWO, 'TTXTTWO') self.assertRaises(dns.rdata.RdatatypeExists, bad) + def test_replace(self): + a1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "1.2.3.4") + a2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, "2.3.4.5") + self.assertEqual(a1.replace(address="2.3.4.5"), a2) + + mx = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, + "10 foo.example") + name = dns.name.from_text("bar.example") + self.assertEqual(mx.replace(preference=20).preference, 20) + self.assertEqual(mx.replace(preference=20).exchange, mx.exchange) + self.assertEqual(mx.replace(exchange=name).exchange, name) + self.assertEqual(mx.replace(exchange=name).preference, mx.preference) + + for invalid_parameter in ("rdclass", "rdtype", "foo", "__class__"): + with self.assertRaises(AttributeError): + mx.replace(invalid_parameter=1) + if __name__ == '__main__': unittest.main() -- 2.47.3