pass
+def _maybe_import_update():
+ # We avoid circular imports by doing this here. We do it in another
+ # function as doing it in _message_factory_from_opcode() makes "dns"
+ # a local symbol, and the first line fails :)
+ import dns.update # noqa: F401
+
+
+def _message_factory_from_opcode(opcode):
+ if opcode == dns.opcode.QUERY:
+ return QueryMessage
+ elif opcode == dns.opcode.UPDATE:
+ _maybe_import_update()
+ return dns.update.UpdateMessage
+ else:
+ return Message
+
+
class _WireReader:
"""Wire format reader.
current: When building a message object from wire format, this
variable contains the offset from the beginning of wire of the next octet
to be read.
+ initialize_message: Callback to set message parsing options
+ question_only: Are we only reading the question?
one_rr_per_rrset: Put each RR into its own RRset?
ignore_trailing: Ignore trailing junk at end of request?
zone_rdclass: The class of the zone in messages which are
DNS dynamic updates.
"""
- def __init__(self, wire, message, question_only, one_rr_per_rrset,
- ignore_trailing):
+ def __init__(self, wire, initialize_message, question_only=False,
+ one_rr_per_rrset=False, ignore_trailing=False):
self.wire = dns.wiredata.maybe_wrap(wire)
- self.message = message
+ self.message = None
self.current = 0
- self.zone_rdclass = dns.rdataclass.IN
+ self.initialize_message = initialize_message
self.question_only = question_only
- self.one_rr_per_rrset = message._get_one_rr_per_rrset(one_rr_per_rrset)
+ self.one_rr_per_rrset = one_rr_per_rrset
self.ignore_trailing = ignore_trailing
+ self.zone_rdclass = dns.rdataclass.IN
def _get_question(self, qcount):
"""Read the next *qcount* records from the wire data and add them to
l = len(self.wire)
if l < 12:
raise ShortHeader
- (self.message.id, self.message.flags, qcount, ancount,
- aucount, adcount) = struct.unpack('!HHHHHH', self.wire[:12])
+ (id, flags, qcount, ancount, aucount, adcount) = \
+ struct.unpack('!HHHHHH', self.wire[:12])
self.current = 12
- self.message.original_id = self.message.id
+ factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
+ self.message = factory(id=id)
+ self.message.flags = flags
+ self.initialize_message(self.message)
+ self.one_rr_per_rrset = \
+ self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
self._get_question(qcount)
if self.question_only:
return
if self.message.multi and self.message.tsig_ctx and \
not self.message.had_tsig:
self.message.tsig_ctx.update(self.wire)
-
-
-def _maybe_import_update():
- # We avoid circular imports by doing this here. We do it in another
- # function as doing it in _message_factory_from_opcode() makes "dns"
- # a local symbol, and the first line fails :)
- import dns.update # noqa: F401
-
-
-def _message_factory_from_opcode(opcode):
- if opcode == dns.opcode.QUERY:
- return QueryMessage
- elif opcode == dns.opcode.UPDATE:
- _maybe_import_update()
- return dns.update.UpdateMessage
- else:
- return Message
-
-
-def _message_factory_from_wire(wire):
- if len(wire) < 12:
- raise ShortHeader
- (flags,) = struct.unpack('!H', wire[2:4])
- opcode = dns.opcode.from_flags(flags)
- return _message_factory_from_opcode(opcode)
+ return self.message
def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
Returns a ``dns.message.Message``.
"""
- m = _message_factory_from_wire(wire)(id=0)
- m.keyring = keyring
- m.request_mac = request_mac
- m.xfr = xfr
- m.origin = origin
- m.tsig_ctx = tsig_ctx
- m.multi = multi
- m.first = first
-
- reader = _WireReader(wire, m, question_only, one_rr_per_rrset,
- ignore_trailing)
+ def initialize_message(message):
+ message.keyring = keyring
+ message.request_mac = request_mac
+ message.xfr = xfr
+ message.origin = origin
+ message.tsig_ctx = tsig_ctx
+ message.multi = multi
+ message.first = first
+
+ reader = _WireReader(wire, initialize_message, question_only,
+ one_rr_per_rrset, ignore_trailing)
try:
- reader.read()
+ m = reader.read()
except dns.exception.FormError:
- if m.flags & dns.flags.TC and raise_on_truncation:
- raise Truncated(message=m)
+ if reader.message and (reader.message.flags & dns.flags.TC) and \
+ raise_on_truncation:
+ raise Truncated(message=reader.message)
else:
raise
# Reading a truncated message might not have any errors, so we