]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Simplify; add a MessageError class. 694/head
authorBob Halley <halley@dnspython.org>
Sun, 24 Oct 2021 11:53:23 +0000 (04:53 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 24 Oct 2021 11:53:23 +0000 (04:53 -0700)
dns/message.py
tests/test_message.py

index e5a3c59b18cda8b79d860d20ea068da41ca84183..8e6f5cc4bb2530c84205633ce74fd92376d4ae9a 100644 (file)
@@ -108,6 +108,12 @@ class MessageSection(dns.enum.IntEnum):
         return 3
 
 
+class MessageError:
+    def __init__(self, exception, offset):
+        self.exception = exception
+        self.offset = offset
+
+
 DEFAULT_EDNS_PAYLOAD = 1232
 MAX_CHAIN = 16
 
@@ -874,6 +880,9 @@ class _WireReader:
     ignore_trailing: Ignore trailing junk at end of request?
     multi: Is this message part of a multi-message sequence?
     DNS dynamic updates.
+    continue_on_error: try to extract as much information as possible from
+    the message, accumulating MessageErrors in the *errors* attribute instead of
+    raising them.
     """
 
     def __init__(self, wire, initialize_message, question_only=False,
@@ -905,6 +914,9 @@ class _WireReader:
             self.message.find_rrset(section, qname, rdclass, rdtype,
                                     create=True, force_unique=True)
 
+    def _add_error(self, e):
+        self.errors.append(MessageError(e, self.parser.current))
+
     def _get_section(self, section_number, count):
         """Read the next I{count} records from the wire data and add them to
         the specified section.
@@ -987,19 +999,8 @@ class _WireReader:
                         rrset.add(rd, ttl)
             except Exception as e:
                 if self.continue_on_error:
-                    self.errors.append((self.parser.current, str(e), e))
-                    try:
-                        self.parser.seek(rdata_start + rdlen)
-                    except dns.exception.FormError:
-                        # seek was past the end
-                        self.parser.seek(self.parser.end)
-                    if self.parser.current == self.parser.end and \
-                        i != count - 1:
-                        senum = MessageSection(section_number)
-                        self.errors.append((self.parser.end, 'not enough RRs in '
-                                            f'section {senum:d}',
-                                            None))
-                        return
+                    self._add_error(e)
+                    self.parser.seek(rdata_start + rdlen)
                 else:
                     raise
 
@@ -1031,7 +1032,7 @@ class _WireReader:
                 self.message.tsig_ctx.update(self.parser.wire)
         except Exception as e:
             if self.continue_on_error:
-                self.errors.append((self.last_good, str(e), e))
+                self._add_error(e)
             else:
                 raise
         return self.message
@@ -1042,57 +1043,57 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
               question_only=False, one_rr_per_rrset=False,
               ignore_trailing=False, raise_on_truncation=False,
               continue_on_error=False):
-    """Convert a DNS wire format message into a message
-    object.
+    """Convert a DNS wire format message into a message object.
 
-    *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use
-    if the message is signed.
+    *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the
+    message is signed.
 
-    *request_mac*, a ``bytes``.  If the message is a response to a
-    TSIG-signed request, *request_mac* should be set to the MAC of
-    that request.
+    *request_mac*, a ``bytes``.  If the message is a response to a TSIG-signed
+    request, *request_mac* should be set to the MAC of that request.
 
-    *xfr*, a ``bool``, should be set to ``True`` if this message is part of
-    zone transfer.
+    *xfr*, a ``bool``, should be set to ``True`` if this message is part of a
+    zone transfer.
 
-    *origin*, a ``dns.name.Name`` or ``None``.  If the message is part
-    of a zone transfer, *origin* should be the origin name of the
-    zone.  If not ``None``, names will be relativized to the origin.
+    *origin*, a ``dns.name.Name`` or ``None``.  If the message is part of a zone
+    transfer, *origin* should be the origin name of the zone.  If not ``None``,
+    names will be relativized to the origin.
 
     *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
     ongoing TSIG context, used when validating zone transfers.
 
-    *multi*, a ``bool``, should be set to ``True`` if this message is
-    part of a multiple message sequence.
+    *multi*, a ``bool``, should be set to ``True`` if this message is part of a
+    multiple message sequence.
+
+    *question_only*, a ``bool``.  If ``True``, read only up to the end of the
+    question section.
 
-    *question_only*, a ``bool``.  If ``True``, read only up to
-    the end of the question section.
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
 
-    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its
-    own RRset.
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing junk at end of
+    the message.
 
-    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
-    junk at end of the message.
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if the
+    TC bit is set.
 
-    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
-    the TC bit is set.
-    
-    *continue_on_error*, a ``bool``.  If ``True``, try to continue parsing
-    even if errors occur.  Erroneous rdata will be ignored, but records
-    of the errors will be added to the message ``errors`` field.  This option
-    is recommended only for DNS debugging.  The default is ``False``.
+    *continue_on_error*, a ``bool``.  If ``True``, try to continue parsing even
+    if errors occur.  Erroneous rdata will be ignored.  Errors will be
+    accumulated as a list of MessageError objects in the message's ``errors``
+    attribute.  This option is recommended only for DNS analysis tools, or for
+    use in a server as part of an error handling path.  The default is
+    ``False``.
 
     Raises ``dns.message.ShortHeader`` if the message is less than 12 octets
     long.
 
-    Raises ``dns.message.TrailingJunk`` if there were octets in the message
-    past the end of the proper DNS message, and *ignore_trailing* is ``False``.
+    Raises ``dns.message.TrailingJunk`` if there were octets in the message past
+    the end of the proper DNS message, and *ignore_trailing* is ``False``.
 
-    Raises ``dns.message.BadEDNS`` if an OPT record was in the
-    wrong section, or occurred more than once.
+    Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or
+    occurred more than once.
 
-    Raises ``dns.message.BadTSIG`` if a TSIG record was not the last
-    record of the additional data section.
+    Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of
+    the additional data section.
 
     Raises ``dns.message.Truncated`` if the TC flag is set and
     *raise_on_truncation* is ``True``.
index 1ea87aa0e8a58097eaef093ba072d98e8385e9ab..190385af932d0edf3b7bc16a5404df5aeea319d1 100644 (file)
@@ -707,9 +707,12 @@ www.dnspython.org. 300 IN AAAA ::1
         bad_wire = bad_wire.replace(b'\x00\x00\x00\x05', b'\xff' * 4)
         m = dns.message.from_wire(bad_wire, continue_on_error=True)
         self.assertEqual(len(m.errors), 3)
-        self.assertEqual(m.errors[0][:2], (69, 'value too large'))
-        self.assertEqual(m.errors[1][:2], (97, 'IPv6 addresses are 16 bytes long'))
-        self.assertEqual(m.errors[2][:2], (97, 'not enough RRs in section 1'))
+        print(m.errors)
+        self.assertEqual(str(m.errors[0].exception), 'value too large')
+        self.assertEqual(str(m.errors[1].exception),
+                         'IPv6 addresses are 16 bytes long')
+        self.assertEqual(str(m.errors[2].exception),
+                         'DNS message is malformed.')
         expected_message = dns.message.from_text(
 """id 1234
 opcode QUERY