]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add helpers to reduce seeking boilerplate when rendering. (#980)
authorBob Halley <halley@dnspython.org>
Fri, 11 Aug 2023 00:32:36 +0000 (17:32 -0700)
committerGitHub <noreply@github.com>
Fri, 11 Aug 2023 00:32:36 +0000 (17:32 -0700)
dns/rdataset.py
dns/rdtypes/svcbbase.py
dns/rdtypes/txtbase.py
dns/renderer.py

index 31124afcc46f013cdaa8ac1bebda62813dab6b14..5921a48cd6b8f234abb33082318bea1fd27690c7 100644 (file)
@@ -28,6 +28,7 @@ import dns.name
 import dns.rdata
 import dns.rdataclass
 import dns.rdatatype
+import dns.renderer
 import dns.set
 import dns.ttl
 
@@ -316,11 +317,9 @@ class Rdataset(dns.set.Set):
             want_shuffle = False
         else:
             rdclass = self.rdclass
-        file.seek(0, io.SEEK_END)
         if len(self) == 0:
             name.to_wire(file, compress, origin)
-            stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)
-            file.write(stuff)
+            file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0))
             return 1
         else:
             l: Union[Rdataset, List[dns.rdata.Rdata]]
@@ -331,16 +330,9 @@ class Rdataset(dns.set.Set):
                 l = self
             for rd in l:
                 name.to_wire(file, compress, origin)
-                stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0)
-                file.write(stuff)
-                start = file.tell()
-                rd.to_wire(file, compress, origin)
-                end = file.tell()
-                assert end - start < 65536
-                file.seek(start - 2)
-                stuff = struct.pack("!H", end - start)
-                file.write(stuff)
-                file.seek(0, io.SEEK_END)
+                file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl))
+                with dns.renderer.prefixed_length(file, 2):
+                    rd.to_wire(file, compress, origin)
             return len(self)
 
     def match(
index ba5b53d2cb7d0e25d0437b2193f0aae4a3324f26..7544d58f91b319d47ea08761f6a7271b9af3a935 100644 (file)
@@ -13,6 +13,7 @@ import dns.ipv6
 import dns.name
 import dns.rdata
 import dns.rdtypes.util
+import dns.renderer
 import dns.tokenizer
 import dns.wire
 
@@ -521,19 +522,10 @@ class SVCBBase(dns.rdata.Rdata):
         for key in sorted(self.params):
             file.write(struct.pack("!H", key))
             value = self.params[key]
-            # placeholder for length (or actual length of empty values)
-            file.write(struct.pack("!H", 0))
-            if value is None:
-                continue
-            else:
-                start = file.tell()
-                value.to_wire(file, origin)
-                end = file.tell()
-                assert end - start < 65536
-                file.seek(start - 2)
-                stuff = struct.pack("!H", end - start)
-                file.write(stuff)
-                file.seek(0, io.SEEK_END)
+            with dns.renderer.prefixed_length(file, 2):
+                # Note that we're still writing a length of zero if the value is None
+                if value is not None:
+                    value.to_wire(file, origin)
 
     @classmethod
     def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
index fdbfb6465f81188512b47b7ed9a58a9ddca89e74..690a1997a870c68a8205c81581410bf508208d67 100644 (file)
@@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
 import dns.exception
 import dns.immutable
 import dns.rdata
+import dns.renderer
 import dns.tokenizer
 
 
@@ -93,10 +94,8 @@ class TXTBase(dns.rdata.Rdata):
 
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         for s in self.strings:
-            l = len(s)
-            assert l < 256
-            file.write(struct.pack("!B", l))
-            file.write(s)
+            with dns.renderer.prefixed_length(file, 1):
+                file.write(s)
 
     @classmethod
     def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
index 53e7c0f6faf98477594b8ce65162160ce2ec0f98..a77481f67c8cdab5205ce8453550eb6f02ec566c 100644 (file)
@@ -32,6 +32,24 @@ AUTHORITY = 2
 ADDITIONAL = 3
 
 
+@contextlib.contextmanager
+def prefixed_length(output, length_length):
+    output.write(b"\00" * length_length)
+    start = output.tell()
+    yield
+    end = output.tell()
+    length = end - start
+    if length > 0:
+        try:
+            output.seek(start - length_length)
+            try:
+                output.write(length.to_bytes(length_length, "big"))
+            except OverflowError:
+                raise dns.exception.FormError
+        finally:
+            output.seek(end)
+
+
 class Renderer:
     """Helper class for building DNS wire-format messages.
 
@@ -134,6 +152,15 @@ class Renderer:
             self._rollback(start)
             raise dns.exception.TooBig
 
+    @contextlib.contextmanager
+    def _temporarily_seek_to(self, where):
+        current = self.output.tell()
+        try:
+            self.output.seek(where)
+            yield
+        finally:
+            self.output.seek(current)
+
     def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
         """Add a question to the message."""
 
@@ -269,18 +296,14 @@ class Renderer:
         with self._track_size():
             keyname.to_wire(self.output, compress, self.origin)
             self.output.write(
-                struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0)
+                struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0)
             )
-            rdata_start = self.output.tell()
-            tsig.to_wire(self.output)
+            with prefixed_length(self.output, 2):
+                tsig.to_wire(self.output)
 
-        after = self.output.tell()
-        self.output.seek(rdata_start - 2)
-        self.output.write(struct.pack("!H", after - rdata_start))
         self.counts[ADDITIONAL] += 1
-        self.output.seek(10)
-        self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
-        self.output.seek(0, io.SEEK_END)
+        with self._temporarily_seek_to(10):
+            self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
 
     def write_header(self):
         """Write the DNS message header.
@@ -290,19 +313,18 @@ class Renderer:
         is added.
         """
 
-        self.output.seek(0)
-        self.output.write(
-            struct.pack(
-                "!HHHHHH",
-                self.id,
-                self.flags,
-                self.counts[0],
-                self.counts[1],
-                self.counts[2],
-                self.counts[3],
+        with self._temporarily_seek_to(0):
+            self.output.write(
+                struct.pack(
+                    "!HHHHHH",
+                    self.id,
+                    self.flags,
+                    self.counts[0],
+                    self.counts[1],
+                    self.counts[2],
+                    self.counts[3],
+                )
             )
-        )
-        self.output.seek(0, io.SEEK_END)
 
     def get_wire(self):
         """Return the wire format message."""