From: Bob Halley Date: Sat, 8 Aug 2020 14:19:23 +0000 (-0700) Subject: Make SVCB and HTTPS immutable. X-Git-Tag: v2.1.0rc1~104 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=89d7e549b07646cc1cd1a84b0fb943722e2700fb;p=thirdparty%2Fdnspython.git Make SVCB and HTTPS immutable. --- diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py index 22011c8b..212b5cdd 100644 --- a/dns/rdtypes/svcbbase.py +++ b/dns/rdtypes/svcbbase.py @@ -7,6 +7,7 @@ import struct import dns.enum import dns.exception +import dns.immutable import dns.ipv4 import dns.ipv6 import dns.name @@ -140,6 +141,14 @@ def _unescape(value, list_mode=False): class Param: """Abstract base class for SVCB parameters""" + def __setattr__(self, name, value): + # Params are immutable + raise TypeError("object doesn't support attribute assignment") + + def __delattr__(self, name): + # Params are immutable + raise TypeError("object doesn't support attribute deletion") + @classmethod def emptiness(cls): return Emptiness.NEVER @@ -148,7 +157,7 @@ class GenericParam(Param): """Generic SVCB parameter """ def __init__(self, value): - self.value = value + object.__setattr__(self, 'value', value) @classmethod def emptiness(cls): @@ -179,14 +188,16 @@ class GenericParam(Param): class MandatoryParam(Param): def __init__(self, keys): # check for duplicates - self.keys = sorted([_validate_key(key)[0] for key in keys]) + keys = sorted([_validate_key(key)[0] for key in keys]) prior_k = None - for k in self.keys: + for k in keys: if k == prior_k: raise ValueError(f'duplicate key {k}') prior_k = k if k == ParamKey.MANDATORY: raise ValueError('listed the mandatory key as mandatory') + keys = dns.immutable.constify(keys) + object.__setattr__(self, 'keys', keys) @classmethod def from_value(cls, value): @@ -219,7 +230,7 @@ class ALPNParam(Param): raise dns.exception.FormError('empty ALPN') if len(id) > 255: raise ValueError('ALPN id too long') - self.ids = ids + object.__setattr__(self, 'ids', dns.immutable.constify(ids)) @classmethod def from_value(cls, value): @@ -272,13 +283,13 @@ class NoDefaultALPNParam(Param): class PortParam(Param): def __init__(self, port): - self.port = port + if port < 0 or port > 65535: + raise ValueError('port out-of-range') + object.__setattr__(self, 'port', port) @classmethod def from_value(cls, value): value = int(value) - if value < 0 or value > 65535: - raise ValueError('port out-of-range') return cls(value) def to_text(self): @@ -295,14 +306,14 @@ class PortParam(Param): class IPv4HintParam(Param): def __init__(self, addresses): - self.addresses = addresses + for address in addresses: + # check validity + dns.ipv4.inet_aton(address) + object.__setattr__(self, 'addresses', dns.immutable.constify(addresses)) @classmethod def from_value(cls, value): addresses = value.split(',') - for address in addresses: - # check validity - dns.ipv4.inet_aton(address) return cls(addresses) def to_text(self): @@ -323,14 +334,14 @@ class IPv4HintParam(Param): class IPv6HintParam(Param): def __init__(self, addresses): - self.addresses = addresses + for address in addresses: + # check validity + dns.ipv6.inet_aton(address) + object.__setattr__(self, 'addresses', dns.immutable.constify(addresses)) @classmethod def from_value(cls, value): addresses = value.split(',') - for address in addresses: - # check validity - dns.ipv6.inet_aton(address) return cls(addresses) def to_text(self): @@ -351,7 +362,7 @@ class IPv6HintParam(Param): class ECHConfigParam(Param): def __init__(self, echconfig): - self.echconfig = echconfig + object.__setattr__(self, 'echconfig', echconfig) @classmethod def from_value(cls, value): @@ -414,7 +425,7 @@ class SVCBBase(dns.rdata.Rdata): super().__init__(rdclass, rdtype) object.__setattr__(self, 'priority', priority) object.__setattr__(self, 'target', target) - object.__setattr__(self, 'params', params) + object.__setattr__(self, 'params', dns.immutable.constify(params)) # Make sure any paramater listed as mandatory is present in the # record. mandatory = params.get(ParamKey.MANDATORY) diff --git a/tests/test_svcb.py b/tests/test_svcb.py index a11a13bc..9d9bda2f 100644 --- a/tests/test_svcb.py +++ b/tests/test_svcb.py @@ -272,3 +272,14 @@ class SVCBTestCase(unittest.TestCase): wire = bytes.fromhex('0000000000000400010003') with self.assertRaises(dns.exception.FormError): dns.rdata.from_wire('in', 'svcb', wire, 0, len(wire)) + + def test_immutability(self): + alpn = dns.rdtypes.svcbbase.ALPNParam.from_value(['h2', 'h3']) + with self.assertRaises(TypeError): + alpn.ids[0] = 'foo' + with self.assertRaises(TypeError): + del alpn.ids[0] + with self.assertRaises(TypeError): + alpn.ids = 'foo' + with self.assertRaises(TypeError): + del alpn.ids