"""Help for building DNS wire format messages"""
+import contextlib
import io
import struct
import random
raise dns.exception.FormError
self.section = section
+ @contextlib.contextmanager
+ def _track_size(self):
+ start = self.output.tell()
+ yield start
+ if self.output.tell() > self.max_size:
+ self._rollback(start)
+ raise dns.exception.TooBig
+
def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
"""Add a question to the message."""
self._set_section(QUESTION)
- before = self.output.tell()
- qname.to_wire(self.output, self.compress, self.origin)
- self.output.write(struct.pack("!HH", rdtype, rdclass))
- after = self.output.tell()
- if after >= self.max_size:
- self._rollback(before)
- raise dns.exception.TooBig
+ with self._track_size():
+ qname.to_wire(self.output, self.compress, self.origin)
+ self.output.write(struct.pack("!HH", rdtype, rdclass))
self.counts[QUESTION] += 1
def add_rrset(self, section, rrset, **kw):
"""
self._set_section(section)
- before = self.output.tell()
- n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
- after = self.output.tell()
- if after >= self.max_size:
- self._rollback(before)
- raise dns.exception.TooBig
+ with self._track_size():
+ n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
self.counts[section] += n
def add_rdataset(self, section, name, rdataset, **kw):
"""
self._set_section(section)
- before = self.output.tell()
- n = rdataset.to_wire(name, self.output, self.compress, self.origin,
- **kw)
- after = self.output.tell()
- if after >= self.max_size:
- self._rollback(before)
- raise dns.exception.TooBig
+ with self._track_size():
+ n = rdataset.to_wire(name, self.output, self.compress, self.origin,
+ **kw)
self.counts[section] += n
def add_edns(self, edns, ednsflags, payload, options=None):
ednsflags &= 0xFF00FFFF
ednsflags |= (edns << 16)
self._set_section(ADDITIONAL)
- before = self.output.tell()
- self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload,
- ednsflags, 0))
- if options is not None:
- lstart = self.output.tell()
- for opt in options:
- stuff = struct.pack("!HH", opt.otype, 0)
- self.output.write(stuff)
- start = self.output.tell()
- opt.to_wire(self.output)
- end = self.output.tell()
- assert end - start < 65536
- self.output.seek(start - 2)
- stuff = struct.pack("!H", end - start)
- self.output.write(stuff)
- self.output.seek(0, 2)
- lend = self.output.tell()
- assert lend - lstart < 65536
- self.output.seek(lstart - 2)
- stuff = struct.pack("!H", lend - lstart)
- self.output.write(stuff)
- self.output.seek(0, 2)
- after = self.output.tell()
- if after >= self.max_size:
- self._rollback(before)
- raise dns.exception.TooBig
+ with self._track_size():
+ self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT,
+ payload, ednsflags, 0))
+ if options is not None:
+ lstart = self.output.tell()
+ for opt in options:
+ owire = opt.to_wire()
+ self.output.write(struct.pack("!HH", opt.otype, len(owire)))
+ self.output.write(owire)
+ lend = self.output.tell()
+ self.output.seek(lstart - 2)
+ self.output.write(struct.pack("!H", lend - lstart))
+ self.output.seek(0, io.SEEK_END)
self.counts[ADDITIONAL] += 1
def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
def _write_tsig(self, tsig_rdata, keyname):
self._set_section(ADDITIONAL)
- before = self.output.tell()
-
- keyname.to_wire(self.output, self.compress, self.origin)
- self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
- dns.rdataclass.ANY, 0, 0))
- rdata_start = self.output.tell()
- self.output.write(tsig_rdata)
+ with self._track_size():
+ keyname.to_wire(self.output, self.compress, self.origin)
+ self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
+ dns.rdataclass.ANY, 0, 0))
+ rdata_start = self.output.tell()
+ self.output.write(tsig_rdata)
after = self.output.tell()
- assert after - rdata_start < 65536
- if after >= self.max_size:
- self._rollback(before)
- raise dns.exception.TooBig
-
self.output.seek(rdata_start - 2)
self.output.write(struct.pack('!H', after - rdata_start))
self.counts[ADDITIONAL] += 1
self.output.seek(10)
self.output.write(struct.pack('!H', self.counts[ADDITIONAL]))
- self.output.seek(0, 2)
+ self.output.seek(0, io.SEEK_END)
def write_header(self):
"""Write the DNS message header.
self.output.write(struct.pack('!HHHHHH', self.id, self.flags,
self.counts[0], self.counts[1],
self.counts[2], self.counts[3]))
- self.output.seek(0, 2)
+ self.output.seek(0, io.SEEK_END)
def get_wire(self):
"""Return the wire format message."""