From: Bob Halley Date: Wed, 26 Aug 2020 13:20:32 +0000 (-0700) Subject: more constructor checking work X-Git-Tag: v2.1.0rc1~43 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b4826fb2d890b6ab6936425ed0cb89379b5d5bd9;p=thirdparty%2Fdnspython.git more constructor checking work --- diff --git a/dns/rdata.py b/dns/rdata.py index 3a499548..60a0d49e 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -65,6 +65,7 @@ def _base64ify(data, chunksize=_chunksize): return _wordbreak(base64.b64encode(data), chunksize) + __escaped = b'"\\' def _escapify(qstring): @@ -347,7 +348,7 @@ class Rdata: return dns.rdatatype.RdataType.make(value) @classmethod - def _as_bytes(cls, value, encode=False, max_length=None): + def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True): if encode and isinstance(value, str): value = value.encode() elif isinstance(value, bytearray): @@ -356,6 +357,8 @@ class Rdata: raise ValueError('not bytes') if max_length is not None and len(value) > max_length: raise ValueError('too long') + if not empty_ok and len(value) == 0: + raise ValueError('empty bytes not allowed') return value @classmethod @@ -449,6 +452,17 @@ class Rdata: else: raise ValueError('not a TTL') + @classmethod + def _as_tuple(cls, value, as_value): + try: + # For user convenience, if value is a singleton of the list + # element type, wrap it in a tuple. + return (as_value(value),) + except Exception: + # Otherwise, check each element of the iterable *value* + # against *as_value*. + return tuple(as_value(v) for v in value) + class GenericRdata(Rdata): diff --git a/dns/rdtypes/ANY/CSYNC.py b/dns/rdtypes/ANY/CSYNC.py index 268fd88e..0a7925aa 100644 --- a/dns/rdtypes/ANY/CSYNC.py +++ b/dns/rdtypes/ANY/CSYNC.py @@ -41,11 +41,9 @@ class CSYNC(dns.rdata.Rdata): super().__init__(rdclass, rdtype) self.serial = self._as_uint32(serial) self.flags = self._as_uint16(flags) - if isinstance(windows, Bitmap): - bitmap = windows - else: - bitmap = Bitmap(windows) - self.windows = tuple(bitmap.windows) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = windows.windows def to_text(self, origin=None, relativize=True, **kw): text = Bitmap(self.windows).to_text() diff --git a/dns/rdtypes/ANY/HIP.py b/dns/rdtypes/ANY/HIP.py index 2901fddd..e887359b 100644 --- a/dns/rdtypes/ANY/HIP.py +++ b/dns/rdtypes/ANY/HIP.py @@ -39,7 +39,7 @@ class HIP(dns.rdata.Rdata): self.hit = self._as_bytes(hit, True, 255) self.algorithm = self._as_uint8(algorithm) self.key = self._as_bytes(key, True) - self.servers = tuple([self._as_name(s) for s in servers]) + self.servers = self._as_tuple(servers, self._as_name) def to_text(self, origin=None, relativize=True, **kw): hit = binascii.hexlify(self.hit).decode() diff --git a/dns/rdtypes/ANY/NSEC.py b/dns/rdtypes/ANY/NSEC.py index 62e44999..c3621ac4 100644 --- a/dns/rdtypes/ANY/NSEC.py +++ b/dns/rdtypes/ANY/NSEC.py @@ -38,11 +38,9 @@ class NSEC(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, next, windows): super().__init__(rdclass, rdtype) self.next = self._as_name(next) - if isinstance(windows, Bitmap): - bitmap = windows - else: - bitmap = Bitmap(windows) - self.windows = tuple(bitmap.windows) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = windows.windows def to_text(self, origin=None, relativize=True, **kw): next = self.next.choose_relativity(origin, relativize) diff --git a/dns/rdtypes/ANY/NSEC3.py b/dns/rdtypes/ANY/NSEC3.py index 48a76eeb..8089f680 100644 --- a/dns/rdtypes/ANY/NSEC3.py +++ b/dns/rdtypes/ANY/NSEC3.py @@ -58,11 +58,9 @@ class NSEC3(dns.rdata.Rdata): self.iterations = self._as_uint16(iterations) self.salt = self._as_bytes(salt, True, 255) self.next = self._as_bytes(next, True, 255) - if isinstance(windows, Bitmap): - bitmap = windows - else: - bitmap = Bitmap(windows) - self.windows = tuple(bitmap.windows) + if not isinstance(windows, Bitmap): + windows = Bitmap(windows) + self.windows = windows.windows def to_text(self, origin=None, relativize=True, **kw): next = base64.b32encode(self.next).translate( diff --git a/dns/rdtypes/ANY/OPT.py b/dns/rdtypes/ANY/OPT.py index cac54bd8..69b8fe75 100644 --- a/dns/rdtypes/ANY/OPT.py +++ b/dns/rdtypes/ANY/OPT.py @@ -45,10 +45,11 @@ class OPT(dns.rdata.Rdata): """ super().__init__(rdclass, rdtype) - for option in options: + def as_option(option): if not isinstance(option, dns.edns.Option): raise ValueError('option is not a dns.edns.option') - self.options = tuple(options) + return option + self.options = self._as_tuple(options, as_option) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): for opt in self.options: diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py index ae2a949c..4a3a2179 100644 --- a/dns/rdtypes/svcbbase.py +++ b/dns/rdtypes/svcbbase.py @@ -225,11 +225,8 @@ class MandatoryParam(Param): @dns.immutable.immutable class ALPNParam(Param): def __init__(self, ids): - for id in ids: - id = dns.rdata.Rdata._as_bytes(id, True, 255) - if len(id) == 0: - raise dns.exception.FormError('empty ALPN') - self.ids = tuple(ids) + self.ids = dns.rdata.Rdata._as_tuple( + ids, lambda x: dns.rdata.Rdata._as_bytes(x, True, 255, False)) @classmethod def from_value(cls, value): @@ -307,10 +304,8 @@ class PortParam(Param): @dns.immutable.immutable class IPv4HintParam(Param): def __init__(self, addresses): - for address in addresses: - # check validity - dns.ipv4.inet_aton(address) - self.addresses = tuple(addresses) + self.addresses = dns.rdata.Rdata._as_tuple( + addresses, dns.rdata.Rdata._as_ipv4_address) @classmethod def from_value(cls, value): @@ -336,10 +331,8 @@ class IPv4HintParam(Param): @dns.immutable.immutable class IPv6HintParam(Param): def __init__(self, addresses): - for address in addresses: - # check validity - dns.ipv6.inet_aton(address) - self.addresses = tuple(addresses) + self.addresses = dns.rdata.Rdata._as_tuple( + addresses, dns.rdata.Rdata._as_ipv6_address) @classmethod def from_value(cls, value): @@ -429,7 +422,11 @@ class SVCBBase(dns.rdata.Rdata): super().__init__(rdclass, rdtype) self.priority = self._as_uint16(priority) self.target = self._as_name(target) - self.params = dns.immutable.constify(params) + for k, v in params.items(): + k = ParamKey.make(k) + if not isinstance(v, Param) and v is not None: + raise ValueError("not a Param") + self.params = dns.immutable.Dict(params) # Make sure any paramater listed as mandatory is present in the # record. mandatory = params.get(ParamKey.MANDATORY) diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py index 37bf9617..68071ee0 100644 --- a/dns/rdtypes/txtbase.py +++ b/dns/rdtypes/txtbase.py @@ -42,13 +42,8 @@ class TXTBase(dns.rdata.Rdata): *strings*, a tuple of ``bytes`` """ super().__init__(rdclass, rdtype) - if isinstance(strings, (bytes, str)): - strings = (strings,) - encoded_strings = [] - for string in strings: - string = self._as_bytes(string, True, 255) - encoded_strings.append(string) - self.strings = tuple(encoded_strings) + self.strings = self._as_tuple(strings, + lambda x: self._as_bytes(x, True, 255)) def to_text(self, origin=None, relativize=True, **kw): txt = ''