From fe4b1c5ab511386b7fc7df7bdcbdf94fd7da5c9a Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 26 Jun 2020 11:30:29 -0700 Subject: [PATCH] Minor _WireReader refactoring. Instead of parsing the header to determine which Message subclass to create and passing that message to _WireReader, make _WireReader create the Message subclass itself. --- dns/message.py | 96 ++++++++++++++++++++++++++------------------------ 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/dns/message.py b/dns/message.py index a08b0517..00166caa 100644 --- a/dns/message.py +++ b/dns/message.py @@ -623,6 +623,23 @@ class QueryMessage(Message): pass +def _maybe_import_update(): + # We avoid circular imports by doing this here. We do it in another + # function as doing it in _message_factory_from_opcode() makes "dns" + # a local symbol, and the first line fails :) + import dns.update # noqa: F401 + + +def _message_factory_from_opcode(opcode): + if opcode == dns.opcode.QUERY: + return QueryMessage + elif opcode == dns.opcode.UPDATE: + _maybe_import_update() + return dns.update.UpdateMessage + else: + return Message + + class _WireReader: """Wire format reader. @@ -632,21 +649,24 @@ class _WireReader: current: When building a message object from wire format, this variable contains the offset from the beginning of wire of the next octet to be read. + initialize_message: Callback to set message parsing options + question_only: Are we only reading the question? one_rr_per_rrset: Put each RR into its own RRset? ignore_trailing: Ignore trailing junk at end of request? zone_rdclass: The class of the zone in messages which are DNS dynamic updates. """ - def __init__(self, wire, message, question_only, one_rr_per_rrset, - ignore_trailing): + def __init__(self, wire, initialize_message, question_only=False, + one_rr_per_rrset=False, ignore_trailing=False): self.wire = dns.wiredata.maybe_wrap(wire) - self.message = message + self.message = None self.current = 0 - self.zone_rdclass = dns.rdataclass.IN + self.initialize_message = initialize_message self.question_only = question_only - self.one_rr_per_rrset = message._get_one_rr_per_rrset(one_rr_per_rrset) + self.one_rr_per_rrset = one_rr_per_rrset self.ignore_trailing = ignore_trailing + self.zone_rdclass = dns.rdataclass.IN def _get_question(self, qcount): """Read the next *qcount* records from the wire data and add them to @@ -774,10 +794,15 @@ class _WireReader: l = len(self.wire) if l < 12: raise ShortHeader - (self.message.id, self.message.flags, qcount, ancount, - aucount, adcount) = struct.unpack('!HHHHHH', self.wire[:12]) + (id, flags, qcount, ancount, aucount, adcount) = \ + struct.unpack('!HHHHHH', self.wire[:12]) self.current = 12 - self.message.original_id = self.message.id + factory = _message_factory_from_opcode(dns.opcode.from_flags(flags)) + self.message = factory(id=id) + self.message.flags = flags + self.initialize_message(self.message) + self.one_rr_per_rrset = \ + self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) self._get_question(qcount) if self.question_only: return @@ -789,31 +814,7 @@ class _WireReader: if self.message.multi and self.message.tsig_ctx and \ not self.message.had_tsig: self.message.tsig_ctx.update(self.wire) - - -def _maybe_import_update(): - # We avoid circular imports by doing this here. We do it in another - # function as doing it in _message_factory_from_opcode() makes "dns" - # a local symbol, and the first line fails :) - import dns.update # noqa: F401 - - -def _message_factory_from_opcode(opcode): - if opcode == dns.opcode.QUERY: - return QueryMessage - elif opcode == dns.opcode.UPDATE: - _maybe_import_update() - return dns.update.UpdateMessage - else: - return Message - - -def _message_factory_from_wire(wire): - if len(wire) < 12: - raise ShortHeader - (flags,) = struct.unpack('!H', wire[2:4]) - opcode = dns.opcode.from_flags(flags) - return _message_factory_from_opcode(opcode) + return self.message def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, @@ -875,22 +876,23 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, Returns a ``dns.message.Message``. """ - m = _message_factory_from_wire(wire)(id=0) - m.keyring = keyring - m.request_mac = request_mac - m.xfr = xfr - m.origin = origin - m.tsig_ctx = tsig_ctx - m.multi = multi - m.first = first - - reader = _WireReader(wire, m, question_only, one_rr_per_rrset, - ignore_trailing) + def initialize_message(message): + message.keyring = keyring + message.request_mac = request_mac + message.xfr = xfr + message.origin = origin + message.tsig_ctx = tsig_ctx + message.multi = multi + message.first = first + + reader = _WireReader(wire, initialize_message, question_only, + one_rr_per_rrset, ignore_trailing) try: - reader.read() + m = reader.read() except dns.exception.FormError: - if m.flags & dns.flags.TC and raise_on_truncation: - raise Truncated(message=m) + if reader.message and (reader.message.flags & dns.flags.TC) and \ + raise_on_truncation: + raise Truncated(message=reader.message) else: raise # Reading a truncated message might not have any errors, so we -- 2.47.3