import dns.rdata
import dns.rdataclass
import dns.rdatatype
+import dns.renderer
import dns.set
import dns.ttl
want_shuffle = False
else:
rdclass = self.rdclass
- file.seek(0, io.SEEK_END)
if len(self) == 0:
name.to_wire(file, compress, origin)
- stuff = struct.pack("!HHIH", self.rdtype, rdclass, 0, 0)
- file.write(stuff)
+ file.write(struct.pack("!HHIH", self.rdtype, rdclass, 0, 0))
return 1
else:
l: Union[Rdataset, List[dns.rdata.Rdata]]
l = self
for rd in l:
name.to_wire(file, compress, origin)
- stuff = struct.pack("!HHIH", self.rdtype, rdclass, self.ttl, 0)
- file.write(stuff)
- start = file.tell()
- rd.to_wire(file, compress, origin)
- end = file.tell()
- assert end - start < 65536
- file.seek(start - 2)
- stuff = struct.pack("!H", end - start)
- file.write(stuff)
- file.seek(0, io.SEEK_END)
+ file.write(struct.pack("!HHI", self.rdtype, rdclass, self.ttl))
+ with dns.renderer.prefixed_length(file, 2):
+ rd.to_wire(file, compress, origin)
return len(self)
def match(
import dns.name
import dns.rdata
import dns.rdtypes.util
+import dns.renderer
import dns.tokenizer
import dns.wire
for key in sorted(self.params):
file.write(struct.pack("!H", key))
value = self.params[key]
- # placeholder for length (or actual length of empty values)
- file.write(struct.pack("!H", 0))
- if value is None:
- continue
- else:
- start = file.tell()
- value.to_wire(file, origin)
- end = file.tell()
- assert end - start < 65536
- file.seek(start - 2)
- stuff = struct.pack("!H", end - start)
- file.write(stuff)
- file.seek(0, io.SEEK_END)
+ with dns.renderer.prefixed_length(file, 2):
+ # Note that we're still writing a length of zero if the value is None
+ if value is not None:
+ value.to_wire(file, origin)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
import dns.exception
import dns.immutable
import dns.rdata
+import dns.renderer
import dns.tokenizer
def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
for s in self.strings:
- l = len(s)
- assert l < 256
- file.write(struct.pack("!B", l))
- file.write(s)
+ with dns.renderer.prefixed_length(file, 1):
+ file.write(s)
@classmethod
def from_wire_parser(cls, rdclass, rdtype, parser, origin=None):
ADDITIONAL = 3
+@contextlib.contextmanager
+def prefixed_length(output, length_length):
+ output.write(b"\00" * length_length)
+ start = output.tell()
+ yield
+ end = output.tell()
+ length = end - start
+ if length > 0:
+ try:
+ output.seek(start - length_length)
+ try:
+ output.write(length.to_bytes(length_length, "big"))
+ except OverflowError:
+ raise dns.exception.FormError
+ finally:
+ output.seek(end)
+
+
class Renderer:
"""Helper class for building DNS wire-format messages.
self._rollback(start)
raise dns.exception.TooBig
+ @contextlib.contextmanager
+ def _temporarily_seek_to(self, where):
+ current = self.output.tell()
+ try:
+ self.output.seek(where)
+ yield
+ finally:
+ self.output.seek(current)
+
def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
"""Add a question to the message."""
with self._track_size():
keyname.to_wire(self.output, compress, self.origin)
self.output.write(
- struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0)
+ struct.pack("!HHI", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0)
)
- rdata_start = self.output.tell()
- tsig.to_wire(self.output)
+ with prefixed_length(self.output, 2):
+ tsig.to_wire(self.output)
- after = self.output.tell()
- self.output.seek(rdata_start - 2)
- self.output.write(struct.pack("!H", after - rdata_start))
self.counts[ADDITIONAL] += 1
- self.output.seek(10)
- self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
- self.output.seek(0, io.SEEK_END)
+ with self._temporarily_seek_to(10):
+ self.output.write(struct.pack("!H", self.counts[ADDITIONAL]))
def write_header(self):
"""Write the DNS message header.
is added.
"""
- self.output.seek(0)
- self.output.write(
- struct.pack(
- "!HHHHHH",
- self.id,
- self.flags,
- self.counts[0],
- self.counts[1],
- self.counts[2],
- self.counts[3],
+ with self._temporarily_seek_to(0):
+ self.output.write(
+ struct.pack(
+ "!HHHHHH",
+ self.id,
+ self.flags,
+ self.counts[0],
+ self.counts[1],
+ self.counts[2],
+ self.counts[3],
+ )
)
- )
- self.output.seek(0, io.SEEK_END)
def get_wire(self):
"""Return the wire format message."""