]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
draft continue_on_error
authorBob Halley <halley@dnspython.org>
Sun, 26 Sep 2021 22:39:03 +0000 (15:39 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 23 Oct 2021 23:38:34 +0000 (16:38 -0700)
dns/message.py

index 75faee289fbc9264ca5a9f1b43bccb00e33a9938..4c94d69a4498981122555b7fcef43c1eb0b7ad69 100644 (file)
@@ -132,6 +132,7 @@ class Message:
         self.origin = None
         self.tsig_ctx = None
         self.index = {}
+        self.errors = []
 
     @property
     def question(self):
@@ -877,7 +878,7 @@ class _WireReader:
 
     def __init__(self, wire, initialize_message, question_only=False,
                  one_rr_per_rrset=False, ignore_trailing=False,
-                 keyring=None, multi=False):
+                 keyring=None, multi=False, continue_on_error=False):
         self.parser = dns.wire.Parser(wire)
         self.message = None
         self.initialize_message = initialize_message
@@ -886,6 +887,9 @@ class _WireReader:
         self.ignore_trailing = ignore_trailing
         self.keyring = keyring
         self.multi = multi
+        self.continue_on_error = continue_on_error
+        self.last_good = 0
+        self.errors = []
 
     def _get_question(self, section_number, qcount):
         """Read the next *qcount* records from the wire data and add them to
@@ -901,12 +905,13 @@ class _WireReader:
                                               rdtype)
             self.message.find_rrset(section, qname, rdclass, rdtype,
                                     create=True, force_unique=True)
+            self.last_good = self.parser.current
 
     def _get_section(self, section_number, count):
         """Read the next I{count} records from the wire data and add them to
         the specified section.
 
-        section: the section of the message to which to add records
+        section_number: the section of the message to which to add records
         count: the number of records to read
         """
 
@@ -929,55 +934,75 @@ class _WireReader:
                 (rdclass, rdtype, deleting, empty) = \
                     self.message._parse_rr_header(section_number,
                                                   name, rdclass, rdtype)
-            if empty:
-                if rdlen > 0:
-                    raise dns.exception.FormError
-                rd = None
-                covers = dns.rdatatype.NONE
-            else:
-                with self.parser.restrict_to(rdlen):
-                    rd = dns.rdata.from_wire_parser(rdclass, rdtype,
-                                                    self.parser,
-                                                    self.message.origin)
-                covers = rd.covers()
-            if self.message.xfr and rdtype == dns.rdatatype.SOA:
-                force_unique = True
-            if rdtype == dns.rdatatype.OPT:
-                self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
-            elif rdtype == dns.rdatatype.TSIG:
-                if self.keyring is None:
-                    raise UnknownTSIGKey('got signed message without keyring')
-                if isinstance(self.keyring, dict):
-                    key = self.keyring.get(absolute_name)
-                    if isinstance(key, bytes):
-                        key = dns.tsig.Key(absolute_name, key, rd.algorithm)
-                elif callable(self.keyring):
-                    key = self.keyring(self.message, absolute_name)
+            try:
+                if empty:
+                    if rdlen > 0:
+                        raise dns.exception.FormError
+                    rd = None
+                    covers = dns.rdatatype.NONE
                 else:
-                    key = self.keyring
-                if key is None:
-                    raise UnknownTSIGKey("key '%s' unknown" % name)
-                self.message.keyring = key
-                self.message.tsig_ctx = \
-                    dns.tsig.validate(self.parser.wire,
-                                      key,
-                                      absolute_name,
-                                      rd,
-                                      int(time.time()),
-                                      self.message.request_mac,
-                                      rr_start,
-                                      self.message.tsig_ctx,
-                                      self.multi)
-                self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
-            else:
-                rrset = self.message.find_rrset(section, name,
-                                                rdclass, rdtype, covers,
-                                                deleting, True,
-                                                force_unique)
-                if rd is not None:
-                    if ttl > 0x7fffffff:
-                        ttl = 0
-                    rrset.add(rd, ttl)
+                    self.last_good = self.parser.current
+                    with self.parser.restrict_to(rdlen):
+                        rd = dns.rdata.from_wire_parser(rdclass, rdtype,
+                                                        self.parser,
+                                                        self.message.origin)
+                    covers = rd.covers()
+                if self.message.xfr and rdtype == dns.rdatatype.SOA:
+                    force_unique = True
+                if rdtype == dns.rdatatype.OPT:
+                    self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
+                elif rdtype == dns.rdatatype.TSIG:
+                    if self.keyring is None:
+                        raise UnknownTSIGKey('got signed message without '
+                                             'keyring')
+                    if isinstance(self.keyring, dict):
+                        key = self.keyring.get(absolute_name)
+                        if isinstance(key, bytes):
+                            key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+                    elif callable(self.keyring):
+                        key = self.keyring(self.message, absolute_name)
+                    else:
+                        key = self.keyring
+                    if key is None:
+                        raise UnknownTSIGKey("key '%s' unknown" % name)
+                    self.message.keyring = key
+                    self.message.tsig_ctx = \
+                        dns.tsig.validate(self.parser.wire,
+                                        key,
+                                        absolute_name,
+                                        rd,
+                                        int(time.time()),
+                                        self.message.request_mac,
+                                        rr_start,
+                                        self.message.tsig_ctx,
+                                        self.multi)
+                    self.message.tsig = dns.rrset.from_rdata(absolute_name, 0,
+                                                             rd)
+                else:
+                    rrset = self.message.find_rrset(section, name,
+                                                    rdclass, rdtype, covers,
+                                                    deleting, True,
+                                                    force_unique)
+                    if rd is not None:
+                        if ttl > 0x7fffffff:
+                            ttl = 0
+                        rrset.add(rd, ttl)
+            except Exception as e:
+                if self.continue_on_error:
+                    self.errors.append((self.last_good, str(e), e))
+                    try:
+                        self.parser.seek(self.last_good + rdlen)
+                    except dns.exception.FormError:
+                        # seek was past the end
+                        self.parser.seek(self.parser.end)
+                        if i != count - 1:
+                            senum = MessageSection(section_number)
+                            self.errors.append((self.end, 'not enough RRs in '
+                                                f'section {senum}',
+                                                None))
+                        return
+                else:
+                    raise
 
     def read(self):
         """Read a wire format DNS message and build a dns.message.Message
@@ -993,23 +1018,31 @@ class _WireReader:
         self.initialize_message(self.message)
         self.one_rr_per_rrset = \
             self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
-        self._get_question(MessageSection.QUESTION, qcount)
-        if self.question_only:
-            return self.message
-        self._get_section(MessageSection.ANSWER, ancount)
-        self._get_section(MessageSection.AUTHORITY, aucount)
-        self._get_section(MessageSection.ADDITIONAL, adcount)
-        if not self.ignore_trailing and self.parser.remaining() != 0:
-            raise TrailingJunk
-        if self.multi and self.message.tsig_ctx and not self.message.had_tsig:
-            self.message.tsig_ctx.update(self.parser.wire)
+        try:
+            self._get_question(MessageSection.QUESTION, qcount)
+            if self.question_only:
+                return self.message
+            self._get_section(MessageSection.ANSWER, ancount)
+            self._get_section(MessageSection.AUTHORITY, aucount)
+            self._get_section(MessageSection.ADDITIONAL, adcount)
+            if not self.ignore_trailing and self.parser.remaining() != 0:
+                raise TrailingJunk
+            if self.multi and self.message.tsig_ctx and \
+                not self.message.had_tsig:
+                self.message.tsig_ctx.update(self.parser.wire)
+        except Exception as e:
+            if self.continue_on_error:
+                self.errors.append((self.last_good, str(e), e))
+            else:
+                raise
         return self.message
 
 
 def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
               tsig_ctx=None, multi=False,
               question_only=False, one_rr_per_rrset=False,
-              ignore_trailing=False, raise_on_truncation=False):
+              ignore_trailing=False, raise_on_truncation=False,
+              continue_on_error=False):
     """Convert a DNS wire format message into a message
     object.
 
@@ -1044,6 +1077,11 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
 
     *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
     the TC bit is set.
+    
+    *continue_on_error*, a ``bool``.  If ``True``, try to continue parsing
+    even if errors occur.  Erroneous rdata will be ignored, but records
+    of the errors will be added to the message ``errors`` field.  This option
+    is recommended only for DNS debugging.  The default is ``False``.
 
     Raises ``dns.message.ShortHeader`` if the message is less than 12 octets
     long.
@@ -1070,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
         message.tsig_ctx = tsig_ctx
 
     reader = _WireReader(wire, initialize_message, question_only,
-                         one_rr_per_rrset, ignore_trailing, keyring, multi)
+                         one_rr_per_rrset, ignore_trailing, keyring, multi,
+                         continue_on_error)
     try:
         m = reader.read()
     except dns.exception.FormError:
@@ -1083,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     # have to do this check here too.
     if m.flags & dns.flags.TC and raise_on_truncation:
         raise Truncated(message=m)
+    if continue_on_error:
+        m.errors = reader.errors
 
     return m