From: Bob Halley Date: Thu, 2 Jul 2020 15:23:52 +0000 (-0700) Subject: Rework wire format processing. X-Git-Tag: v2.0.0rc2~11^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8335194878d419c9698b33ea3d269068387b245d;p=thirdparty%2Fdnspython.git Rework wire format processing. Wire format data is now done via a dns.wire.Parser, which does all of the bookkeeping and also provides convenience routines (e.g. get_uint16() or get_name()). --- diff --git a/dns/__init__.py b/dns/__init__.py index bb87ff47..b944701d 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -54,7 +54,7 @@ __all__ = [ 'rdtypes', 'update', 'version', - 'wiredata', + 'wire', 'zone', ] diff --git a/dns/edns.py b/dns/edns.py index 82e7abcb..4ad3a5c1 100644 --- a/dns/edns.py +++ b/dns/edns.py @@ -72,21 +72,16 @@ class Option: raise NotImplementedError # pragma: no cover @classmethod - def from_wire(cls, otype, wire, current, olen): + def from_wire_parser(cls, otype, parser): """Build an EDNS option object from wire format. *otype*, an ``int``, is the option type. - *wire*, a ``bytes``, is the wire-format message. - - *current*, an ``int``, is the offset in *wire* of the beginning - of the rdata. - - *olen*, an ``int``, is the length of the wire-format option data + *parser*, a ``dns.wire.Parser``, the parser, which should be + restructed to the option length. Returns a ``dns.edns.Option``. """ - raise NotImplementedError # pragma: no cover def _cmp(self, other): @@ -163,8 +158,8 @@ class GenericOption(Option): return "Generic %d" % self.otype @classmethod - def from_wire(cls, otype, wire, current, olen): - return cls(otype, wire[current: current + olen]) + def from_wire_parser(cls, otype, parser): + return cls(otype, parser.get_remaining()) def __str__(self): return self.to_text() @@ -281,18 +276,16 @@ class ECSOption(Option): return value @classmethod - def from_wire(cls, otype, wire, cur, olen): - family, src, scope = struct.unpack('!HBB', wire[cur:cur + 4]) - cur += 4 - + def from_wire_parser(cls, otype, parser): + family, src, scope = parser.get_struct('!HBB') addrlen = int(math.ceil(src / 8.0)) - + prefix = parser.get_bytes(addrlen) if family == 1: pad = 4 - addrlen - addr = dns.ipv4.inet_ntoa(wire[cur:cur + addrlen] + b'\x00' * pad) + addr = dns.ipv4.inet_ntoa(prefix + b'\x00' * pad) elif family == 2: pad = 16 - addrlen - addr = dns.ipv6.inet_ntoa(wire[cur:cur + addrlen] + b'\x00' * pad) + addr = dns.ipv6.inet_ntoa(prefix + b'\x00' * pad) else: raise ValueError('unsupported family') @@ -301,6 +294,7 @@ class ECSOption(Option): def __str__(self): return self.to_text() + _type_to_class = { OptionType.ECS: ECSOption } @@ -318,6 +312,21 @@ def get_option_class(otype): return cls +def option_from_wire_parser(otype, parser): + """Build an EDNS option object from wire format. + + *otype*, an ``int``, is the option type. + + *parser*, a ``dns.wire.Parser``, the parser, which should be + restricted to the option length. + + Returns an instance of a subclass of ``dns.edns.Option``. + """ + cls = get_option_class(otype) + otype = OptionType.make(otype) + return cls.from_wire_parser(otype, parser) + + def option_from_wire(otype, wire, current, olen): """Build an EDNS option object from wire format. @@ -332,7 +341,6 @@ def option_from_wire(otype, wire, current, olen): Returns an instance of a subclass of ``dns.edns.Option``. """ - - cls = get_option_class(otype) - otype = OptionType.make(otype) - return cls.from_wire(otype, wire, current, olen) + parser = dns.wire.Parser(wire, current) + with parser.restrict_to(olen): + return option_from_wire_parser(otype, parser) diff --git a/dns/message.py b/dns/message.py index fdaec026..4b322128 100644 --- a/dns/message.py +++ b/dns/message.py @@ -19,9 +19,9 @@ import contextlib import io -import struct import time +import dns.wire import dns.edns import dns.enum import dns.exception @@ -36,7 +36,6 @@ import dns.rdatatype import dns.rrset import dns.renderer import dns.tsig -import dns.wiredata import dns.rdtypes.ANY.OPT import dns.rdtypes.ANY.TSIG @@ -727,11 +726,8 @@ class _WireReader: """Wire format reader. - wire: a binary, is the wire-format message. + parser: the binary parser message: The message object being built - current: When building a message object from wire format, this - variable contains the offset from the beginning of wire of the next octet - to be read. initialize_message: Callback to set message parsing options question_only: Are we only reading the question? one_rr_per_rrset: Put each RR into its own RRset? @@ -744,9 +740,8 @@ class _WireReader: def __init__(self, wire, initialize_message, question_only=False, one_rr_per_rrset=False, ignore_trailing=False, keyring=None, multi=False): - self.wire = dns.wiredata.maybe_wrap(wire) + self.parser = dns.wire.Parser(wire) self.message = None - self.current = 0 self.initialize_message = initialize_message self.question_only = question_only self.one_rr_per_rrset = one_rr_per_rrset @@ -761,14 +756,8 @@ class _WireReader: section = self.message.sections[section_number] for i in range(qcount): - (qname, used) = dns.name.from_wire(self.wire, self.current) - if self.message.origin is not None: - qname = qname.relativize(self.message.origin) - self.current += used - (rdtype, rdclass) = \ - struct.unpack('!HH', - self.wire[self.current:self.current + 4]) - self.current += 4 + qname = self.parser.get_name(self.message.origin) + (rdtype, rdclass) = self.parser.get_struct('!HH') (rdclass, rdtype, _, _) = \ self.message._parse_rr_header(section_number, qname, rdclass, rdtype) @@ -786,16 +775,13 @@ class _WireReader: section = self.message.sections[section_number] force_unique = self.one_rr_per_rrset for i in range(count): - rr_start = self.current - (name, used) = dns.name.from_wire(self.wire, self.current) - absolute_name = name + rr_start = self.parser.current + absolute_name = self.parser.get_name() if self.message.origin is not None: - name = name.relativize(self.message.origin) - self.current += used - (rdtype, rdclass, ttl, rdlen) = \ - struct.unpack('!HHIH', - self.wire[self.current:self.current + 10]) - self.current += 10 + name = absolute_name.relativize(self.message.origin) + else: + name = absolute_name + (rdtype, rdclass, ttl, rdlen) = self.parser.get_struct('!HHIH') if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG): (rdclass, rdtype, deleting, empty) = \ self.message._parse_special_rr_header(section_number, @@ -811,9 +797,10 @@ class _WireReader: rd = None covers = dns.rdatatype.NONE else: - rd = dns.rdata.from_wire(rdclass, rdtype, - self.wire, self.current, rdlen, - self.message.origin) + with self.parser.restrict_to(rdlen): + rd = dns.rdata.from_wire_parser(rdclass, rdtype, + self.parser, + self.message.origin) covers = rd.covers() if self.message.xfr and rdtype == dns.rdatatype.SOA: force_unique = True @@ -832,7 +819,7 @@ class _WireReader: raise UnknownTSIGKey("key '%s' unknown" % name) self.message.keyring = key self.message.tsig_ctx = \ - dns.tsig.validate(self.wire, + dns.tsig.validate(self.parser.wire, key, absolute_name, rd, @@ -851,18 +838,15 @@ class _WireReader: if ttl > 0x7fffffff: ttl = 0 rrset.add(rd, ttl) - self.current += rdlen def read(self): """Read a wire format DNS message and build a dns.message.Message object.""" - l = len(self.wire) - if l < 12: + if self.parser.remaining() < 12: raise ShortHeader (id, flags, qcount, ancount, aucount, adcount) = \ - struct.unpack('!HHHHHH', self.wire[:12]) - self.current = 12 + self.parser.get_struct('!HHHHHH') factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) self.message = factory(id=id) self.message.flags = flags @@ -875,10 +859,10 @@ class _WireReader: self._get_section(MessageSection.ANSWER, ancount) self._get_section(MessageSection.AUTHORITY, aucount) self._get_section(MessageSection.ADDITIONAL, adcount) - if not self.ignore_trailing and self.current != l: + if not self.ignore_trailing and self.parser.remaining() != 0: raise TrailingJunk if self.multi and self.message.tsig_ctx and not self.message.had_tsig: - self.message.tsig_ctx.update(self.wire) + self.message.tsig_ctx.update(self.parser.wire) return self.message diff --git a/dns/name.py b/dns/name.py index ee87c752..477d0b74 100644 --- a/dns/name.py +++ b/dns/name.py @@ -28,8 +28,8 @@ try: except ImportError: # pragma: no cover have_idna_2008 = False +import dns.wire import dns.exception -import dns.wiredata # fullcompare() result values @@ -966,6 +966,39 @@ def from_text(text, origin=root, idna_codec=None): return Name(labels) +def from_wire_parser(parser): + """Convert possibly compressed wire format into a Name. + + *parser* is a dns.wire.Parser. + + Raises ``dns.name.BadPointer`` if a compression pointer did not + point backwards in the message. + + Raises ``dns.name.BadLabelType`` if an invalid label type was encountered. + + Returns a ``dns.name.Name`` + """ + + labels = [] + biggest_pointer = parser.current + with parser.restore_furthest(): + count = parser.get_uint8() + while count != 0: + if count < 64: + labels.append(parser.get_bytes(count)) + elif count >= 192: + current = (count & 0x3f) * 256 + parser.get_uint8() + if current >= biggest_pointer: + raise BadPointer + biggest_pointer = current + parser.seek(current) + else: + raise BadLabelType + count = parser.get_uint8() + labels.append(b'') + return Name(labels) + + def from_wire(message, current): """Convert possibly compressed wire format into a Name. @@ -987,32 +1020,6 @@ def from_wire(message, current): if not isinstance(message, bytes): raise ValueError("input to from_wire() must be a byte string") - message = dns.wiredata.maybe_wrap(message) - labels = [] - biggest_pointer = current - hops = 0 - count = message[current] - current += 1 - cused = 1 - while count != 0: - if count < 64: - labels.append(message[current: current + count].unwrap()) - current += count - if hops == 0: - cused += count - elif count >= 192: - current = (count & 0x3f) * 256 + message[current] - if hops == 0: - cused += 1 - if current >= biggest_pointer: - raise BadPointer - biggest_pointer = current - hops += 1 - else: - raise BadLabelType - count = message[current] - current += 1 - if hops == 0: - cused += 1 - labels.append('') - return (Name(labels), cused) + parser = dns.wire.Parser(message, current) + name = from_wire_parser(parser) + return (name, parser.current - current) diff --git a/dns/rdata.py b/dns/rdata.py index 2de1763a..64d20245 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -24,12 +24,12 @@ import io import inspect import itertools +import dns.wire import dns.exception import dns.name import dns.rdataclass import dns.rdatatype import dns.tokenizer -import dns.wiredata _hex_chunksize = 32 @@ -377,8 +377,8 @@ class GenericRdata(Rdata): file.write(self.data) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - return cls(rdclass, rdtype, wire[current: current + rdlen]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + return cls(rdclass, rdtype, parser.get_remaining()) _rdata_classes = {} _module_prefix = 'dns.rdtypes' @@ -474,6 +474,36 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, relativize_to) +def from_wire_parser(rdclass, rdtype, parser, origin=None): + """Build an rdata object from wire format + + This function attempts to dynamically load a class which + implements the specified rdata class and type. If there is no + class-and-type-specific implementation, the GenericRdata class + is used. + + Once a class is chosen, its from_wire() class method is called + with the parameters to this function. + + *rdclass*, an ``int``, the rdataclass. + + *rdtype*, an ``int``, the rdatatype. + + *parser*, a ``dns.wire.Parser``, the parser, which should be + restricted to the rdata length. + + *origin*, a ``dns.name.Name`` (or ``None``). If not ``None``, + then names will be relativized to this origin. + + Returns an instance of the chosen Rdata subclass. + """ + + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + cls = get_rdata_class(rdclass, rdtype) + return cls.from_wire_parser(rdclass, rdtype, parser, origin) + + def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): """Build an rdata object from wire format @@ -501,12 +531,9 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): Returns an instance of the chosen Rdata subclass. """ - - wire = dns.wiredata.maybe_wrap(wire) - rdclass = dns.rdataclass.RdataClass.make(rdclass) - rdtype = dns.rdatatype.RdataType.make(rdtype) - cls = get_rdata_class(rdclass, rdtype) - return cls.from_wire(rdclass, rdtype, wire, current, rdlen, origin) + parser = dns.wire.Parser(wire, current) + with parser.restrict_to(rdlen): + return from_wire_parser(rdclass, rdtype, parser, origin) class RdatatypeExists(dns.exception.DNSException): diff --git a/dns/rdtypes/ANY/AMTRELAY.py b/dns/rdtypes/ANY/AMTRELAY.py index 8fb18c19..4e012a27 100644 --- a/dns/rdtypes/ANY/AMTRELAY.py +++ b/dns/rdtypes/ANY/AMTRELAY.py @@ -70,18 +70,10 @@ class AMTRELAY(dns.rdata.Rdata): canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 2: - raise dns.exception.FormError - (precedence, relay_type) = struct.unpack('!BB', - wire[current: current + 2]) - current += 2 - rdlen -= 2 + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (precedence, relay_type) = parser.get_struct('!BB') discovery_optional = bool(relay_type >> 7) relay_type &= 0x7f - (relay, cused) = Relay(relay_type).from_wire(wire, current, rdlen, - origin) - current += cused - rdlen -= cused + relay = Relay(relay_type).from_wire_parser(parser, origin) return cls(rdclass, rdtype, precedence, discovery_optional, relay_type, relay) diff --git a/dns/rdtypes/ANY/CAA.py b/dns/rdtypes/ANY/CAA.py index d498107e..b7edae87 100644 --- a/dns/rdtypes/ANY/CAA.py +++ b/dns/rdtypes/ANY/CAA.py @@ -62,9 +62,8 @@ class CAA(dns.rdata.Rdata): file.write(self.value) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (flags, l) = struct.unpack('!BB', wire[current: current + 2]) - current += 2 - tag = wire[current: current + l] - value = wire[current + l:current + rdlen - 2] + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + flags = parser.get_uint8() + tag = parser.get_counted_bytes() + value = parser.get_remaining() return cls(rdclass, rdtype, flags, tag, value) diff --git a/dns/rdtypes/ANY/CERT.py b/dns/rdtypes/ANY/CERT.py index 98f8b67c..62df241c 100644 --- a/dns/rdtypes/ANY/CERT.py +++ b/dns/rdtypes/ANY/CERT.py @@ -96,13 +96,8 @@ class CERT(dns.rdata.Rdata): file.write(self.certificate) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - prefix = wire[current: current + 5].unwrap() - current += 5 - rdlen -= 5 - if rdlen < 0: - raise dns.exception.FormError - (certificate_type, key_tag, algorithm) = struct.unpack("!HHB", prefix) - certificate = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (certificate_type, key_tag, algorithm) = parser.get_struct("!HHB") + certificate = parser.get_remaining() return cls(rdclass, rdtype, certificate_type, key_tag, algorithm, certificate) diff --git a/dns/rdtypes/ANY/CSYNC.py b/dns/rdtypes/ANY/CSYNC.py index 66f2f31b..c62dad8a 100644 --- a/dns/rdtypes/ANY/CSYNC.py +++ b/dns/rdtypes/ANY/CSYNC.py @@ -93,26 +93,14 @@ class CSYNC(dns.rdata.Rdata): file.write(bitmap) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 6: - raise dns.exception.FormError("CSYNC too short") - (serial, flags) = struct.unpack("!IH", wire[current: current + 6]) - current += 6 - rdlen -= 6 + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (serial, flags) = parser.get_struct("!IH") windows = [] - while rdlen > 0: - if rdlen < 3: - raise dns.exception.FormError("CSYNC too short") - window = wire[current] - octets = wire[current + 1] + while parser.remaining() > 0: + window = parser.get_uint8() + octets = parser.get_uint8() if octets == 0 or octets > 32: raise dns.exception.FormError("bad CSYNC octets") - current += 2 - rdlen -= 2 - if rdlen < octets: - raise dns.exception.FormError("bad CSYNC bitmap length") - bitmap = bytearray(wire[current: current + octets].unwrap()) - current += octets - rdlen -= octets + bitmap = parser.get_bytes(octets) windows.append((window, bitmap)) return cls(rdclass, rdtype, serial, flags, windows) diff --git a/dns/rdtypes/ANY/GPOS.py b/dns/rdtypes/ANY/GPOS.py index e8b69eed..03677fd2 100644 --- a/dns/rdtypes/ANY/GPOS.py +++ b/dns/rdtypes/ANY/GPOS.py @@ -111,29 +111,10 @@ class GPOS(dns.rdata.Rdata): file.write(self.altitude) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - latitude = wire[current: current + l].unwrap() - current += l - rdlen -= l - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - longitude = wire[current: current + l].unwrap() - current += l - rdlen -= l - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - altitude = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + latitude = parser.get_counted_bytes() + longitude = parser.get_counted_bytes() + altitude = parser.get_counted_bytes() return cls(rdclass, rdtype, latitude, longitude, altitude) @property diff --git a/dns/rdtypes/ANY/HINFO.py b/dns/rdtypes/ANY/HINFO.py index dc982f11..587e0ad1 100644 --- a/dns/rdtypes/ANY/HINFO.py +++ b/dns/rdtypes/ANY/HINFO.py @@ -64,19 +64,7 @@ class HINFO(dns.rdata.Rdata): file.write(self.os) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - cpu = wire[current:current + l].unwrap() - current += l - rdlen -= l - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - os = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + cpu = parser.get_counted_bytes() + os = parser.get_counted_bytes() return cls(rdclass, rdtype, cpu, os) diff --git a/dns/rdtypes/ANY/HIP.py b/dns/rdtypes/ANY/HIP.py index 51e42ac9..1c774bbf 100644 --- a/dns/rdtypes/ANY/HIP.py +++ b/dns/rdtypes/ANY/HIP.py @@ -77,24 +77,12 @@ class HIP(dns.rdata.Rdata): server.to_wire(file, None, origin, False) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (lh, algorithm, lk) = struct.unpack('!BBH', - wire[current: current + 4]) - current += 4 - rdlen -= 4 - hit = wire[current: current + lh].unwrap() - current += lh - rdlen -= lh - key = wire[current: current + lk].unwrap() - current += lk - rdlen -= lk + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (lh, algorithm, lk) = parser.get_struct('!BBH') + hit = parser.get_bytes(lh) + key = parser.get_bytes(lk) servers = [] - while rdlen > 0: - (server, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - current += cused - rdlen -= cused - if origin is not None: - server = server.relativize(origin) + while parser.remaining() > 0: + server = parser.get_name(origin) servers.append(server) return cls(rdclass, rdtype, hit, algorithm, key, servers) diff --git a/dns/rdtypes/ANY/ISDN.py b/dns/rdtypes/ANY/ISDN.py index 37332321..6834b3c7 100644 --- a/dns/rdtypes/ANY/ISDN.py +++ b/dns/rdtypes/ANY/ISDN.py @@ -74,22 +74,10 @@ class ISDN(dns.rdata.Rdata): file.write(self.subaddress) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - address = wire[current: current + l].unwrap() - current += l - rdlen -= l - if rdlen > 0: - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - subaddress = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_counted_bytes() + if parser.remaining() > 0: + subaddress = parser.get_counted_bytes() else: - subaddress = '' + subaddress = b'' return cls(rdclass, rdtype, address, subaddress) diff --git a/dns/rdtypes/ANY/LOC.py b/dns/rdtypes/ANY/LOC.py index 31065bba..eb00a1cd 100644 --- a/dns/rdtypes/ANY/LOC.py +++ b/dns/rdtypes/ANY/LOC.py @@ -293,9 +293,9 @@ class LOC(dns.rdata.Rdata): file.write(wire) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): (version, size, hprec, vprec, latitude, longitude, altitude) = \ - struct.unpack("!BBBBIII", wire[current: current + rdlen]) + parser.get_struct("!BBBBIII") if latitude < _MIN_LATITUDE or latitude > _MAX_LATITUDE: raise dns.exception.FormError("bad latitude") if latitude > 0x80000000: diff --git a/dns/rdtypes/ANY/NSEC.py b/dns/rdtypes/ANY/NSEC.py index 0f79d94a..8c1da5ae 100644 --- a/dns/rdtypes/ANY/NSEC.py +++ b/dns/rdtypes/ANY/NSEC.py @@ -93,26 +93,13 @@ class NSEC(dns.rdata.Rdata): file.write(bitmap) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (next, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + next = parser.get_name(origin) windows = [] - while rdlen > 0: - if rdlen < 3: - raise dns.exception.FormError("NSEC too short") - window = wire[current] - octets = wire[current + 1] - if octets == 0 or octets > 32: + while parser.remaining() > 0: + window = parser.get_uint8() + bitmap = parser.get_counted_bytes() + if len(bitmap) == 0 or len(bitmap) > 32: raise dns.exception.FormError("bad NSEC octets") - current += 2 - rdlen -= 2 - if rdlen < octets: - raise dns.exception.FormError("bad NSEC bitmap length") - bitmap = bytearray(wire[current: current + octets].unwrap()) - current += octets - rdlen -= octets windows.append((window, bitmap)) - if origin is not None: - next = next.relativize(origin) return cls(rdclass, rdtype, next, windows) diff --git a/dns/rdtypes/ANY/NSEC3.py b/dns/rdtypes/ANY/NSEC3.py index 208f9e8c..32dfe3e0 100644 --- a/dns/rdtypes/ANY/NSEC3.py +++ b/dns/rdtypes/ANY/NSEC3.py @@ -138,36 +138,16 @@ class NSEC3(dns.rdata.Rdata): file.write(bitmap) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (algorithm, flags, iterations, slen) = \ - struct.unpack('!BBHB', wire[current: current + 5]) - - current += 5 - rdlen -= 5 - salt = wire[current: current + slen].unwrap() - current += slen - rdlen -= slen - nlen = wire[current] - current += 1 - rdlen -= 1 - next = wire[current: current + nlen].unwrap() - current += nlen - rdlen -= nlen + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (algorithm, flags, iterations) = parser.get_struct('!BBH') + salt = parser.get_counted_bytes() + next = parser.get_counted_bytes() windows = [] - while rdlen > 0: - if rdlen < 3: - raise dns.exception.FormError("NSEC3 too short") - window = wire[current] - octets = wire[current + 1] - if octets == 0 or octets > 32: + while parser.remaining() > 0: + window = parser.get_uint8() + bitmap = parser.get_counted_bytes() + if len(bitmap) == 0 or len(bitmap) > 32: raise dns.exception.FormError("bad NSEC3 octets") - current += 2 - rdlen -= 2 - if rdlen < octets: - raise dns.exception.FormError("bad NSEC3 bitmap length") - bitmap = bytearray(wire[current: current + octets].unwrap()) - current += octets - rdlen -= octets windows.append((window, bitmap)) return cls(rdclass, rdtype, algorithm, flags, iterations, salt, next, windows) diff --git a/dns/rdtypes/ANY/NSEC3PARAM.py b/dns/rdtypes/ANY/NSEC3PARAM.py index d10321b3..8ac76271 100644 --- a/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/dns/rdtypes/ANY/NSEC3PARAM.py @@ -67,15 +67,7 @@ class NSEC3PARAM(dns.rdata.Rdata): file.write(self.salt) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (algorithm, flags, iterations, slen) = \ - struct.unpack('!BBHB', - wire[current: current + 5]) - current += 5 - rdlen -= 5 - salt = wire[current: current + slen].unwrap() - current += slen - rdlen -= slen - if rdlen != 0: - raise dns.exception.FormError + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (algorithm, flags, iterations) = parser.get_struct('!BBH') + salt = parser.get_counted_bytes() return cls(rdclass, rdtype, algorithm, flags, iterations, salt) diff --git a/dns/rdtypes/ANY/OPENPGPKEY.py b/dns/rdtypes/ANY/OPENPGPKEY.py index eebe2a98..f632132e 100644 --- a/dns/rdtypes/ANY/OPENPGPKEY.py +++ b/dns/rdtypes/ANY/OPENPGPKEY.py @@ -45,6 +45,6 @@ class OPENPGPKEY(dns.rdata.Rdata): file.write(self.key) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - key = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + key = parser.get_remaining() return cls(rdclass, rdtype, key) diff --git a/dns/rdtypes/ANY/OPT.py b/dns/rdtypes/ANY/OPT.py index 0a0e7afe..c48aa12f 100644 --- a/dns/rdtypes/ANY/OPT.py +++ b/dns/rdtypes/ANY/OPT.py @@ -52,19 +52,12 @@ class OPT(dns.rdata.Rdata): return ' '.join(opt.to_text() for opt in self.options) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): options = [] - while rdlen > 0: - if rdlen < 4: - raise dns.exception.FormError - (otype, olen) = struct.unpack('!HH', wire[current:current + 4]) - current += 4 - rdlen -= 4 - if olen > rdlen: - raise dns.exception.FormError - opt = dns.edns.option_from_wire(otype, wire, current, olen) - current += olen - rdlen -= olen + while parser.remaining() > 0: + (otype, olen) = parser.get_struct('!HH') + with parser.restrict_to(olen): + opt = dns.edns.option_from_wire_parser(otype, parser) options.append(opt) return cls(rdclass, rdtype, options) diff --git a/dns/rdtypes/ANY/RP.py b/dns/rdtypes/ANY/RP.py index fa3aaac9..7446de6d 100644 --- a/dns/rdtypes/ANY/RP.py +++ b/dns/rdtypes/ANY/RP.py @@ -51,18 +51,7 @@ class RP(dns.rdata.Rdata): self.txt.to_wire(file, None, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (mbox, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - current += cused - rdlen -= cused - if rdlen <= 0: - raise dns.exception.FormError - (txt, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - mbox = mbox.relativize(origin) - txt = txt.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + mbox = parser.get_name(origin) + txt = parser.get_name(origin) return cls(rdclass, rdtype, mbox, txt) diff --git a/dns/rdtypes/ANY/RRSIG.py b/dns/rdtypes/ANY/RRSIG.py index ddc6141c..2077d905 100644 --- a/dns/rdtypes/ANY/RRSIG.py +++ b/dns/rdtypes/ANY/RRSIG.py @@ -115,16 +115,8 @@ class RRSIG(dns.rdata.Rdata): file.write(self.signature) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack('!HBBIIIH', wire[current: current + 18]) - current += 18 - rdlen -= 18 - (signer, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused - if origin is not None: - signer = signer.relativize(origin) - signature = wire[current: current + rdlen].unwrap() - return cls(rdclass, rdtype, header[0], header[1], header[2], - header[3], header[4], header[5], header[6], signer, - signature) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct('!HBBIIIH') + signer = parser.get_name(origin) + signature = parser.get_remaining() + return cls(rdclass, rdtype, *header, signer, signature) diff --git a/dns/rdtypes/ANY/SOA.py b/dns/rdtypes/ANY/SOA.py index cd59ec0d..e93274ed 100644 --- a/dns/rdtypes/ANY/SOA.py +++ b/dns/rdtypes/ANY/SOA.py @@ -71,20 +71,7 @@ class SOA(dns.rdata.Rdata): file.write(five_ints) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (mname, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused - (rname, cused) = dns.name.from_wire(wire[: current + rdlen], current) - current += cused - rdlen -= cused - if rdlen != 20: - raise dns.exception.FormError - five_ints = struct.unpack('!IIIII', - wire[current: current + rdlen]) - if origin is not None: - mname = mname.relativize(origin) - rname = rname.relativize(origin) - return cls(rdclass, rdtype, mname, rname, - five_ints[0], five_ints[1], five_ints[2], five_ints[3], - five_ints[4]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + mname = parser.get_name(origin) + rname = parser.get_name(origin) + return cls(rdclass, rdtype, mname, rname, *parser.get_struct('!IIIII')) diff --git a/dns/rdtypes/ANY/SSHFP.py b/dns/rdtypes/ANY/SSHFP.py index 49a1239b..a3cc0039 100644 --- a/dns/rdtypes/ANY/SSHFP.py +++ b/dns/rdtypes/ANY/SSHFP.py @@ -58,9 +58,7 @@ class SSHFP(dns.rdata.Rdata): file.write(self.fingerprint) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack("!BB", wire[current: current + 2]) - current += 2 - rdlen -= 2 - fingerprint = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("BB") + fingerprint = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], fingerprint) diff --git a/dns/rdtypes/ANY/TLSA.py b/dns/rdtypes/ANY/TLSA.py index 3ffaaca6..9c9c8662 100644 --- a/dns/rdtypes/ANY/TLSA.py +++ b/dns/rdtypes/ANY/TLSA.py @@ -61,9 +61,7 @@ class TLSA(dns.rdata.Rdata): file.write(self.cert) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack("!BBB", wire[current: current + 3]) - current += 3 - rdlen -= 3 - cert = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("BBB") + cert = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], header[2], cert) diff --git a/dns/rdtypes/ANY/TSIG.py b/dns/rdtypes/ANY/TSIG.py index 002c2dbc..0d4b48b6 100644 --- a/dns/rdtypes/ANY/TSIG.py +++ b/dns/rdtypes/ANY/TSIG.py @@ -80,33 +80,12 @@ class TSIG(dns.rdata.Rdata): file.write(self.other) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (algorithm, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - current += cused - rdlen -= cused - if rdlen < 10: - raise dns.exception.FormError - (time_hi, time_lo, fudge, mac_len) = \ - struct.unpack('!HIHH', wire[current: current + 10]) - current += 10 - rdlen -= 10 + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + algorithm = parser.get_name(origin) + (time_hi, time_lo, fudge) = parser.get_struct('!HIH') time_signed = (time_hi << 32) + time_lo - if rdlen < mac_len: - raise dns.exception.FormError - mac = wire[current: current + mac_len].unwrap() - current += mac_len - rdlen -= mac_len - if rdlen < 6: - raise dns.exception.FormError - (original_id, error, other_len) = \ - struct.unpack('!HHH', wire[current: current + 6]) - current += 6 - rdlen -= 6 - if rdlen < other_len: - raise dns.exception.FormError - other = wire[current: current + other_len].unwrap() - current += other_len - rdlen -= other_len + mac = parser.get_counted_bytes(2) + (original_id, error) = parser.get_struct('!HH') + other = parser.get_counted_bytes(2) return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, original_id, error, other) diff --git a/dns/rdtypes/ANY/URI.py b/dns/rdtypes/ANY/URI.py index 77f89645..84296f52 100644 --- a/dns/rdtypes/ANY/URI.py +++ b/dns/rdtypes/ANY/URI.py @@ -63,14 +63,9 @@ class URI(dns.rdata.Rdata): file.write(self.target) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 5: - raise dns.exception.FormError('URI RR is shorter than 5 octets') - - (priority, weight) = struct.unpack('!HH', wire[current: current + 4]) - current += 4 - rdlen -= 4 - target = wire[current: current + rdlen] - current += rdlen - + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (priority, weight) = parser.get_struct('!HH') + target = parser.get_remaining() + if len(target) == 0: + raise dns.exception.FormError('URI target may not be empty') return cls(rdclass, rdtype, priority, weight, target) diff --git a/dns/rdtypes/ANY/X25.py b/dns/rdtypes/ANY/X25.py index ac61849e..214f1dca 100644 --- a/dns/rdtypes/ANY/X25.py +++ b/dns/rdtypes/ANY/X25.py @@ -54,11 +54,6 @@ class X25(dns.rdata.Rdata): file.write(self.address) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - l = wire[current] - current += 1 - rdlen -= 1 - if l != rdlen: - raise dns.exception.FormError - address = wire[current: current + l].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_counted_bytes() return cls(rdclass, rdtype, address) diff --git a/dns/rdtypes/CH/A.py b/dns/rdtypes/CH/A.py index ca349320..b738ac6c 100644 --- a/dns/rdtypes/CH/A.py +++ b/dns/rdtypes/CH/A.py @@ -18,7 +18,7 @@ import dns.rdtypes.mxbase import struct -class A(dns.rdtypes.mxbase.MXBase): +class A(dns.rdata.Rdata): """A record for Chaosnet""" @@ -27,8 +27,8 @@ class A(dns.rdtypes.mxbase.MXBase): __slots__ = ['domain', 'address'] - def __init__(self, rdclass, rdtype, address, domain): - super().__init__(rdclass, rdtype, address, domain) + def __init__(self, rdclass, rdtype, domain, address): + super().__init__(rdclass, rdtype) object.__setattr__(self, 'domain', domain) object.__setattr__(self, 'address', address) @@ -42,7 +42,7 @@ class A(dns.rdtypes.mxbase.MXBase): domain = tok.get_name(origin, relativize, relativize_to) address = tok.get_uint16(base=8) tok.get_eol() - return cls(rdclass, rdtype, address, domain) + return cls(rdclass, rdtype, domain, address) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): self.domain.to_wire(file, compress, origin, canonicalize) @@ -50,13 +50,7 @@ class A(dns.rdtypes.mxbase.MXBase): file.write(pref) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (domain, cused) = dns.name.from_wire(wire[:current + rdlen - 2], - current) - current += cused - (address,) = struct.unpack('!H', wire[current:current + 2]) - if cused + 2 != rdlen: - raise dns.exception.FormError - if origin is not None: - domain = domain.relativize(origin) - return cls(rdclass, rdtype, address, domain) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + domain = parser.get_name(origin) + address = parser.get_uint16() + return cls(rdclass, rdtype, domain, address) diff --git a/dns/rdtypes/IN/A.py b/dns/rdtypes/IN/A.py index 461ce927..8b71e329 100644 --- a/dns/rdtypes/IN/A.py +++ b/dns/rdtypes/IN/A.py @@ -47,6 +47,6 @@ class A(dns.rdata.Rdata): file.write(dns.ipv4.inet_aton(self.address)) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = dns.ipv4.inet_ntoa(wire[current: current + rdlen]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = dns.ipv4.inet_ntoa(parser.get_remaining()) return cls(rdclass, rdtype, address) diff --git a/dns/rdtypes/IN/AAAA.py b/dns/rdtypes/IN/AAAA.py index 9751f82c..08f9d679 100644 --- a/dns/rdtypes/IN/AAAA.py +++ b/dns/rdtypes/IN/AAAA.py @@ -47,6 +47,6 @@ class AAAA(dns.rdata.Rdata): file.write(dns.ipv6.inet_aton(self.address)) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = dns.ipv6.inet_ntoa(wire[current: current + rdlen]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = dns.ipv6.inet_ntoa(parser.get_remaining()) return cls(rdclass, rdtype, address) diff --git a/dns/rdtypes/IN/APL.py b/dns/rdtypes/IN/APL.py index 7149a648..ab7fe4bc 100644 --- a/dns/rdtypes/IN/APL.py +++ b/dns/rdtypes/IN/APL.py @@ -111,26 +111,18 @@ class APL(dns.rdata.Rdata): item.to_wire(file) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): items = [] - while 1: - if rdlen == 0: - break - if rdlen < 4: - raise dns.exception.FormError - header = struct.unpack('!HBB', wire[current: current + 4]) + while parser.remaining() > 0: + header = parser.get_struct('!HBB') afdlen = header[2] if afdlen > 127: negation = True afdlen -= 128 else: negation = False - current += 4 - rdlen -= 4 - if rdlen < afdlen: - raise dns.exception.FormError - address = wire[current: current + afdlen].unwrap() + address = parser.get_bytes(afdlen) l = len(address) if header[0] == 1: if l < 4: @@ -146,8 +138,6 @@ class APL(dns.rdata.Rdata): # seems better than throwing an exception # address = codecs.encode(address, 'hex_codec') - current += afdlen - rdlen -= afdlen item = APLItem(header[0], negation, address, header[1]) items.append(item) return cls(rdclass, rdtype, items) diff --git a/dns/rdtypes/IN/DHCID.py b/dns/rdtypes/IN/DHCID.py index da834500..6f66eb89 100644 --- a/dns/rdtypes/IN/DHCID.py +++ b/dns/rdtypes/IN/DHCID.py @@ -46,6 +46,6 @@ class DHCID(dns.rdata.Rdata): file.write(self.data) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - data = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + data = parser.get_remaining() return cls(rdclass, rdtype, data) diff --git a/dns/rdtypes/IN/IPSECKEY.py b/dns/rdtypes/IN/IPSECKEY.py index 6d0f4259..182ad2cb 100644 --- a/dns/rdtypes/IN/IPSECKEY.py +++ b/dns/rdtypes/IN/IPSECKEY.py @@ -72,19 +72,10 @@ class IPSECKEY(dns.rdata.Rdata): file.write(self.key) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 3: - raise dns.exception.FormError - header = struct.unpack('!BBB', wire[current: current + 3]) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct('!BBB') gateway_type = header[1] - current += 3 - rdlen -= 3 - (gateway, cused) = Gateway(gateway_type).from_wire(wire, current, - rdlen, origin) - current += cused - rdlen -= cused - key = wire[current: current + rdlen].unwrap() - if origin is not None and gateway_type == 3: - gateway = gateway.relativize(origin) + gateway = Gateway(gateway_type).from_wire_parser(parser, origin) + key = parser.get_remaining() return cls(rdclass, rdtype, header[0], gateway_type, header[2], gateway, key) diff --git a/dns/rdtypes/IN/NAPTR.py b/dns/rdtypes/IN/NAPTR.py index d8f2958c..48d43562 100644 --- a/dns/rdtypes/IN/NAPTR.py +++ b/dns/rdtypes/IN/NAPTR.py @@ -85,26 +85,12 @@ class NAPTR(dns.rdata.Rdata): self.replacement.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (order, preference) = struct.unpack('!HH', wire[current: current + 4]) - current += 4 - rdlen -= 4 + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (order, preference) = parser.get_struct('!HH') strings = [] for i in range(3): - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen or rdlen < 0: - raise dns.exception.FormError - s = wire[current: current + l].unwrap() - current += l - rdlen -= l + s = parser.get_counted_bytes() strings.append(s) - (replacement, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - replacement = replacement.relativize(origin) + replacement = parser.get_name(origin) return cls(rdclass, rdtype, order, preference, strings[0], strings[1], strings[2], replacement) diff --git a/dns/rdtypes/IN/NSAP.py b/dns/rdtypes/IN/NSAP.py index a42e5928..227465fa 100644 --- a/dns/rdtypes/IN/NSAP.py +++ b/dns/rdtypes/IN/NSAP.py @@ -54,6 +54,6 @@ class NSAP(dns.rdata.Rdata): file.write(self.address) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = parser.get_remaining() return cls(rdclass, rdtype, address) diff --git a/dns/rdtypes/IN/PX.py b/dns/rdtypes/IN/PX.py index e8e0adac..946d79f8 100644 --- a/dns/rdtypes/IN/PX.py +++ b/dns/rdtypes/IN/PX.py @@ -57,22 +57,8 @@ class PX(dns.rdata.Rdata): self.mapx400.to_wire(file, None, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (preference, ) = struct.unpack('!H', wire[current: current + 2]) - current += 2 - rdlen -= 2 - (map822, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused > rdlen: - raise dns.exception.FormError - current += cused - rdlen -= cused - if origin is not None: - map822 = map822.relativize(origin) - (mapx400, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - mapx400 = mapx400.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + map822 = parser.get_name(origin) + mapx400 = parser.get_name(origin) return cls(rdclass, rdtype, preference, map822, mapx400) diff --git a/dns/rdtypes/IN/SRV.py b/dns/rdtypes/IN/SRV.py index 02603f24..485153f4 100644 --- a/dns/rdtypes/IN/SRV.py +++ b/dns/rdtypes/IN/SRV.py @@ -58,15 +58,7 @@ class SRV(dns.rdata.Rdata): self.target.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (priority, weight, port) = struct.unpack('!HHH', - wire[current: current + 6]) - current += 6 - rdlen -= 6 - (target, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - target = target.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + (priority, weight, port) = parser.get_struct('!HHH') + target = parser.get_name(origin) return cls(rdclass, rdtype, priority, weight, port, target) diff --git a/dns/rdtypes/IN/WKS.py b/dns/rdtypes/IN/WKS.py index a5641358..d66d8583 100644 --- a/dns/rdtypes/IN/WKS.py +++ b/dns/rdtypes/IN/WKS.py @@ -89,10 +89,8 @@ class WKS(dns.rdata.Rdata): file.write(self.bitmap) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = dns.ipv4.inet_ntoa(wire[current: current + 4]) - protocol, = struct.unpack('!B', wire[current + 4: current + 5]) - current += 5 - rdlen -= 5 - bitmap = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + address = dns.ipv4.inet_ntoa(parser.get_bytes(4)) + protocol = parser.get_uint8() + bitmap = parser.get_remaining() return cls(rdclass, rdtype, address, protocol, bitmap) diff --git a/dns/rdtypes/dnskeybase.py b/dns/rdtypes/dnskeybase.py index 31fa0ecf..0243d6f3 100644 --- a/dns/rdtypes/dnskeybase.py +++ b/dns/rdtypes/dnskeybase.py @@ -67,12 +67,8 @@ class DNSKEYBase(dns.rdata.Rdata): file.write(self.key) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - if rdlen < 4: - raise dns.exception.FormError - header = struct.unpack('!HBB', wire[current: current + 4]) - current += 4 - rdlen -= 4 - key = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct('!HBB') + key = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], header[2], key) diff --git a/dns/rdtypes/dnskeybase.pyi b/dns/rdtypes/dnskeybase.pyi index 2588cf37..1b999cfd 100644 --- a/dns/rdtypes/dnskeybase.pyi +++ b/dns/rdtypes/dnskeybase.pyi @@ -31,7 +31,7 @@ class DNSKEYBase(rdata.Rdata): ... @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_parser(cls, rdclass, rdtype, parser, origin=None): ... def flags_to_text_set(self) -> Set[str]: diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py index dec16f3f..d7850bee 100644 --- a/dns/rdtypes/dsbase.py +++ b/dns/rdtypes/dsbase.py @@ -61,9 +61,7 @@ class DSBase(dns.rdata.Rdata): file.write(self.digest) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - header = struct.unpack("!HBB", wire[current: current + 4]) - current += 4 - rdlen -= 4 - digest = wire[current: current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + header = parser.get_struct("!HBB") + digest = parser.get_remaining() return cls(rdclass, rdtype, header[0], header[1], header[2], digest) diff --git a/dns/rdtypes/euibase.py b/dns/rdtypes/euibase.py index 7756538c..c1677a81 100644 --- a/dns/rdtypes/euibase.py +++ b/dns/rdtypes/euibase.py @@ -63,6 +63,6 @@ class EUIBase(dns.rdata.Rdata): file.write(self.eui) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - eui = wire[current:current + rdlen].unwrap() + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + eui = parser.get_bytes(cls.byte_len) return cls(rdclass, rdtype, eui) diff --git a/dns/rdtypes/mxbase.py b/dns/rdtypes/mxbase.py index 63fd8697..d6a6efed 100644 --- a/dns/rdtypes/mxbase.py +++ b/dns/rdtypes/mxbase.py @@ -53,16 +53,9 @@ class MXBase(dns.rdata.Rdata): self.exchange.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (preference, ) = struct.unpack('!H', wire[current: current + 2]) - current += 2 - rdlen -= 2 - (exchange, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - exchange = exchange.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + preference = parser.get_uint16() + exchange = parser.get_name(origin) return cls(rdclass, rdtype, preference, exchange) diff --git a/dns/rdtypes/nsbase.py b/dns/rdtypes/nsbase.py index 28daca20..93d3ee53 100644 --- a/dns/rdtypes/nsbase.py +++ b/dns/rdtypes/nsbase.py @@ -47,13 +47,8 @@ class NSBase(dns.rdata.Rdata): self.target.to_wire(file, compress, origin, canonicalize) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - (target, cused) = dns.name.from_wire(wire[: current + rdlen], - current) - if cused != rdlen: - raise dns.exception.FormError - if origin is not None: - target = target.relativize(origin) + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + target = parser.get_name(origin) return cls(rdclass, rdtype, target) diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py index ec061a18..ad0093da 100644 --- a/dns/rdtypes/txtbase.py +++ b/dns/rdtypes/txtbase.py @@ -84,16 +84,9 @@ class TXTBase(dns.rdata.Rdata): file.write(s) @classmethod - def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): strings = [] - while rdlen > 0: - l = wire[current] - current += 1 - rdlen -= 1 - if l > rdlen: - raise dns.exception.FormError - s = wire[current: current + l].unwrap() - current += l - rdlen -= l + while parser.remaining() > 0: + s = parser.get_counted_bytes() strings.append(s) return cls(rdclass, rdtype, strings) diff --git a/dns/rdtypes/util.py b/dns/rdtypes/util.py index 5d1f6c92..3dc636d0 100644 --- a/dns/rdtypes/util.py +++ b/dns/rdtypes/util.py @@ -78,14 +78,14 @@ class Gateway: else: raise ValueError(self._invalid_type()) - def from_wire(self, wire, current, rdlen, origin=None): + def from_wire_parser(self, parser, origin=None): if self.type == 0: - return (None, 0) + return None elif self.type == 1: - return (dns.ipv4.inet_ntoa(wire[current: current + 4]), 4) + return dns.ipv4.inet_ntoa(parser.get_bytes(4)) elif self.type == 2: - return (dns.ipv6.inet_ntoa(wire[current: current + 16]), 16) + return dns.ipv6.inet_ntoa(parser.get_bytes(16)) elif self.type == 3: - return dns.name.from_wire(wire[: current + rdlen], current) + return parser.get_name(origin) else: raise dns.exception.FormError(self._invalid_type()) diff --git a/dns/wire.py b/dns/wire.py new file mode 100644 index 00000000..a3149605 --- /dev/null +++ b/dns/wire.py @@ -0,0 +1,82 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import contextlib +import struct + +import dns.exception +import dns.name + +class Parser: + def __init__(self, wire, current=0): + self.wire = wire + self.current = 0 + self.end = len(self.wire) + if current: + self.seek(current) + self.furthest = current + + def remaining(self): + return self.end - self.current + + def get_bytes(self, size): + if size > self.remaining(): + raise dns.exception.FormError + output = self.wire[self.current:self.current + size] + self.current += size + self.furthest = max(self.furthest, self.current) + return output + + def get_counted_bytes(self, length_size=1): + length = int.from_bytes(self.get_bytes(length_size), 'big') + return self.get_bytes(length) + + def get_remaining(self): + return self.get_bytes(self.remaining()) + + def get_uint8(self): + return struct.unpack('!B', self.get_bytes(1))[0] + + def get_uint16(self): + return struct.unpack('!H', self.get_bytes(2))[0] + + def get_uint32(self): + return struct.unpack('!I', self.get_bytes(4))[0] + + def get_struct(self, format): + return struct.unpack(format, self.get_bytes(struct.calcsize(format))) + + def get_name(self, origin=None): + name = dns.name.from_wire_parser(self) + if origin: + name = name.relativize(origin) + return name + + def seek(self, where): + # Note that seeking to the end is OK! (If you try to read + # after such a seek, you'll get an exception as expected.) + if where < 0 or where > self.end: + raise dns.exception.FormError + self.current = where + + @contextlib.contextmanager + def restrict_to(self, size): + if size > self.remaining(): + raise dns.exception.FormError + saved_end = self.end + try: + self.end = self.current + size + yield + # We make this check here and not in the finally as we + # don't want to raise if we're already raising for some + # other reason. + if self.current != self.end: + raise dns.exception.FormError + finally: + self.end = saved_end + + @contextlib.contextmanager + def restore_furthest(self): + try: + yield None + finally: + self.current = self.furthest diff --git a/dns/wiredata.py b/dns/wiredata.py deleted file mode 100644 index 51f12fc3..00000000 --- a/dns/wiredata.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -# Copyright (C) 2011,2017 Nominum, Inc. -# -# Permission to use, copy, modify, and distribute this software and its -# documentation for any purpose with or without fee is hereby granted, -# provided that the above copyright notice and this permission notice -# appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT -# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -"""DNS Wire Data Helper""" - -import dns.exception - - -class WireData(bytes): - # WireData is a binary type with stricter slicing - - def __getitem__(self, key): - try: - if isinstance(key, slice): - # make sure we are not going outside of valid ranges, - # do stricter control of boundaries than python does - # by default - - for index in (key.start, key.stop): - if index is None: - continue - elif abs(index) > len(self): - raise dns.exception.FormError - - return WireData(super().__getitem__(key)) - return super().__getitem__(key) - except IndexError: - raise dns.exception.FormError - - def unwrap(self): - return bytes(self) - - -def maybe_wrap(wire): - if isinstance(wire, WireData): - return wire - elif isinstance(wire, bytes): - return WireData(wire) - elif isinstance(wire, str): - return WireData(wire.encode()) - raise ValueError("unhandled type %s" % type(wire)) diff --git a/tests/test_rdata.py b/tests/test_rdata.py index 0ed38b76..ca238953 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -23,6 +23,7 @@ import pickle import struct import unittest +import dns.wire import dns.exception import dns.name import dns.rdata @@ -401,13 +402,19 @@ class RdataTestCase(unittest.TestCase): def test_opt_short_lengths(self): def bad1(): - opt = OPT.from_wire(4096, dns.rdatatype.OPT, - binascii.unhexlify('f00102'), 0, 3) + parser = dns.wire.Parser(bytes.fromhex('f00102')) + opt = OPT.from_wire_parser(4096, dns.rdatatype.OPT, parser) self.assertRaises(dns.exception.FormError, bad1) def bad2(): - opt = OPT.from_wire(4096, dns.rdatatype.OPT, - binascii.unhexlify('f00100030000'), 0, 6) + parser = dns.wire.Parser(bytes.fromhex('f00100030000')) + opt = OPT.from_wire_parser(4096, dns.rdatatype.OPT, parser) self.assertRaises(dns.exception.FormError, bad2) + def test_from_wire_parser(self): + wire = bytes.fromhex('01020304') + rdata = dns.rdata.from_wire('in', 'a', wire, 0, 4) + print(rdata) + self.assertEqual(rdata, dns.rdata.from_text('in', 'a', '1.2.3.4')) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_rdtypeanyeui.py b/tests/test_rdtypeanyeui.py index e8645820..08527273 100644 --- a/tests/test_rdtypeanyeui.py +++ b/tests/test_rdtypeanyeui.py @@ -94,25 +94,19 @@ class RdtypeAnyEUI48TestCase(unittest.TestCase): '''Valid wire format.''' eui = b'\x01\x23\x45\x67\x89\xab' pad_len = 100 - wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2) - inst = dns.rdtypes.ANY.EUI48.EUI48.from_wire(dns.rdataclass.IN, - dns.rdatatype.EUI48, - wire, - pad_len, - len(eui)) + wire = b'x' * pad_len + eui + b'y' * pad_len * 2 + inst = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI48, + wire, pad_len, len(eui)) self.assertEqual(inst.eui, eui) def testFromWireLength(self): '''Valid wire format.''' eui = b'\x01\x23\x45\x67\x89' pad_len = 100 - wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2) + wire = b'x' * pad_len + eui + b'y' * pad_len * 2 with self.assertRaises(dns.exception.FormError): - dns.rdtypes.ANY.EUI48.EUI48.from_wire(dns.rdataclass.IN, - dns.rdatatype.EUI48, - wire, - pad_len, - len(eui)) + dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI48, + wire, pad_len, len(eui)) class RdtypeAnyEUI64TestCase(unittest.TestCase): @@ -191,25 +185,19 @@ class RdtypeAnyEUI64TestCase(unittest.TestCase): '''Valid wire format.''' eui = b'\x01\x23\x45\x67\x89\xab\xcd\xef' pad_len = 100 - wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2) - inst = dns.rdtypes.ANY.EUI64.EUI64.from_wire(dns.rdataclass.IN, - dns.rdatatype.EUI64, - wire, - pad_len, - len(eui)) + wire = b'x' * pad_len + eui + b'y' * pad_len * 2 + inst = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI64, + wire, pad_len, len(eui)) self.assertEqual(inst.eui, eui) def testFromWireLength(self): '''Valid wire format.''' eui = b'\x01\x23\x45\x67\x89' pad_len = 100 - wire = dns.wiredata.WireData(b'x' * pad_len + eui + b'y' * pad_len * 2) + wire = b'x' * pad_len + eui + b'y' * pad_len * 2 with self.assertRaises(dns.exception.FormError): - dns.rdtypes.ANY.EUI64.EUI64.from_wire(dns.rdataclass.IN, - dns.rdatatype.EUI64, - wire, - pad_len, - len(eui)) + dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.EUI64, + wire, pad_len, len(eui)) if __name__ == '__main__': diff --git a/tests/test_wire.py b/tests/test_wire.py new file mode 100644 index 00000000..ac4401ce --- /dev/null +++ b/tests/test_wire.py @@ -0,0 +1,55 @@ + +import unittest + +import dns.exception +import dns.wire +import dns.name + + +class BinaryTestCase(unittest.TestCase): + + def test_basic(self): + wire = bytes.fromhex('0102010203040102') + p = dns.wire.Parser(wire) + self.assertEqual(p.get_uint16(), 0x0102) + with p.restrict_to(5): + self.assertEqual(p.get_uint32(), 0x01020304) + self.assertEqual(p.get_uint8(), 0x01) + self.assertEqual(p.remaining(), 0) + with self.assertRaises(dns.exception.FormError): + p.get_uint16() + self.assertEqual(p.remaining(), 1) + self.assertEqual(p.get_uint8(), 0x02) + with self.assertRaises(dns.exception.FormError): + p.get_uint8() + + def test_name(self): + # www.dnspython.org NS IN question + wire = b'\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01' + expected = dns.name.from_text('www.dnspython.org') + p = dns.wire.Parser(wire) + self.assertEqual(p.get_name(), expected) + self.assertEqual(p.get_uint16(), 2) + self.assertEqual(p.get_uint16(), 1) + self.assertEqual(p.remaining(), 0) + + def test_relativized_name(self): + # www.dnspython.org NS IN question + wire = b'\x03www\x09dnspython\x03org\x00\x00\x02\x00\x01' + origin = dns.name.from_text('dnspython.org') + expected = dns.name.from_text('www', None) + p = dns.wire.Parser(wire) + self.assertEqual(p.get_name(origin), expected) + self.assertEqual(p.remaining(), 4) + + def test_compressed_name(self): + # www.dnspython.org NS IN question + wire = b'\x09dnspython\x03org\x00\x03www\xc0\x00' + expected1 = dns.name.from_text('dnspython.org') + expected2 = dns.name.from_text('www.dnspython.org') + p = dns.wire.Parser(wire) + self.assertEqual(p.get_name(), expected1) + self.assertEqual(p.get_name(), expected2) + self.assertEqual(p.remaining(), 0) + # verify the unseek() + self.assertEqual(p.current, len(wire)) diff --git a/tests/test_wiredata.py b/tests/test_wiredata.py deleted file mode 100644 index 9274259a..00000000 --- a/tests/test_wiredata.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (C) 2016 -# Author: Martin Basti -# -# Permission to use, copy, modify, and distribute this software and its -# documentation for any purpose with or without fee is hereby granted, -# provided that the above copyright notice and this permission notice -# appear in all copies. - -import unittest - -from dns.exception import FormError -from dns.wiredata import WireData, maybe_wrap - - -class WireDataSlicingTestCase(unittest.TestCase): - - def testSliceAll(self): - """Get all data""" - inst = WireData(b'0123456789') - self.assertEqual(inst[:], WireData(b'0123456789')) - - def testSliceAllExplicitlyDefined(self): - """Get all data""" - inst = WireData(b'0123456789') - self.assertEqual(inst[0:10], WireData(b'0123456789')) - - def testSliceLowerHalf(self): - """Get lower half of data""" - inst = WireData(b'0123456789') - self.assertEqual(inst[:5], WireData(b'01234')) - - def testSliceLowerHalfWithNegativeIndex(self): - """Get lower half of data""" - inst = WireData(b'0123456789') - self.assertEqual(inst[:-5], WireData(b'01234')) - - def testSliceUpperHalf(self): - """Get upper half of data""" - inst = WireData(b'0123456789') - self.assertEqual(inst[5:], WireData(b'56789')) - - def testSliceMiddle(self): - """Get data from middle""" - inst = WireData(b'0123456789') - self.assertEqual(inst[3:6], WireData(b'345')) - - def testSliceMiddleWithNegativeIndex(self): - """Get data from middle""" - inst = WireData(b'0123456789') - self.assertEqual(inst[-6:-3], WireData(b'456')) - - def testSliceMiddleWithMixedIndex(self): - """Get data from middle""" - inst = WireData(b'0123456789') - self.assertEqual(inst[-8:3], WireData(b'2')) - self.assertEqual(inst[5:-3], WireData(b'56')) - - def testGetOne(self): - """Get data one by one item""" - data = b'0123456789' - inst = WireData(data) - for i, byte in enumerate(data): - self.assertEqual(inst[i], byte) - for i in range(-1, len(data) * -1, -1): - self.assertEqual(inst[i], data[i]) - - def testEmptySlice(self): - """Test empty slice""" - data = b'0123456789' - inst = WireData(data) - for i, byte in enumerate(data): - self.assertEqual(inst[i:i], b'') - for i in range(-1, len(data) * -1, -1): - self.assertEqual(inst[i:i], b'') - self.assertEqual(inst[-3:-6], b'') - - def testSliceStartOutOfLowerBorder(self): - """Get data from out of lower border""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[-11:] # pylint: disable=pointless-statement - - def testSliceStopOutOfLowerBorder(self): - """Get data from out of lower border""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[:-11] # pylint: disable=pointless-statement - - def testSliceBothOutOfLowerBorder(self): - """Get data from out of lower border""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[-12:-11] # pylint: disable=pointless-statement - - def testSliceStartOutOfUpperBorder(self): - """Get data from out of upper border""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[11:] # pylint: disable=pointless-statement - - def testSliceStopOutOfUpperBorder(self): - """Get data from out of upper border""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[:11] # pylint: disable=pointless-statement - - def testSliceBothOutOfUpperBorder(self): - """Get data from out of lower border""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[10:20] # pylint: disable=pointless-statement - - def testGetOneOutOfLowerBorder(self): - """Get item outside of range""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[-11] # pylint: disable=pointless-statement - - def testGetOneOutOfUpperBorder(self): - """Get item outside of range""" - inst = WireData(b'0123456789') - with self.assertRaises(FormError): - inst[10] # pylint: disable=pointless-statement - - def testIteration(self): - bval = b'0123' - inst = WireData(bval) - l = list(inst) - self.assertEqual(l, [x for x in bval]) - - def testBadWrap(self): - def bad(): - w = maybe_wrap(123) - self.assertRaises(ValueError, bad)