]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Simplify renderer code. 510/head
authorBrian Wellington <bwelling@xbill.org>
Thu, 18 Jun 2020 02:15:07 +0000 (19:15 -0700)
committerBrian Wellington <bwelling@xbill.org>
Thu, 18 Jun 2020 02:15:07 +0000 (19:15 -0700)
Use context manager to avoid duplicating length checking in many places.
Change the code dealing with EDNS options to avoid lots of seeking by
not rendering directly into the file.

dns/renderer.py

index 27d96a6266378ee91aa5aa1ce658dc99fcc7af29..959b8bf28a0b22a572ed7db0c05afde527f73eb7 100644 (file)
@@ -17,6 +17,7 @@
 
 """Help for building DNS wire format messages"""
 
+import contextlib
 import io
 import struct
 import random
@@ -120,17 +121,21 @@ class Renderer:
                 raise dns.exception.FormError
             self.section = section
 
+    @contextlib.contextmanager
+    def _track_size(self):
+        start = self.output.tell()
+        yield start
+        if self.output.tell() > self.max_size:
+            self._rollback(start)
+            raise dns.exception.TooBig
+
     def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
         """Add a question to the message."""
 
         self._set_section(QUESTION)
-        before = self.output.tell()
-        qname.to_wire(self.output, self.compress, self.origin)
-        self.output.write(struct.pack("!HH", rdtype, rdclass))
-        after = self.output.tell()
-        if after >= self.max_size:
-            self._rollback(before)
-            raise dns.exception.TooBig
+        with self._track_size():
+            qname.to_wire(self.output, self.compress, self.origin)
+            self.output.write(struct.pack("!HH", rdtype, rdclass))
         self.counts[QUESTION] += 1
 
     def add_rrset(self, section, rrset, **kw):
@@ -141,12 +146,8 @@ class Renderer:
         """
 
         self._set_section(section)
-        before = self.output.tell()
-        n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
-        after = self.output.tell()
-        if after >= self.max_size:
-            self._rollback(before)
-            raise dns.exception.TooBig
+        with self._track_size():
+            n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
         self.counts[section] += n
 
     def add_rdataset(self, section, name, rdataset, **kw):
@@ -158,13 +159,9 @@ class Renderer:
         """
 
         self._set_section(section)
-        before = self.output.tell()
-        n = rdataset.to_wire(name, self.output, self.compress, self.origin,
-                             **kw)
-        after = self.output.tell()
-        if after >= self.max_size:
-            self._rollback(before)
-            raise dns.exception.TooBig
+        with self._track_size():
+            n = rdataset.to_wire(name, self.output, self.compress, self.origin,
+                                 **kw)
         self.counts[section] += n
 
     def add_edns(self, edns, ednsflags, payload, options=None):
@@ -174,32 +171,19 @@ class Renderer:
         ednsflags &= 0xFF00FFFF
         ednsflags |= (edns << 16)
         self._set_section(ADDITIONAL)
-        before = self.output.tell()
-        self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload,
-                                      ednsflags, 0))
-        if options is not None:
-            lstart = self.output.tell()
-            for opt in options:
-                stuff = struct.pack("!HH", opt.otype, 0)
-                self.output.write(stuff)
-                start = self.output.tell()
-                opt.to_wire(self.output)
-                end = self.output.tell()
-                assert end - start < 65536
-                self.output.seek(start - 2)
-                stuff = struct.pack("!H", end - start)
-                self.output.write(stuff)
-                self.output.seek(0, 2)
-            lend = self.output.tell()
-            assert lend - lstart < 65536
-            self.output.seek(lstart - 2)
-            stuff = struct.pack("!H", lend - lstart)
-            self.output.write(stuff)
-            self.output.seek(0, 2)
-        after = self.output.tell()
-        if after >= self.max_size:
-            self._rollback(before)
-            raise dns.exception.TooBig
+        with self._track_size():
+            self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT,
+                                          payload, ednsflags, 0))
+            if options is not None:
+                lstart = self.output.tell()
+                for opt in options:
+                    owire = opt.to_wire()
+                    self.output.write(struct.pack("!HH", opt.otype, len(owire)))
+                    self.output.write(owire)
+                lend = self.output.tell()
+                self.output.seek(lstart - 2)
+                self.output.write(struct.pack("!H", lend - lstart))
+                self.output.seek(0, io.SEEK_END)
         self.counts[ADDITIONAL] += 1
 
     def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
@@ -249,26 +233,20 @@ class Renderer:
 
     def _write_tsig(self, tsig_rdata, keyname):
         self._set_section(ADDITIONAL)
-        before = self.output.tell()
-
-        keyname.to_wire(self.output, self.compress, self.origin)
-        self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
-                                      dns.rdataclass.ANY, 0, 0))
-        rdata_start = self.output.tell()
-        self.output.write(tsig_rdata)
+        with self._track_size():
+            keyname.to_wire(self.output, self.compress, self.origin)
+            self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
+                                          dns.rdataclass.ANY, 0, 0))
+            rdata_start = self.output.tell()
+            self.output.write(tsig_rdata)
 
         after = self.output.tell()
-        assert after - rdata_start < 65536
-        if after >= self.max_size:
-            self._rollback(before)
-            raise dns.exception.TooBig
-
         self.output.seek(rdata_start - 2)
         self.output.write(struct.pack('!H', after - rdata_start))
         self.counts[ADDITIONAL] += 1
         self.output.seek(10)
         self.output.write(struct.pack('!H', self.counts[ADDITIONAL]))
-        self.output.seek(0, 2)
+        self.output.seek(0, io.SEEK_END)
 
     def write_header(self):
         """Write the DNS message header.
@@ -282,7 +260,7 @@ class Renderer:
         self.output.write(struct.pack('!HHHHHH', self.id, self.flags,
                                       self.counts[0], self.counts[1],
                                       self.counts[2], self.counts[3]))
-        self.output.seek(0, 2)
+        self.output.seek(0, io.SEEK_END)
 
     def get_wire(self):
         """Return the wire format message."""