]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Refactor OPT handling code into OPT record class. (#520)
authorBrian Wellington <bwelling@xbill.org>
Sat, 27 Jun 2020 01:55:53 +0000 (18:55 -0700)
committerGitHub <noreply@github.com>
Sat, 27 Jun 2020 01:55:53 +0000 (18:55 -0700)
* Create an OPT record class.

* Move OPT logic to one place.

* Store the OPT record on the message object.

This also adds a Renderer.add_rdata() method.

* Add Rdataset.rdata_to_wire() helper.

* Fix conflicts; simplify.

* Fix typo.

* style

* Add a trivial to_text so that repr() works.

* Add _parse_special_rr_header

* More OPT checking.

Pass the name to _parse_rr_header and _parse_special_rr_header, and
check that the OPT record has the root name.

dns/message.py
dns/rdtypes/ANY/OPT.py [new file with mode: 0644]
dns/renderer.py
dns/update.py

index 03d3f4ca0584b43eff8a4dcbdc229eed9326baba..0f4a40125a68e5d8720113bc57105ca10c780b2e 100644 (file)
@@ -37,6 +37,7 @@ import dns.rrset
 import dns.renderer
 import dns.tsig
 import dns.wiredata
+import dns.rdtypes.ANY.OPT
 
 
 class ShortHeader(dns.exception.FormError):
@@ -105,10 +106,7 @@ class Message:
             self.id = id
         self.flags = 0
         self.sections = [[], [], [], []]
-        self.edns = -1
-        self.ednsflags = 0
-        self.payload = 0
-        self.options = []
+        self.opt = None
         self.request_payload = 0
         self.keyring = None
         self.keyname = None
@@ -440,8 +438,8 @@ class Message:
             r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
         for rrset in self.authority:
             r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
-        if self.edns >= 0:
-            r.add_edns(self.edns, self.ednsflags, self.payload, self.options)
+        if self.opt is not None:
+            r.add_rrset(dns.renderer.ADDITIONAL, self.opt)
         for rrset in self.additional:
             r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
         r.write_header()
@@ -508,6 +506,12 @@ class Message:
         self.tsig_error = tsig_error
         self.other_data = other_data
 
+    @staticmethod
+    def _make_opt(flags=0, payload=1280, options=None):
+        opt = dns.rdtypes.ANY.OPT.OPT(payload, dns.rdatatype.OPT,
+                                      options or ())
+        return dns.rrset.from_rdata(dns.name.root, int(flags), opt)
+
     def use_edns(self, edns=0, ednsflags=0, payload=1280, request_payload=None,
                  options=None):
         """Configure EDNS behavior.
@@ -548,12 +552,47 @@ class Message:
             ednsflags |= (edns << 16)
             if options is None:
                 options = []
-        self.edns = edns
-        self.ednsflags = ednsflags
-        self.payload = payload
-        self.options = options
+        if edns >= 0:
+            self.opt = self._make_opt(ednsflags, payload, options)
+        else:
+            self.opt = None
         self.request_payload = request_payload
 
+    @property
+    def edns(self):
+        if self.opt:
+            return (self.ednsflags & 0xff0000) >> 16
+        else:
+            return -1
+
+    @property
+    def ednsflags(self):
+        if self.opt:
+            return self.opt.ttl
+        else:
+            return 0
+
+    @ednsflags.setter
+    def ednsflags(self, v):
+        if self.opt:
+            self.opt.ttl = v
+        else:
+            self.opt = self._make_opt(0, v)
+
+    @property
+    def payload(self):
+        if self.opt:
+            return self.opt[0].payload
+        else:
+            return 0
+
+    @property
+    def options(self):
+        if self.opt:
+            return self.opt[0].options
+        else:
+            return ()
+
     def want_dnssec(self, wanted=True):
         """Enable or disable 'DNSSEC desired' flag in requests.
 
@@ -564,10 +603,8 @@ class Message:
         """
 
         if wanted:
-            if self.edns < 0:
-                self.use_edns()
             self.ednsflags |= dns.flags.DO
-        elif self.edns >= 0:
+        elif self.opt:
             self.ednsflags &= ~dns.flags.DO
 
     def rcode(self):
@@ -609,11 +646,18 @@ class Message:
         # What the caller picked is fine.
         return value
 
-    def _parse_rr_header(self, section, rdclass, rdtype):
+    def _parse_rr_header(self, section, name, rdclass, rdtype):
         if dns.rdataclass.is_metaclass(rdclass):
             raise dns.exception.FormError
         return (rdclass, rdtype, None, False)
 
+    def _parse_special_rr_header(self, section, name, rdclass, rdtype):
+        if rdtype == dns.rdatatype.OPT:
+            if section != MessageSection.ADDITIONAL or self.opt or \
+               name != dns.name.root:
+                raise BadEDNS
+        return (rdclass, rdtype, None, False)
+
 
 class QueryMessage(Message):
     pass
@@ -678,7 +722,8 @@ class _WireReader:
                               self.wire[self.current:self.current + 4])
             self.current += 4
             (rdclass, rdtype, _, _) = \
-                self.message._parse_rr_header(section_number, rdclass, rdtype)
+                self.message._parse_rr_header(section_number, qname, rdclass,
+                                              rdtype)
             self.message.find_rrset(section, qname, rdclass, rdtype,
                                     create=True, force_unique=True)
 
@@ -692,7 +737,6 @@ class _WireReader:
 
         section = self.message.sections[section_number]
         force_unique = self.one_rr_per_rrset
-        seen_opt = False
         for i in range(count):
             rr_start = self.current
             (name, used) = dns.name.from_wire(self.wire, self.current)
@@ -704,27 +748,7 @@ class _WireReader:
                 struct.unpack('!HHIH',
                               self.wire[self.current:self.current + 10])
             self.current += 10
-            if rdtype == dns.rdatatype.OPT:
-                if section is not self.message.additional or seen_opt:
-                    raise BadEDNS
-                self.message.payload = rdclass
-                self.message.ednsflags = ttl
-                self.message.edns = (ttl & 0xff0000) >> 16
-                self.message.options = []
-                current = self.current
-                optslen = rdlen
-                while optslen > 0:
-                    (otype, olen) = \
-                        struct.unpack('!HH',
-                                      self.wire[current:current + 4])
-                    current = current + 4
-                    opt = dns.edns.option_from_wire(
-                        otype, self.wire, current, olen)
-                    self.message.options.append(opt)
-                    current = current + olen
-                    optslen = optslen - 4 - olen
-                seen_opt = True
-            elif rdtype == dns.rdatatype.TSIG:
+            if rdtype == dns.rdatatype.TSIG:
                 if not (section is self.message.additional and
                         i == (count - 1)):
                     raise BadTSIG
@@ -751,9 +775,15 @@ class _WireReader:
                                       self.message.first)
                 self.message.had_tsig = True
             else:
-                (rdclass, rdtype, deleting, empty) = \
-                    self.message._parse_rr_header(section_number,
-                                                  rdclass, rdtype)
+                if rdtype == dns.rdatatype.OPT:
+                    (rdclass, rdtype, deleting, empty) = \
+                        self.message._parse_special_rr_header(section_number,
+                                                              name,
+                                                              rdclass, rdtype)
+                else:
+                    (rdclass, rdtype, deleting, empty) = \
+                        self.message._parse_rr_header(section_number,
+                                                      name, rdclass, rdtype)
                 if empty:
                     if rdlen > 0:
                         raise dns.exception.FormError
@@ -766,13 +796,17 @@ class _WireReader:
                     covers = rd.covers()
                 if self.message.xfr and rdtype == dns.rdatatype.SOA:
                     force_unique = True
-                rrset = self.message.find_rrset(section, name,
-                                                rdclass, rdtype, covers,
-                                                deleting, True, force_unique)
-                if rd is not None:
-                    if ttl > 0x7fffffff:
-                        ttl = 0
-                    rrset.add(rd, ttl)
+                if rdtype == dns.rdatatype.OPT:
+                    self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
+                else:
+                    rrset = self.message.find_rrset(section, name,
+                                                    rdclass, rdtype, covers,
+                                                    deleting, True,
+                                                    force_unique)
+                    if rd is not None:
+                        if ttl > 0x7fffffff:
+                            ttl = 0
+                        rrset.add(rd, ttl)
             self.current += rdlen
 
     def read(self):
@@ -990,7 +1024,7 @@ class _TextReader:
         # Type
         rdtype = dns.rdatatype.from_text(token.value)
         (rdclass, rdtype, _, _) = \
-            self.message._parse_rr_header(section_number, rdclass, rdtype)
+            self.message._parse_rr_header(section_number, name, rdclass, rdtype)
         self.message.find_rrset(section, name, rdclass, rdtype, create=True,
                                 force_unique=True)
         self.tok.get_eol()
@@ -1034,7 +1068,7 @@ class _TextReader:
         # Type
         rdtype = dns.rdatatype.from_text(token.value)
         (rdclass, rdtype, deleting, empty) = \
-            self.message._parse_rr_header(section_number, rdclass, rdtype)
+            self.message._parse_rr_header(section_number, name, rdclass, rdtype)
         token = self.tok.get()
         if empty and not token.is_eol_or_eof():
             raise dns.exception.SyntaxError
@@ -1058,11 +1092,7 @@ class _TextReader:
         message = factory(id=self.id)
         message.flags = self.flags
         if self.edns >= 0:
-            message.edns = self.edns
-        if self.ednsflags:
-            message.ednsflags = self.ednsflags
-        if self.payload:
-            message.payload = self.payload
+            message.use_edns(self.edns, self.ednsflags, self.payload)
         if self.rcode:
             message.set_rcode(self.rcode)
         if self.origin:
diff --git a/dns/rdtypes/ANY/OPT.py b/dns/rdtypes/ANY/OPT.py
new file mode 100644 (file)
index 0000000..0a0e7af
--- /dev/null
@@ -0,0 +1,74 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import struct
+
+import dns.edns
+import dns.exception
+import dns.rdata
+
+
+class OPT(dns.rdata.Rdata):
+
+    """OPT record"""
+
+    __slots__ = ['options']
+
+    def __init__(self, rdclass, rdtype, options):
+        """Initialize an OPT rdata.
+
+        *rdclass*, an ``int`` is the rdataclass of the Rdata,
+        which is also the payload size.
+
+        *rdtype*, an ``int`` is the rdatatype of the Rdata.
+
+        *options*, a tuple of ``bytes``
+        """
+
+        super().__init__(rdclass, rdtype)
+        object.__setattr__(self, 'options', dns.rdata._constify(options))
+
+    def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+        for opt in self.options:
+            owire = opt.to_wire()
+            file.write(struct.pack("!HH", opt.otype, len(owire)))
+            file.write(owire)
+
+    def to_text(self, origin=None, relativize=True, **kw):
+        return ' '.join(opt.to_text() for opt in self.options)
+
+    @classmethod
+    def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
+        options = []
+        while rdlen > 0:
+            if rdlen < 4:
+                raise dns.exception.FormError
+            (otype, olen) = struct.unpack('!HH', wire[current:current + 4])
+            current += 4
+            rdlen -= 4
+            if olen > rdlen:
+                raise dns.exception.FormError
+            opt = dns.edns.option_from_wire(otype, wire, current, olen)
+            current += olen
+            rdlen -= olen
+            options.append(opt)
+        return cls(rdclass, rdtype, options)
+
+    @property
+    def payload(self):
+        "payload size"
+        return self.rdclass
index 959b8bf28a0b22a572ed7db0c05afde527f73eb7..34ca12ab885b79988b765f6b8548d0707b1c040c 100644 (file)
@@ -170,21 +170,8 @@ class Renderer:
         # make sure the EDNS version in ednsflags agrees with edns
         ednsflags &= 0xFF00FFFF
         ednsflags |= (edns << 16)
-        self._set_section(ADDITIONAL)
-        with self._track_size():
-            self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT,
-                                          payload, ednsflags, 0))
-            if options is not None:
-                lstart = self.output.tell()
-                for opt in options:
-                    owire = opt.to_wire()
-                    self.output.write(struct.pack("!HH", opt.otype, len(owire)))
-                    self.output.write(owire)
-                lend = self.output.tell()
-                self.output.seek(lstart - 2)
-                self.output.write(struct.pack("!H", lend - lstart))
-                self.output.seek(0, io.SEEK_END)
-        self.counts[ADDITIONAL] += 1
+        opt = dns.message.Message._make_opt(ednsflags, payload, options)
+        self.add_rdataset(ADDITIONAL, dns.name.root, opt)
 
     def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
                  request_mac, algorithm=dns.tsig.default_algorithm):
index 9166ee5bd907c02e60e1b7b04fa144f4960a2543..2d4965d5621af43be04e17e2b22eeb3dfd3f5976 100644 (file)
@@ -300,7 +300,7 @@ class UpdateMessage(dns.message.Message):
         # Updates are always one_rr_per_rrset
         return True
 
-    def _parse_rr_header(self, section, rdclass, rdtype):
+    def _parse_rr_header(self, section, name, rdclass, rdtype):
         deleting = None
         empty = False
         if section == UpdateSection.ZONE: