From c6ad78bde1ffa7f27f0d24d475812cf8f64123f9 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Thu, 10 Aug 2023 17:32:36 -0700 Subject: [PATCH] Add helpers to reduce seeking boilerplate when rendering. (#980) --- dns/rdataset.py | 18 ++++-------- dns/rdtypes/svcbbase.py | 18 ++++-------- dns/rdtypes/txtbase.py | 7 ++--- dns/renderer.py | 64 +++++++++++++++++++++++++++-------------- 4 files changed, 56 insertions(+), 51 deletions(-) diff --git a/dns/rdataset.py b/dns/rdataset.py index 31124afc..5921a48c 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -28,6 +28,7 @@ import dns.name import dns.rdata import dns.rdataclass import dns.rdatatype +import dns.renderer import dns.set import dns.ttl @@ -316,11 +317,9 @@ class Rdataset(dns.set.Set): want_shuffle = False else: rdclass = self.rdclass - file.seek(0, io.SEEK_END) if len(self) == 0: name.to_wire(file, compress, origin) - stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0) - file.write(stuff) + file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)) return 1 else: l: Union[Rdataset, List[dns.rdata.Rdata]] @@ -331,16 +330,9 @@ class Rdataset(dns.set.Set): l = self for rd in l: name.to_wire(file, compress, origin) - stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0) - file.write(stuff) - start = file.tell() - rd.to_wire(file, compress, origin) - end = file.tell() - assert end - start < 65536 - file.seek(start - 2) - stuff = struct.pack("!H", end - start) - file.write(stuff) - file.seek(0, io.SEEK_END) + file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl)) + with dns.renderer.prefixed_length(file, 2): + rd.to_wire(file, compress, origin) return len(self) def match( diff --git a/dns/rdtypes/svcbbase.py b/dns/rdtypes/svcbbase.py index ba5b53d2..7544d58f 100644 --- a/dns/rdtypes/svcbbase.py +++ b/dns/rdtypes/svcbbase.py @@ -13,6 +13,7 @@ import dns.ipv6 import dns.name import dns.rdata import dns.rdtypes.util +import dns.renderer import dns.tokenizer import dns.wire @@ -521,19 +522,10 @@ class SVCBBase(dns.rdata.Rdata): for key in sorted(self.params): file.write(struct.pack("!H", key)) value = self.params[key] - # placeholder for length (or actual length of empty values) - file.write(struct.pack("!H", 0)) - if value is None: - continue - else: - start = file.tell() - value.to_wire(file, origin) - end = file.tell() - assert end - start < 65536 - file.seek(start - 2) - stuff = struct.pack("!H", end - start) - file.write(stuff) - file.seek(0, io.SEEK_END) + with dns.renderer.prefixed_length(file, 2): + # Note that we're still writing a length of zero if the value is None + if value is not None: + value.to_wire(file, origin) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py index fdbfb646..690a1997 100644 --- a/dns/rdtypes/txtbase.py +++ b/dns/rdtypes/txtbase.py @@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import dns.exception import dns.immutable import dns.rdata +import dns.renderer import dns.tokenizer @@ -93,10 +94,8 @@ class TXTBase(dns.rdata.Rdata): def _to_wire(self, file, compress=None, origin=None, canonicalize=False): for s in self.strings: - l = len(s) - assert l < 256 - file.write(struct.pack("!B", l)) - file.write(s) + with dns.renderer.prefixed_length(file, 1): + file.write(s) @classmethod def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): diff --git a/dns/renderer.py b/dns/renderer.py index 53e7c0f6..a77481f6 100644 --- a/dns/renderer.py +++ b/dns/renderer.py @@ -32,6 +32,24 @@ AUTHORITY = 2 ADDITIONAL = 3 +@contextlib.contextmanager +def prefixed_length(output, length_length): + output.write(b"\00" * length_length) + start = output.tell() + yield + end = output.tell() + length = end - start + if length > 0: + try: + output.seek(start - length_length) + try: + output.write(length.to_bytes(length_length, "big")) + except OverflowError: + raise dns.exception.FormError + finally: + output.seek(end) + + class Renderer: """Helper class for building DNS wire-format messages. @@ -134,6 +152,15 @@ class Renderer: self._rollback(start) raise dns.exception.TooBig + @contextlib.contextmanager + def _temporarily_seek_to(self, where): + current = self.output.tell() + try: + self.output.seek(where) + yield + finally: + self.output.seek(current) + def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): """Add a question to the message.""" @@ -269,18 +296,14 @@ class Renderer: with self._track_size(): keyname.to_wire(self.output, compress, self.origin) self.output.write( - struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0) + struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0) ) - rdata_start = self.output.tell() - tsig.to_wire(self.output) + with prefixed_length(self.output, 2): + tsig.to_wire(self.output) - after = self.output.tell() - self.output.seek(rdata_start - 2) - self.output.write(struct.pack("!H", after - rdata_start)) self.counts[ADDITIONAL] += 1 - self.output.seek(10) - self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) - self.output.seek(0, io.SEEK_END) + with self._temporarily_seek_to(10): + self.output.write(struct.pack("!H", self.counts[ADDITIONAL])) def write_header(self): """Write the DNS message header. @@ -290,19 +313,18 @@ class Renderer: is added. """ - self.output.seek(0) - self.output.write( - struct.pack( - "!HHHHHH", - self.id, - self.flags, - self.counts[0], - self.counts[1], - self.counts[2], - self.counts[3], + with self._temporarily_seek_to(0): + self.output.write( + struct.pack( + "!HHHHHH", + self.id, + self.flags, + self.counts[0], + self.counts[1], + self.counts[2], + self.counts[3], + ) ) - ) - self.output.seek(0, io.SEEK_END) def get_wire(self): """Return the wire format message.""" -- 2.47.3