]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Make SVCB and HTTPS immutable.
authorBob Halley <halley@dnspython.org>
Sat, 8 Aug 2020 14:19:23 +0000 (07:19 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 8 Aug 2020 14:19:23 +0000 (07:19 -0700)
dns/rdtypes/svcbbase.py
tests/test_svcb.py

index 22011c8bbdc5b0a5cf2afa520735c0bcc022e7c0..212b5cdd068a5c46a045e17199ae44385ed9d9cf 100644 (file)
@@ -7,6 +7,7 @@ import struct
 
 import dns.enum
 import dns.exception
+import dns.immutable
 import dns.ipv4
 import dns.ipv6
 import dns.name
@@ -140,6 +141,14 @@ def _unescape(value, list_mode=False):
 class Param:
     """Abstract base class for SVCB parameters"""
 
+    def __setattr__(self, name, value):
+        # Params are immutable
+        raise TypeError("object doesn't support attribute assignment")
+
+    def __delattr__(self, name):
+        # Params are immutable
+        raise TypeError("object doesn't support attribute deletion")
+
     @classmethod
     def emptiness(cls):
         return Emptiness.NEVER
@@ -148,7 +157,7 @@ class GenericParam(Param):
     """Generic SVCB parameter
     """
     def __init__(self, value):
-        self.value = value
+        object.__setattr__(self, 'value', value)
 
     @classmethod
     def emptiness(cls):
@@ -179,14 +188,16 @@ class GenericParam(Param):
 class MandatoryParam(Param):
     def __init__(self, keys):
         # check for duplicates
-        self.keys = sorted([_validate_key(key)[0] for key in keys])
+        keys = sorted([_validate_key(key)[0] for key in keys])
         prior_k = None
-        for k in self.keys:
+        for k in keys:
             if k == prior_k:
                 raise ValueError(f'duplicate key {k}')
             prior_k = k
             if k == ParamKey.MANDATORY:
                 raise ValueError('listed the mandatory key as mandatory')
+        keys = dns.immutable.constify(keys)
+        object.__setattr__(self, 'keys', keys)
 
     @classmethod
     def from_value(cls, value):
@@ -219,7 +230,7 @@ class ALPNParam(Param):
                 raise dns.exception.FormError('empty ALPN')
             if len(id) > 255:
                 raise ValueError('ALPN id too long')
-        self.ids = ids
+        object.__setattr__(self, 'ids', dns.immutable.constify(ids))
 
     @classmethod
     def from_value(cls, value):
@@ -272,13 +283,13 @@ class NoDefaultALPNParam(Param):
 
 class PortParam(Param):
     def __init__(self, port):
-        self.port = port
+        if port < 0 or port > 65535:
+            raise ValueError('port out-of-range')
+        object.__setattr__(self, 'port', port)
 
     @classmethod
     def from_value(cls, value):
         value = int(value)
-        if value < 0 or value > 65535:
-            raise ValueError('port out-of-range')
         return cls(value)
 
     def to_text(self):
@@ -295,14 +306,14 @@ class PortParam(Param):
 
 class IPv4HintParam(Param):
     def __init__(self, addresses):
-        self.addresses = addresses
+        for address in addresses:
+            # check validity
+            dns.ipv4.inet_aton(address)
+        object.__setattr__(self, 'addresses', dns.immutable.constify(addresses))
 
     @classmethod
     def from_value(cls, value):
         addresses = value.split(',')
-        for address in addresses:
-            # check validity
-            dns.ipv4.inet_aton(address)
         return cls(addresses)
 
     def to_text(self):
@@ -323,14 +334,14 @@ class IPv4HintParam(Param):
 
 class IPv6HintParam(Param):
     def __init__(self, addresses):
-        self.addresses = addresses
+        for address in addresses:
+            # check validity
+            dns.ipv6.inet_aton(address)
+        object.__setattr__(self, 'addresses', dns.immutable.constify(addresses))
 
     @classmethod
     def from_value(cls, value):
         addresses = value.split(',')
-        for address in addresses:
-            # check validity
-            dns.ipv6.inet_aton(address)
         return cls(addresses)
 
     def to_text(self):
@@ -351,7 +362,7 @@ class IPv6HintParam(Param):
 
 class ECHConfigParam(Param):
     def __init__(self, echconfig):
-        self.echconfig = echconfig
+        object.__setattr__(self, 'echconfig', echconfig)
 
     @classmethod
     def from_value(cls, value):
@@ -414,7 +425,7 @@ class SVCBBase(dns.rdata.Rdata):
         super().__init__(rdclass, rdtype)
         object.__setattr__(self, 'priority', priority)
         object.__setattr__(self, 'target', target)
-        object.__setattr__(self, 'params', params)
+        object.__setattr__(self, 'params', dns.immutable.constify(params))
         # Make sure any paramater listed as mandatory is present in the
         # record.
         mandatory = params.get(ParamKey.MANDATORY)
index a11a13bcb328416284e0143e372a475eab7a557d..9d9bda2fde3509cad6bf319dfccd15b24c18026d 100644 (file)
@@ -272,3 +272,14 @@ class SVCBTestCase(unittest.TestCase):
         wire = bytes.fromhex('0000000000000400010003')
         with self.assertRaises(dns.exception.FormError):
             dns.rdata.from_wire('in', 'svcb', wire, 0, len(wire))
+
+    def test_immutability(self):
+        alpn = dns.rdtypes.svcbbase.ALPNParam.from_value(['h2', 'h3'])
+        with self.assertRaises(TypeError):
+            alpn.ids[0] = 'foo'
+        with self.assertRaises(TypeError):
+            del alpn.ids[0]
+        with self.assertRaises(TypeError):
+            alpn.ids = 'foo'
+        with self.assertRaises(TypeError):
+            del alpn.ids