]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
new message class hierarchy and conversion of wire and text readers
authorBob Halley <halley@dnspython.org>
Fri, 26 Jun 2020 03:13:18 +0000 (20:13 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 26 Jun 2020 03:13:18 +0000 (20:13 -0700)
dns/message.py
dns/update.py

index 3d92c11fde95ad544996c942cefffc9d825b9e91..6272bbd4d5b93c1aed4ba4a476fba7882846e6dd 100644 (file)
@@ -78,6 +78,7 @@ class Truncated(dns.exception.DNSException):
         """
         return self.kwargs['message']
 
+
 class MessageSection(dns.enum.IntEnum):
     """Message sections"""
     QUESTION = 0
@@ -403,7 +404,8 @@ class Message:
         method.
 
         *origin*, a ``dns.name.Name`` or ``None``, the origin to be appended
-        to any relative names.
+        to any relative names.  If ``None``, and the message has an origin
+        attribute that is not ``None``, then it will be used.
 
         *max_size*, an ``int``, the maximum size of the wire format
         output; default is 0, which means "the message's request
@@ -420,6 +422,8 @@ class Message:
         Returns a ``bytes``.
         """
 
+        if origin is None and self.origin is not None:
+            origin = self.origin
         if max_size == 0:
             if self.request_payload != 0:
                 max_size = self.request_payload
@@ -601,6 +605,22 @@ class Message:
         self.flags &= 0x87FF
         self.flags |= dns.opcode.to_flags(opcode)
 
+    def _get_one_rr_per_rrset(self, value):
+        # What the caller picked is fine.
+        return value
+
+    def _validate_rrset(self, section, rrset):
+        if rrset.rdclass == dns.rdataclass.ANY or \
+           rrset.rdclass == dns.rdataclass.NONE:
+            raise dns.exception.FormError
+
+    def _finish_section(self, section):
+        pass
+
+
+class QueryMessage(Message):
+    pass
+
 
 class _WireReader:
 
@@ -611,22 +631,20 @@ class _WireReader:
     current: When building a message object from wire format, this
     variable contains the offset from the beginning of wire of the next octet
     to be read.
-    updating: Is the message a dynamic update?
     one_rr_per_rrset: Put each RR into its own RRset?
     ignore_trailing: Ignore trailing junk at end of request?
     zone_rdclass: The class of the zone in messages which are
     DNS dynamic updates.
     """
 
-    def __init__(self, wire, message, question_only=False,
-                 one_rr_per_rrset=False, ignore_trailing=False):
+    def __init__(self, wire, message, question_only, one_rr_per_rrset,
+                 ignore_trailing):
         self.wire = dns.wiredata.maybe_wrap(wire)
         self.message = message
         self.current = 0
-        self.updating = False
         self.zone_rdclass = dns.rdataclass.IN
         self.question_only = question_only
-        self.one_rr_per_rrset = one_rr_per_rrset
+        self.one_rr_per_rrset = message._get_one_rr_per_rrset(one_rr_per_rrset)
         self.ignore_trailing = ignore_trailing
 
     def _get_question(self, qcount):
@@ -634,9 +652,6 @@ class _WireReader:
         the question section.
         """
 
-        if self.updating and qcount > 1:
-            raise dns.exception.FormError
-
         for i in range(qcount):
             (qname, used) = dns.name.from_wire(self.wire, self.current)
             if self.message.origin is not None:
@@ -646,13 +661,13 @@ class _WireReader:
                 struct.unpack('!HH',
                               self.wire[self.current:self.current + 4])
             self.current += 4
-            self.message.find_rrset(self.message.question, qname,
-                                    rdclass, rdtype, create=True,
-                                    force_unique=True)
-            if self.updating:
-                self.zone_rdclass = rdclass
+            rrset = self.message.find_rrset(self.message.question, qname,
+                                            rdclass, rdtype, create=True,
+                                            force_unique=True)
+            self.message._validate_rrset(MessageSection.QUESTION, rrset)
+        self.message._finish_section(MessageSection.QUESTION)
 
-    def _get_section(self, section, count):
+    def _get_section(self, section_number, count):
         """Read the next I{count} records from the wire data and add them to
         the specified section.
 
@@ -660,10 +675,8 @@ class _WireReader:
         count: the number of records to read
         """
 
-        if self.updating or self.one_rr_per_rrset:
-            force_unique = True
-        else:
-            force_unique = False
+        section = self.message.sections[section_number]
+        force_unique = self.one_rr_per_rrset
         seen_opt = False
         for i in range(count):
             rr_start = self.current
@@ -725,8 +738,7 @@ class _WireReader:
             else:
                 if ttl > 0x7fffffff:
                     ttl = 0
-                if self.updating and \
-                   rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE):
+                if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE):
                     deleting = rdclass
                     rdclass = self.zone_rdclass
                 else:
@@ -748,7 +760,14 @@ class _WireReader:
                                                 deleting, True, force_unique)
                 if rd is not None:
                     rrset.add(rd, ttl)
+                # Note this validates the rrset every RR, but this is
+                # simpler than adding logic to remember a "last rrset"
+                # and validate it when the current rrset changes (or when
+                # we get to the end of the section.  If it becomes a
+                # performance problem, we can do that.
+                self.message._validate_rrset(section_number, rrset)
             self.current += rdlen
+        self.message._finish_section(section_number)
 
     def read(self):
         """Read a wire format DNS message and build a dns.message.Message
@@ -761,14 +780,12 @@ class _WireReader:
          aucount, adcount) = struct.unpack('!HHHHHH', self.wire[:12])
         self.current = 12
         self.message.original_id = self.message.id
-        if dns.opcode.is_update(self.message.flags):
-            self.updating = True
         self._get_question(qcount)
         if self.question_only:
             return
-        self._get_section(self.message.answer, ancount)
-        self._get_section(self.message.authority, aucount)
-        self._get_section(self.message.additional, adcount)
+        self._get_section(MessageSection.ANSWER, ancount)
+        self._get_section(MessageSection.AUTHORITY, aucount)
+        self._get_section(MessageSection.ADDITIONAL, adcount)
         if not self.ignore_trailing and self.current != l:
             raise TrailingJunk
         if self.message.multi and self.message.tsig_ctx and \
@@ -776,6 +793,31 @@ class _WireReader:
             self.message.tsig_ctx.update(self.wire)
 
 
+def _maybe_import_update():
+    # We avoid circular imports by doing this here.  We do it in another
+    # function as doing it in _message_factory_from_opcode() makes "dns"
+    # a local symbol, and the first line fails :)
+    import dns.update
+
+
+def _message_factory_from_opcode(opcode):
+    if opcode == dns.opcode.QUERY:
+        return QueryMessage
+    elif opcode == dns.opcode.UPDATE:
+        _maybe_import_update()
+        return dns.update.UpdateMessage
+    else:
+        return Message
+
+
+def _message_factory_from_wire(wire):
+    if len(wire) < 12:
+        raise ShortHeader
+    (flags,) = struct.unpack('!H', wire[2:4])
+    opcode = dns.opcode.from_flags(flags)
+    return _message_factory_from_opcode(opcode)
+
+
 def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
               tsig_ctx=None, multi=False, first=True,
               question_only=False, one_rr_per_rrset=False,
@@ -835,7 +877,7 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     Returns a ``dns.message.Message``.
     """
 
-    m = Message(id=0)
+    m = _message_factory_from_wire(wire)(id=0)
     m.keyring = keyring
     m.request_mac = request_mac
     m.xfr = xfr
@@ -867,20 +909,25 @@ class _TextReader:
 
     tok: the tokenizer.
     message: The message object being built.
-    updating: Is the message a dynamic update?
     zone_rdclass: The class of the zone in messages which are
     DNS dynamic updates.
     last_name: The most recently read name when building a message object.
     one_rr_per_rrset: Put each RR into its own RRset?
     """
 
-    def __init__(self, text, message, idna_codec, one_rr_per_rrset=False):
-        self.message = message
+    def __init__(self, text, idna_codec, one_rr_per_rrset=False):
+        self.message = None
         self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec)
         self.last_name = None
         self.zone_rdclass = dns.rdataclass.IN
-        self.updating = False
         self.one_rr_per_rrset = one_rr_per_rrset
+        self.id = None
+        self.edns = -1
+        self.ednsflags = 0
+        self.payload = None
+        self.rcode = None
+        self.opcode = dns.opcode.QUERY
+        self.flags = 0
 
     def _header_line(self, section):
         """Process one line from the text format header section."""
@@ -888,49 +935,46 @@ class _TextReader:
         token = self.tok.get()
         what = token.value
         if what == 'id':
-            self.message.id = self.tok.get_int()
+            self.id = self.tok.get_int()
         elif what == 'flags':
             while True:
                 token = self.tok.get()
                 if not token.is_identifier():
                     self.tok.unget(token)
                     break
-                self.message.flags = self.message.flags | \
-                    dns.flags.from_text(token.value)
-            if dns.opcode.is_update(self.message.flags):
-                self.updating = True
+                self.flags = self.flags | dns.flags.from_text(token.value)
         elif what == 'edns':
-            self.message.edns = self.tok.get_int()
-            self.message.ednsflags = self.message.ednsflags | \
-                (self.message.edns << 16)
+            self.edns = self.tok.get_int()
+            self.ednsflags = self.ednsflags | (self.edns << 16)
         elif what == 'eflags':
-            if self.message.edns < 0:
-                self.message.edns = 0
+            if self.edns < 0:
+                self.edns = 0
             while True:
                 token = self.tok.get()
                 if not token.is_identifier():
                     self.tok.unget(token)
                     break
-                self.message.ednsflags = self.message.ednsflags | \
+                self.ednsflags = self.ednsflags | \
                     dns.flags.edns_from_text(token.value)
         elif what == 'payload':
-            self.message.payload = self.tok.get_int()
-            if self.message.edns < 0:
-                self.message.edns = 0
+            self.payload = self.tok.get_int()
+            if self.edns < 0:
+                self.edns = 0
         elif what == 'opcode':
             text = self.tok.get_string()
-            self.message.flags = self.message.flags | \
-                dns.opcode.to_flags(dns.opcode.from_text(text))
+            self.opcode = dns.opcode.from_text(text)
+            self.flags = self.flags | dns.opcode.to_flags(self.opcode)
         elif what == 'rcode':
             text = self.tok.get_string()
-            self.message.set_rcode(dns.rcode.from_text(text))
+            self.rcode = dns.rcode.from_text(text)
         else:
             raise UnknownHeaderField
         self.tok.get_eol()
 
-    def _question_line(self, section):
+    def _question_line(self, section_number):
         """Process one line from the text format question section."""
 
+        section = self.message.sections[section_number]
         token = self.tok.get(want_leading=True)
         if not token.is_whitespace():
             self.last_name = self.tok.as_name(token, None)
@@ -950,18 +994,19 @@ class _TextReader:
             rdclass = dns.rdataclass.IN
         # Type
         rdtype = dns.rdatatype.from_text(token.value)
-        self.message.find_rrset(self.message.question, name,
-                                rdclass, rdtype, create=True,
-                                force_unique=True)
-        if self.updating:
-            self.zone_rdclass = rdclass
+        rrset = self.message.find_rrset(self.message.question, name,
+                                        rdclass, rdtype, create=True,
+                                        force_unique=True)
+        self.message._validate_rrset(section_number, rrset)
+        self.zone_rdclass = rdclass
         self.tok.get_eol()
 
-    def _rr_line(self, section):
+    def _rr_line(self, section_number):
         """Process one line from the text format answer, authority, or
         additional data sections.
         """
 
+        section = self.message.sections[section_number]
         deleting = None
         # Name
         token = self.tok.get(want_leading=True)
@@ -1004,19 +1049,36 @@ class _TextReader:
         else:
             rd = None
             covers = dns.rdatatype.NONE
-        force_unique = self.updating or self.one_rr_per_rrset
         rrset = self.message.find_rrset(section, name,
                                         rdclass, rdtype, covers,
-                                        deleting, True, force_unique)
+                                        deleting, True, self.one_rr_per_rrset)
         if rd is not None:
             rrset.add(rd, ttl)
+        self.message._validate_rrset(section_number, rrset)
+
+    def _maybe_instantiate_message(self):
+        if self.message is None:
+            # Time to instantiate the message!
+            factory = _message_factory_from_opcode(self.opcode)
+            self.message = factory(id=self.id)
+            self.message.flags = self.flags
+            if self.edns >= 0:
+                self.message.edns = self.edns
+            if self.ednsflags:
+                self.message.ednsflags = self.ednsflags
+            if self.payload:
+                self.message.payload = self.payload
+            if self.rcode:
+                self.message.set_rcode(self.rcode)
+            self.one_rr_per_rrset = \
+                self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
 
     def read(self):
         """Read a text format DNS message and build a dns.message.Message
         object."""
 
         line_method = self._header_line
-        section = None
+        section_number = None
         while 1:
             token = self.tok.get(True, True)
             if token.is_eol_or_eof():
@@ -1025,22 +1087,24 @@ class _TextReader:
                 u = token.value.upper()
                 if u == 'HEADER':
                     line_method = self._header_line
-                elif u == 'QUESTION' or u == 'ZONE':
-                    line_method = self._question_line
-                    section = self.message.question
-                elif u == 'ANSWER' or u == 'PREREQ':
-                    line_method = self._rr_line
-                    section = self.message.answer
-                elif u == 'AUTHORITY' or u == 'UPDATE':
-                    line_method = self._rr_line
-                    section = self.message.authority
-                elif u == 'ADDITIONAL':
-                    line_method = self._rr_line
-                    section = self.message.additional
+                elif u in {'QUESTION', 'ANSWER', 'AUTHORITY', 'ADDITIONAL',
+                           'ZONE', 'PREREQ', 'UPDATE'}:
+                    # It's ugly, but we have to do the check above because
+                    # if the token is JUST a comment, we want to ignore it,
+                    # and not prehaps prematurely instantiate the message!
+                    self._maybe_instantiate_message()
+                    section_number = self.message._section_enum.make(u)
+                    if section_number == 0:
+                        line_method = self._question_line
+                    else:
+                        line_method = self._rr_line
                 self.tok.get_eol()
                 continue
             self.tok.unget(token)
-            line_method(section)
+            line_method(section_number)
+        for i in range(4):
+            self.message._finish_section(i)
+        return self.message
 
 
 def from_text(text, idna_codec=None, one_rr_per_rrset=False):
@@ -1070,12 +1134,8 @@ def from_text(text, idna_codec=None, one_rr_per_rrset=False):
     # since it's an implementation detail.  The official file
     # interface is from_file().
 
-    m = Message()
-
-    reader = _TextReader(text, m, idna_codec, one_rr_per_rrset)
-    reader.read()
-
-    return m
+    reader = _TextReader(text, idna_codec, one_rr_per_rrset)
+    return reader.read()
 
 
 def from_file(f, idna_codec=None, one_rr_per_rrset=False):
index 2aa8ca56db231e999e8e98c141c778e1cb9dedb0..d57d0cd4e511427c1f78a7a8c01f8e99184f1465 100644 (file)
@@ -41,19 +41,22 @@ class UpdateSection(dns.enum.IntEnum):
 globals().update(UpdateSection.__members__)
 
 
-class Update(dns.message.Message):
+class UpdateMessage(dns.message.Message):
 
     _section_enum = UpdateSection
 
-    def __init__(self, zone, rdclass=dns.rdataclass.IN, keyring=None,
-                 keyname=None, keyalgorithm=dns.tsig.default_algorithm):
+    def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None,
+                 keyname=None, keyalgorithm=dns.tsig.default_algorithm,
+                 id=None):
         """Initialize a new DNS Update object.
 
         See the documentation of the Message class for a complete
         description of the keyring dictionary.
 
-        *zone*, a ``dns.name.Name`` or ``str``, the zone which is being
-        updated.
+        *zone*, a ``dns.name.Name``, ``str``, or ``None``, the zone
+        which is being updated.  ``None`` should only be used by dnspython's
+        message constructors, as a zone is required for the convenience
+        methods like ``add()``, ``replace()``, etc.
 
         *rdclass*, an ``int`` or ``str``, the class of the zone.
 
@@ -68,16 +71,18 @@ class Update(dns.message.Message):
         to use; defaults to ``None``. The key must be defined in the keyring.
 
         *keyalgorithm*, a ``dns.name.Name``, the TSIG algorithm to use.
+
         """
-        super().__init__()
+        super().__init__(id=id)
         self.flags |= dns.opcode.to_flags(dns.opcode.UPDATE)
         if isinstance(zone, str):
             zone = dns.name.from_text(zone)
         self.origin = zone
         rdclass = dns.rdataclass.RdataClass.make(rdclass)
         self.zone_rdclass = rdclass
-        self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA,
-                        create=True, force_unique=True)
+        if self.origin:
+            self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA,
+                            create=True, force_unique=True)
         if keyring is not None:
             self.use_tsig(keyring, keyname, algorithm=keyalgorithm)
 
@@ -288,23 +293,21 @@ class Update(dns.message.Message):
                             dns.rdatatype.NONE, None,
                             True, True)
 
-    def to_wire(self, origin=None, max_size=65535):
-        """Return a string containing the update in DNS compressed wire
-        format.
-
-        *origin*, a ``dns.name.Name`` or ``None``, the origin to be
-        appended to any relative names.  If *origin* is ``None``, then
-        the origin of the ``dns.update.Update`` message object is used
-        (i.e. the *zone* parameter passed when the Update object was
-        created).
+    def _get_one_rr_per_rrset(self, value):
+        # Updates are always one_rr_per_rrset
+        return True
 
-        *max_size*, an ``int``, the maximum size of the wire format
-        output; default is 0, which means "the message's request
-        payload, if nonzero, or 65535".
+    def _validate_rrset(self, section, rrset):
+        if section == UpdateSection.ZONE:
+            if rrset.rdtype != dns.rdatatype.SOA:
+                raise dns.exception.FormError
 
-        Returns a ``bytes``.
-        """
+    def _finish_section(self, section):
+        if section == UpdateSection.ZONE and len(self.zone) != 1:
+            raise dns.exception.FormError
+        self.zone_rdclass = self.zone[0].rdclass
+        # We do NOT want to set origin here, as that would cause
+        # from_wire() relativization.
 
-        if origin is None:
-            origin = self.origin
-        return super().to_wire(origin, max_size)
+# backwards compatibility
+Update = UpdateMessage