From 4b15361cb52762e8220feb91ad9bd957795bf164 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Wed, 17 Jun 2020 19:15:07 -0700 Subject: [PATCH] Simplify renderer code. Use context manager to avoid duplicating length checking in many places. Change the code dealing with EDNS options to avoid lots of seeking by not rendering directly into the file. --- dns/renderer.py | 98 +++++++++++++++++++------------------------------ 1 file changed, 38 insertions(+), 60 deletions(-) diff --git a/dns/renderer.py b/dns/renderer.py index 27d96a62..959b8bf2 100644 --- a/dns/renderer.py +++ b/dns/renderer.py @@ -17,6 +17,7 @@ """Help for building DNS wire format messages""" +import contextlib import io import struct import random @@ -120,17 +121,21 @@ class Renderer: raise dns.exception.FormError self.section = section + @contextlib.contextmanager + def _track_size(self): + start = self.output.tell() + yield start + if self.output.tell() > self.max_size: + self._rollback(start) + raise dns.exception.TooBig + def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): """Add a question to the message.""" self._set_section(QUESTION) - before = self.output.tell() - qname.to_wire(self.output, self.compress, self.origin) - self.output.write(struct.pack("!HH", rdtype, rdclass)) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig + with self._track_size(): + qname.to_wire(self.output, self.compress, self.origin) + self.output.write(struct.pack("!HH", rdtype, rdclass)) self.counts[QUESTION] += 1 def add_rrset(self, section, rrset, **kw): @@ -141,12 +146,8 @@ class Renderer: """ self._set_section(section) - before = self.output.tell() - n = rrset.to_wire(self.output, self.compress, self.origin, **kw) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig + with self._track_size(): + n = rrset.to_wire(self.output, self.compress, self.origin, **kw) self.counts[section] += n def add_rdataset(self, section, name, rdataset, **kw): @@ -158,13 +159,9 @@ class Renderer: """ self._set_section(section) - before = self.output.tell() - n = rdataset.to_wire(name, self.output, self.compress, self.origin, - **kw) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig + with self._track_size(): + n = rdataset.to_wire(name, self.output, self.compress, self.origin, + **kw) self.counts[section] += n def add_edns(self, edns, ednsflags, payload, options=None): @@ -174,32 +171,19 @@ class Renderer: ednsflags &= 0xFF00FFFF ednsflags |= (edns << 16) self._set_section(ADDITIONAL) - before = self.output.tell() - self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload, - ednsflags, 0)) - if options is not None: - lstart = self.output.tell() - for opt in options: - stuff = struct.pack("!HH", opt.otype, 0) - self.output.write(stuff) - start = self.output.tell() - opt.to_wire(self.output) - end = self.output.tell() - assert end - start < 65536 - self.output.seek(start - 2) - stuff = struct.pack("!H", end - start) - self.output.write(stuff) - self.output.seek(0, 2) - lend = self.output.tell() - assert lend - lstart < 65536 - self.output.seek(lstart - 2) - stuff = struct.pack("!H", lend - lstart) - self.output.write(stuff) - self.output.seek(0, 2) - after = self.output.tell() - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig + with self._track_size(): + self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, + payload, ednsflags, 0)) + if options is not None: + lstart = self.output.tell() + for opt in options: + owire = opt.to_wire() + self.output.write(struct.pack("!HH", opt.otype, len(owire))) + self.output.write(owire) + lend = self.output.tell() + self.output.seek(lstart - 2) + self.output.write(struct.pack("!H", lend - lstart)) + self.output.seek(0, io.SEEK_END) self.counts[ADDITIONAL] += 1 def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data, @@ -249,26 +233,20 @@ class Renderer: def _write_tsig(self, tsig_rdata, keyname): self._set_section(ADDITIONAL) - before = self.output.tell() - - keyname.to_wire(self.output, self.compress, self.origin) - self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, - dns.rdataclass.ANY, 0, 0)) - rdata_start = self.output.tell() - self.output.write(tsig_rdata) + with self._track_size(): + keyname.to_wire(self.output, self.compress, self.origin) + self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, + dns.rdataclass.ANY, 0, 0)) + rdata_start = self.output.tell() + self.output.write(tsig_rdata) after = self.output.tell() - assert after - rdata_start < 65536 - if after >= self.max_size: - self._rollback(before) - raise dns.exception.TooBig - self.output.seek(rdata_start - 2) self.output.write(struct.pack('!H', after - rdata_start)) self.counts[ADDITIONAL] += 1 self.output.seek(10) self.output.write(struct.pack('!H', self.counts[ADDITIONAL])) - self.output.seek(0, 2) + self.output.seek(0, io.SEEK_END) def write_header(self): """Write the DNS message header. @@ -282,7 +260,7 @@ class Renderer: self.output.write(struct.pack('!HHHHHH', self.id, self.flags, self.counts[0], self.counts[1], self.counts[2], self.counts[3])) - self.output.seek(0, 2) + self.output.seek(0, io.SEEK_END) def get_wire(self): """Return the wire format message.""" -- 2.47.3