]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
more constructor checking work
authorBob Halley <halley@dnspython.org>
Wed, 26 Aug 2020 13:20:32 +0000 (06:20 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 26 Aug 2020 13:20:32 +0000 (06:20 -0700)
dns/rdata.py
dns/rdtypes/ANY/CSYNC.py
dns/rdtypes/ANY/HIP.py
dns/rdtypes/ANY/NSEC.py
dns/rdtypes/ANY/NSEC3.py
dns/rdtypes/ANY/OPT.py
dns/rdtypes/svcbbase.py
dns/rdtypes/txtbase.py

index 3a4995486a29753134a3ce783684bf8a25bd6714..60a0d49e5ac2a2606453771e0b2db0d63371120a 100644 (file)
@@ -65,6 +65,7 @@ def _base64ify(data, chunksize=_chunksize):
 
     return _wordbreak(base64.b64encode(data), chunksize)
 
+
 __escaped = b'"\\'
 
 def _escapify(qstring):
@@ -347,7 +348,7 @@ class Rdata:
         return dns.rdatatype.RdataType.make(value)
 
     @classmethod
-    def _as_bytes(cls, value, encode=False, max_length=None):
+    def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True):
         if encode and isinstance(value, str):
             value = value.encode()
         elif isinstance(value, bytearray):
@@ -356,6 +357,8 @@ class Rdata:
             raise ValueError('not bytes')
         if max_length is not None and len(value) > max_length:
             raise ValueError('too long')
+        if not empty_ok and len(value) == 0:
+            raise ValueError('empty bytes not allowed')
         return value
 
     @classmethod
@@ -449,6 +452,17 @@ class Rdata:
         else:
             raise ValueError('not a TTL')
 
+    @classmethod
+    def _as_tuple(cls, value, as_value):
+        try:
+            # For user convenience, if value is a singleton of the list
+            # element type, wrap it in a tuple.
+            return (as_value(value),)
+        except Exception:
+            # Otherwise, check each element of the iterable *value*
+            # against *as_value*.
+            return tuple(as_value(v) for v in value)
+
 
 class GenericRdata(Rdata):
 
index 268fd88ee27a2a24868d0eda7c22630292febb14..0a7925aa178331001693001e2c2e50c3f145926d 100644 (file)
@@ -41,11 +41,9 @@ class CSYNC(dns.rdata.Rdata):
         super().__init__(rdclass, rdtype)
         self.serial = self._as_uint32(serial)
         self.flags = self._as_uint16(flags)
-        if isinstance(windows, Bitmap):
-            bitmap = windows
-        else:
-            bitmap = Bitmap(windows)
-        self.windows = tuple(bitmap.windows)
+        if not isinstance(windows, Bitmap):
+            windows = Bitmap(windows)
+        self.windows = windows.windows
 
     def to_text(self, origin=None, relativize=True, **kw):
         text = Bitmap(self.windows).to_text()
index 2901fddd25bb17a9f707e4e3d81ac52497fe3018..e887359b78856d38caba88f3d1886fcfbb8f876d 100644 (file)
@@ -39,7 +39,7 @@ class HIP(dns.rdata.Rdata):
         self.hit = self._as_bytes(hit, True, 255)
         self.algorithm = self._as_uint8(algorithm)
         self.key = self._as_bytes(key, True)
-        self.servers = tuple([self._as_name(s) for s in servers])
+        self.servers = self._as_tuple(servers, self._as_name)
 
     def to_text(self, origin=None, relativize=True, **kw):
         hit = binascii.hexlify(self.hit).decode()
index 62e449993bd653a7ebeffd3d6661eb137487b049..c3621ac4b06f1ebb3c7dd773ae2a81d93ed766c8 100644 (file)
@@ -38,11 +38,9 @@ class NSEC(dns.rdata.Rdata):
     def __init__(self, rdclass, rdtype, next, windows):
         super().__init__(rdclass, rdtype)
         self.next = self._as_name(next)
-        if isinstance(windows, Bitmap):
-            bitmap = windows
-        else:
-            bitmap = Bitmap(windows)
-        self.windows = tuple(bitmap.windows)
+        if not isinstance(windows, Bitmap):
+            windows = Bitmap(windows)
+        self.windows = windows.windows
 
     def to_text(self, origin=None, relativize=True, **kw):
         next = self.next.choose_relativity(origin, relativize)
index 48a76eebf4a699ecfd38f5e579f8677a2be250de..8089f68096756ec17cef37c32ba4073db14ed5ee 100644 (file)
@@ -58,11 +58,9 @@ class NSEC3(dns.rdata.Rdata):
         self.iterations = self._as_uint16(iterations)
         self.salt = self._as_bytes(salt, True, 255)
         self.next = self._as_bytes(next, True, 255)
-        if isinstance(windows, Bitmap):
-            bitmap = windows
-        else:
-            bitmap = Bitmap(windows)
-        self.windows = tuple(bitmap.windows)
+        if not isinstance(windows, Bitmap):
+            windows = Bitmap(windows)
+        self.windows = windows.windows
 
     def to_text(self, origin=None, relativize=True, **kw):
         next = base64.b32encode(self.next).translate(
index cac54bd8aef5b4fd0b16ac356acb5a9dcd5bf100..69b8fe75d51a1acdb2393e31dfe5e58db7f17361 100644 (file)
@@ -45,10 +45,11 @@ class OPT(dns.rdata.Rdata):
         """
 
         super().__init__(rdclass, rdtype)
-        for option in options:
+        def as_option(option):
             if not isinstance(option, dns.edns.Option):
                 raise ValueError('option is not a dns.edns.option')
-        self.options = tuple(options)
+            return option
+        self.options = self._as_tuple(options, as_option)
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         for opt in self.options:
index ae2a949cb3a00bace989b9d83e22d4fc97e8c0ea..4a3a2179cfb1ea1dbb6dedea97c51f4895aa6b73 100644 (file)
@@ -225,11 +225,8 @@ class MandatoryParam(Param):
 @dns.immutable.immutable
 class ALPNParam(Param):
     def __init__(self, ids):
-        for id in ids:
-            id = dns.rdata.Rdata._as_bytes(id, True, 255)
-            if len(id) == 0:
-                raise dns.exception.FormError('empty ALPN')
-        self.ids = tuple(ids)
+        self.ids = dns.rdata.Rdata._as_tuple(
+            ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False))
 
     @classmethod
     def from_value(cls, value):
@@ -307,10 +304,8 @@ class PortParam(Param):
 @dns.immutable.immutable
 class IPv4HintParam(Param):
     def __init__(self, addresses):
-        for address in addresses:
-            # check validity
-            dns.ipv4.inet_aton(address)
-        self.addresses = tuple(addresses)
+        self.addresses = dns.rdata.Rdata._as_tuple(
+            addresses, dns.rdata.Rdata._as_ipv4_address)
 
     @classmethod
     def from_value(cls, value):
@@ -336,10 +331,8 @@ class IPv4HintParam(Param):
 @dns.immutable.immutable
 class IPv6HintParam(Param):
     def __init__(self, addresses):
-        for address in addresses:
-            # check validity
-            dns.ipv6.inet_aton(address)
-        self.addresses = tuple(addresses)
+        self.addresses = dns.rdata.Rdata._as_tuple(
+            addresses, dns.rdata.Rdata._as_ipv6_address)
 
     @classmethod
     def from_value(cls, value):
@@ -429,7 +422,11 @@ class SVCBBase(dns.rdata.Rdata):
         super().__init__(rdclass, rdtype)
         self.priority = self._as_uint16(priority)
         self.target = self._as_name(target)
-        self.params = dns.immutable.constify(params)
+        for k, v in params.items():
+            k = ParamKey.make(k)
+            if not isinstance(v, Param) and v is not None:
+                raise ValueError("not a Param")
+        self.params = dns.immutable.Dict(params)
         # Make sure any paramater listed as mandatory is present in the
         # record.
         mandatory = params.get(ParamKey.MANDATORY)
index 37bf96170f67abbb5bb51ef795b05b2ebbace4ff..68071ee0abda00da55edb71162841605edad756e 100644 (file)
@@ -42,13 +42,8 @@ class TXTBase(dns.rdata.Rdata):
         *strings*, a tuple of ``bytes``
         """
         super().__init__(rdclass, rdtype)
-        if isinstance(strings, (bytes, str)):
-            strings = (strings,)
-        encoded_strings = []
-        for string in strings:
-            string = self._as_bytes(string, True, 255)
-            encoded_strings.append(string)
-        self.strings = tuple(encoded_strings)
+        self.strings = self._as_tuple(strings,
+                                      lambda x: self._as_bytes(x, True, 255))
 
     def to_text(self, origin=None, relativize=True, **kw):
         txt = ''