]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Attempt to refactor per-opcode validation. 521/head
authorBrian Wellington <bwelling@xbill.org>
Fri, 26 Jun 2020 20:59:59 +0000 (13:59 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 26 Jun 2020 20:59:59 +0000 (13:59 -0700)
Instead of validating rrsets and sections after parsing them, check the
class/type for each record before parsing it.  This is more generic,
because it moves all of the update logic out of the common code.  It's
also more flexible, as it allows the update logic to specify that
meta-records are empty.

dns/message.py
dns/update.py

index fcb8372de2f864f02eebb041c4e6735d34a07963..dd278f66381e66ad79a8dd2536857cdb1f5fc5db 100644 (file)
@@ -610,13 +610,10 @@ class Message:
         # What the caller picked is fine.
         return value
 
-    def _validate_rrset(self, section, rrset):
-        if rrset.rdclass == dns.rdataclass.ANY or \
-           rrset.rdclass == dns.rdataclass.NONE:
+    def _parse_rr_header(self, reader, section, rdclass, rdtype):
+        if dns.rdataclass.is_metaclass(rdclass):
             raise dns.exception.FormError
-
-    def _finish_section(self, section):
-        pass
+        return (rdclass, rdtype, None, False)
 
 
 class QueryMessage(Message):
@@ -653,7 +650,6 @@ class _WireReader:
     question_only: Are we only reading the question?
     one_rr_per_rrset: Put each RR into its own RRset?
     ignore_trailing: Ignore trailing junk at end of request?
-    zone_rdclass: The class of the zone in messages which are
     DNS dynamic updates.
     """
 
@@ -666,13 +662,13 @@ class _WireReader:
         self.question_only = question_only
         self.one_rr_per_rrset = one_rr_per_rrset
         self.ignore_trailing = ignore_trailing
-        self.zone_rdclass = dns.rdataclass.IN
 
-    def _get_question(self, qcount):
+    def _get_question(self, section_number, qcount):
         """Read the next *qcount* records from the wire data and add them to
         the question section.
         """
 
+        section = self.message.sections[section_number]
         for i in range(qcount):
             (qname, used) = dns.name.from_wire(self.wire, self.current)
             if self.message.origin is not None:
@@ -682,12 +678,11 @@ class _WireReader:
                 struct.unpack('!HH',
                               self.wire[self.current:self.current + 4])
             self.current += 4
-            rrset = self.message.find_rrset(self.message.question, qname,
-                                            rdclass, rdtype, create=True,
-                                            force_unique=True)
-        for rrset in self.message.sections[MessageSection.QUESTION]:
-            self.message._validate_rrset(MessageSection.QUESTION, rrset)
-        self.message._finish_section(MessageSection.QUESTION)
+            (rdclass, rdtype, _, _) = \
+                self.message._parse_rr_header(self, section_number,
+                                              rdclass, rdtype)
+            rrset = self.message.find_rrset(section, qname, rdclass, rdtype,
+                                            create=True, force_unique=True)
 
     def _get_section(self, section_number, count):
         """Read the next I{count} records from the wire data and add them to
@@ -758,21 +753,17 @@ class _WireReader:
                                       self.message.first)
                 self.message.had_tsig = True
             else:
-                if ttl > 0x7fffffff:
-                    ttl = 0
-                if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE):
-                    deleting = rdclass
-                    rdclass = self.zone_rdclass
-                else:
-                    deleting = None
-                if deleting == dns.rdataclass.ANY or \
-                   (deleting == dns.rdataclass.NONE and
-                        section is self.message.answer):
-                    covers = dns.rdatatype.NONE
+                (rdclass, rdtype, deleting, empty) = \
+                    self.message._parse_rr_header(self, section_number,
+                                                  rdclass, rdtype)
+                if empty:
+                    if rdlen > 0:
+                        raise dns.exception.FormError
                     rd = None
+                    covers = dns.rdatatype.NONE
                 else:
-                    rd = dns.rdata.from_wire(rdclass, rdtype, self.wire,
-                                             self.current, rdlen,
+                    rd = dns.rdata.from_wire(rdclass, rdtype,
+                                             self.wire, self.current, rdlen,
                                              self.message.origin)
                     covers = rd.covers()
                 if self.message.xfr and rdtype == dns.rdatatype.SOA:
@@ -781,11 +772,10 @@ class _WireReader:
                                                 rdclass, rdtype, covers,
                                                 deleting, True, force_unique)
                 if rd is not None:
+                    if ttl > 0x7fffffff:
+                        ttl = 0
                     rrset.add(rd, ttl)
             self.current += rdlen
-        for rrset in self.message.sections[section_number]:
-            self.message._validate_rrset(section_number, rrset)
-        self.message._finish_section(section_number)
 
     def read(self):
         """Read a wire format DNS message and build a dns.message.Message
@@ -803,7 +793,7 @@ class _WireReader:
         self.initialize_message(self.message)
         self.one_rr_per_rrset = \
             self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
-        self._get_question(qcount)
+        self._get_question(MessageSection.QUESTION, qcount)
         if self.question_only:
             return
         self._get_section(MessageSection.ANSWER, ancount)
@@ -909,7 +899,6 @@ class _TextReader:
 
     tok: the tokenizer.
     message: The message object being built.
-    zone_rdclass: The class of the zone in messages which are
     DNS dynamic updates.
     last_name: The most recently read name when building a message object.
     one_rr_per_rrset: Put each RR into its own RRset?
@@ -923,7 +912,6 @@ class _TextReader:
         self.message = None
         self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec)
         self.last_name = None
-        self.zone_rdclass = dns.rdataclass.IN
         self.one_rr_per_rrset = one_rr_per_rrset
         self.origin = origin
         self.relativize = relativize
@@ -1003,9 +991,10 @@ class _TextReader:
             rdclass = dns.rdataclass.IN
         # Type
         rdtype = dns.rdatatype.from_text(token.value)
+        (rdclass, rdtype, _, _) = \
+            self.message._parse_rr_header(self, section_number, rdclass, rdtype)
         self.message.find_rrset(section, name, rdclass, rdtype, create=True,
                                 force_unique=True)
-        self.zone_rdclass = rdclass
         self.tok.get_eol()
 
     def _rr_line(self, section_number):
@@ -1014,7 +1003,6 @@ class _TextReader:
         """
 
         section = self.message.sections[section_number]
-        deleting = None
         # Name
         token = self.tok.get(want_leading=True)
         if not token.is_whitespace():
@@ -1041,16 +1029,17 @@ class _TextReader:
             token = self.tok.get()
             if not token.is_identifier():
                 raise dns.exception.SyntaxError
-            if rdclass == dns.rdataclass.ANY or rdclass == dns.rdataclass.NONE:
-                deleting = rdclass
-                rdclass = self.zone_rdclass
         except dns.exception.SyntaxError:
             raise dns.exception.SyntaxError
         except Exception:
             rdclass = dns.rdataclass.IN
         # Type
         rdtype = dns.rdatatype.from_text(token.value)
+        (rdclass, rdtype, deleting, empty) = \
+            self.message._parse_rr_header(self, section_number, rdclass, rdtype)
         token = self.tok.get()
+        if empty and not token.is_eol_or_eof():
+            raise dns.exception.SyntaxError
         if not token.is_eol_or_eof():
             self.tok.unget(token)
             rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
@@ -1124,10 +1113,6 @@ class _TextReader:
             line_method(section_number)
         if not self.message:
             self.message = self._make_message()
-        for i in range(4):
-            for rrset in self.message.sections[i]:
-                self.message._validate_rrset(i, rrset)
-            self.message._finish_section(i)
         return self.message
 
 
index 9615a73274815c91769be2c605a3128c7d3709bf..e21a283e3ed60465d1fa6d85bc8ece092d6795d7 100644 (file)
@@ -300,17 +300,24 @@ class UpdateMessage(dns.message.Message):
         # Updates are always one_rr_per_rrset
         return True
 
-    def _validate_rrset(self, section, rrset):
+    def _parse_rr_header(self, reader, section, rdclass, rdtype):
+        deleting = None
+        empty = False
         if section == UpdateSection.ZONE:
-            if rrset.rdtype != dns.rdatatype.SOA:
+            if dns.rdataclass.is_metaclass(rdclass) or \
+               rdtype != dns.rdatatype.SOA or \
+               getattr(reader, 'zone_rdclass', None):
                 raise dns.exception.FormError
-
-    def _finish_section(self, section):
-        if section == UpdateSection.ZONE and len(self.zone) != 1:
-            raise dns.exception.FormError
-        self.zone_rdclass = self.zone[0].rdclass
-        # We do NOT want to set origin here, as that would cause
-        # from_wire() relativization.
+            reader.zone_rdclass = rdclass
+        else:
+            if not getattr(reader, 'zone_rdclass', None):
+                raise dns.exception.FormError
+            if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE):
+                deleting = rdclass
+                rdclass = reader.zone_rdclass
+                empty = (deleting == dns.rdataclass.ANY or
+                         section == UpdateSection.PREREQ)
+        return (rdclass, rdtype, deleting, empty)
 
 # backwards compatibility
 Update = UpdateMessage