From: Bob Halley Date: Sun, 26 Sep 2021 22:39:03 +0000 (-0700) Subject: draft continue_on_error X-Git-Tag: v2.2.0rc1~47^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=95677ee0b1dcefd0f1d847f55cbe05a19fe72895;p=thirdparty%2Fdnspython.git draft continue_on_error --- diff --git a/dns/message.py b/dns/message.py index 75faee28..4c94d69a 100644 --- a/dns/message.py +++ b/dns/message.py @@ -132,6 +132,7 @@ class Message: self.origin = None self.tsig_ctx = None self.index = {} + self.errors = [] @property def question(self): @@ -877,7 +878,7 @@ class _WireReader: def __init__(self, wire, initialize_message, question_only=False, one_rr_per_rrset=False, ignore_trailing=False, - keyring=None, multi=False): + keyring=None, multi=False, continue_on_error=False): self.parser = dns.wire.Parser(wire) self.message = None self.initialize_message = initialize_message @@ -886,6 +887,9 @@ class _WireReader: self.ignore_trailing = ignore_trailing self.keyring = keyring self.multi = multi + self.continue_on_error = continue_on_error + self.last_good = 0 + self.errors = [] def _get_question(self, section_number, qcount): """Read the next *qcount* records from the wire data and add them to @@ -901,12 +905,13 @@ class _WireReader: rdtype) self.message.find_rrset(section, qname, rdclass, rdtype, create=True, force_unique=True) + self.last_good = self.parser.current def _get_section(self, section_number, count): """Read the next I{count} records from the wire data and add them to the specified section. - section: the section of the message to which to add records + section_number: the section of the message to which to add records count: the number of records to read """ @@ -929,55 +934,75 @@ class _WireReader: (rdclass, rdtype, deleting, empty) = \ self.message._parse_rr_header(section_number, name, rdclass, rdtype) - if empty: - if rdlen > 0: - raise dns.exception.FormError - rd = None - covers = dns.rdatatype.NONE - else: - with self.parser.restrict_to(rdlen): - rd = dns.rdata.from_wire_parser(rdclass, rdtype, - self.parser, - self.message.origin) - covers = rd.covers() - if self.message.xfr and rdtype == dns.rdatatype.SOA: - force_unique = True - if rdtype == dns.rdatatype.OPT: - self.message.opt = dns.rrset.from_rdata(name, ttl, rd) - elif rdtype == dns.rdatatype.TSIG: - if self.keyring is None: - raise UnknownTSIGKey('got signed message without keyring') - if isinstance(self.keyring, dict): - key = self.keyring.get(absolute_name) - if isinstance(key, bytes): - key = dns.tsig.Key(absolute_name, key, rd.algorithm) - elif callable(self.keyring): - key = self.keyring(self.message, absolute_name) + try: + if empty: + if rdlen > 0: + raise dns.exception.FormError + rd = None + covers = dns.rdatatype.NONE else: - key = self.keyring - if key is None: - raise UnknownTSIGKey("key '%s' unknown" % name) - self.message.keyring = key - self.message.tsig_ctx = \ - dns.tsig.validate(self.parser.wire, - key, - absolute_name, - rd, - int(time.time()), - self.message.request_mac, - rr_start, - self.message.tsig_ctx, - self.multi) - self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) - else: - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, - force_unique) - if rd is not None: - if ttl > 0x7fffffff: - ttl = 0 - rrset.add(rd, ttl) + self.last_good = self.parser.current + with self.parser.restrict_to(rdlen): + rd = dns.rdata.from_wire_parser(rdclass, rdtype, + self.parser, + self.message.origin) + covers = rd.covers() + if self.message.xfr and rdtype == dns.rdatatype.SOA: + force_unique = True + if rdtype == dns.rdatatype.OPT: + self.message.opt = dns.rrset.from_rdata(name, ttl, rd) + elif rdtype == dns.rdatatype.TSIG: + if self.keyring is None: + raise UnknownTSIGKey('got signed message without ' + 'keyring') + if isinstance(self.keyring, dict): + key = self.keyring.get(absolute_name) + if isinstance(key, bytes): + key = dns.tsig.Key(absolute_name, key, rd.algorithm) + elif callable(self.keyring): + key = self.keyring(self.message, absolute_name) + else: + key = self.keyring + if key is None: + raise UnknownTSIGKey("key '%s' unknown" % name) + self.message.keyring = key + self.message.tsig_ctx = \ + dns.tsig.validate(self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, + rd) + else: + rrset = self.message.find_rrset(section, name, + rdclass, rdtype, covers, + deleting, True, + force_unique) + if rd is not None: + if ttl > 0x7fffffff: + ttl = 0 + rrset.add(rd, ttl) + except Exception as e: + if self.continue_on_error: + self.errors.append((self.last_good, str(e), e)) + try: + self.parser.seek(self.last_good + rdlen) + except dns.exception.FormError: + # seek was past the end + self.parser.seek(self.parser.end) + if i != count - 1: + senum = MessageSection(section_number) + self.errors.append((self.end, 'not enough RRs in ' + f'section {senum}', + None)) + return + else: + raise def read(self): """Read a wire format DNS message and build a dns.message.Message @@ -993,23 +1018,31 @@ class _WireReader: self.initialize_message(self.message) self.one_rr_per_rrset = \ self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) - self._get_question(MessageSection.QUESTION, qcount) - if self.question_only: - return self.message - self._get_section(MessageSection.ANSWER, ancount) - self._get_section(MessageSection.AUTHORITY, aucount) - self._get_section(MessageSection.ADDITIONAL, adcount) - if not self.ignore_trailing and self.parser.remaining() != 0: - raise TrailingJunk - if self.multi and self.message.tsig_ctx and not self.message.had_tsig: - self.message.tsig_ctx.update(self.parser.wire) + try: + self._get_question(MessageSection.QUESTION, qcount) + if self.question_only: + return self.message + self._get_section(MessageSection.ANSWER, ancount) + self._get_section(MessageSection.AUTHORITY, aucount) + self._get_section(MessageSection.ADDITIONAL, adcount) + if not self.ignore_trailing and self.parser.remaining() != 0: + raise TrailingJunk + if self.multi and self.message.tsig_ctx and \ + not self.message.had_tsig: + self.message.tsig_ctx.update(self.parser.wire) + except Exception as e: + if self.continue_on_error: + self.errors.append((self.last_good, str(e), e)) + else: + raise return self.message def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, question_only=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False): + ignore_trailing=False, raise_on_truncation=False, + continue_on_error=False): """Convert a DNS wire format message into a message object. @@ -1044,6 +1077,11 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the TC bit is set. + + *continue_on_error*, a ``bool``. If ``True``, try to continue parsing + even if errors occur. Erroneous rdata will be ignored, but records + of the errors will be added to the message ``errors`` field. This option + is recommended only for DNS debugging. The default is ``False``. Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. @@ -1070,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, message.tsig_ctx = tsig_ctx reader = _WireReader(wire, initialize_message, question_only, - one_rr_per_rrset, ignore_trailing, keyring, multi) + one_rr_per_rrset, ignore_trailing, keyring, multi, + continue_on_error) try: m = reader.read() except dns.exception.FormError: @@ -1083,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, # have to do this check here too. if m.flags & dns.flags.TC and raise_on_truncation: raise Truncated(message=m) + if continue_on_error: + m.errors = reader.errors return m