"""
return self.kwargs['message']
+
class MessageSection(dns.enum.IntEnum):
"""Message sections"""
QUESTION = 0
method.
*origin*, a ``dns.name.Name`` or ``None``, the origin to be appended
- to any relative names.
+ to any relative names. If ``None``, and the message has an origin
+ attribute that is not ``None``, then it will be used.
*max_size*, an ``int``, the maximum size of the wire format
output; default is 0, which means "the message's request
Returns a ``bytes``.
"""
+ if origin is None and self.origin is not None:
+ origin = self.origin
if max_size == 0:
if self.request_payload != 0:
max_size = self.request_payload
self.flags &= 0x87FF
self.flags |= dns.opcode.to_flags(opcode)
+ def _get_one_rr_per_rrset(self, value):
+ # What the caller picked is fine.
+ return value
+
+ def _validate_rrset(self, section, rrset):
+ if rrset.rdclass == dns.rdataclass.ANY or \
+ rrset.rdclass == dns.rdataclass.NONE:
+ raise dns.exception.FormError
+
+ def _finish_section(self, section):
+ pass
+
+
+class QueryMessage(Message):
+ pass
+
class _WireReader:
current: When building a message object from wire format, this
variable contains the offset from the beginning of wire of the next octet
to be read.
- updating: Is the message a dynamic update?
one_rr_per_rrset: Put each RR into its own RRset?
ignore_trailing: Ignore trailing junk at end of request?
zone_rdclass: The class of the zone in messages which are
DNS dynamic updates.
"""
- def __init__(self, wire, message, question_only=False,
- one_rr_per_rrset=False, ignore_trailing=False):
+ def __init__(self, wire, message, question_only, one_rr_per_rrset,
+ ignore_trailing):
self.wire = dns.wiredata.maybe_wrap(wire)
self.message = message
self.current = 0
- self.updating = False
self.zone_rdclass = dns.rdataclass.IN
self.question_only = question_only
- self.one_rr_per_rrset = one_rr_per_rrset
+ self.one_rr_per_rrset = message._get_one_rr_per_rrset(one_rr_per_rrset)
self.ignore_trailing = ignore_trailing
def _get_question(self, qcount):
the question section.
"""
- if self.updating and qcount > 1:
- raise dns.exception.FormError
-
for i in range(qcount):
(qname, used) = dns.name.from_wire(self.wire, self.current)
if self.message.origin is not None:
struct.unpack('!HH',
self.wire[self.current:self.current + 4])
self.current += 4
- self.message.find_rrset(self.message.question, qname,
- rdclass, rdtype, create=True,
- force_unique=True)
- if self.updating:
- self.zone_rdclass = rdclass
+ rrset = self.message.find_rrset(self.message.question, qname,
+ rdclass, rdtype, create=True,
+ force_unique=True)
+ self.message._validate_rrset(MessageSection.QUESTION, rrset)
+ self.message._finish_section(MessageSection.QUESTION)
- def _get_section(self, section, count):
+ def _get_section(self, section_number, count):
"""Read the next I{count} records from the wire data and add them to
the specified section.
count: the number of records to read
"""
- if self.updating or self.one_rr_per_rrset:
- force_unique = True
- else:
- force_unique = False
+ section = self.message.sections[section_number]
+ force_unique = self.one_rr_per_rrset
seen_opt = False
for i in range(count):
rr_start = self.current
else:
if ttl > 0x7fffffff:
ttl = 0
- if self.updating and \
- rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE):
+ if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE):
deleting = rdclass
rdclass = self.zone_rdclass
else:
deleting, True, force_unique)
if rd is not None:
rrset.add(rd, ttl)
+ # Note this validates the rrset every RR, but this is
+ # simpler than adding logic to remember a "last rrset"
+ # and validate it when the current rrset changes (or when
+ # we get to the end of the section. If it becomes a
+ # performance problem, we can do that.
+ self.message._validate_rrset(section_number, rrset)
self.current += rdlen
+ self.message._finish_section(section_number)
def read(self):
"""Read a wire format DNS message and build a dns.message.Message
aucount, adcount) = struct.unpack('!HHHHHH', self.wire[:12])
self.current = 12
self.message.original_id = self.message.id
- if dns.opcode.is_update(self.message.flags):
- self.updating = True
self._get_question(qcount)
if self.question_only:
return
- self._get_section(self.message.answer, ancount)
- self._get_section(self.message.authority, aucount)
- self._get_section(self.message.additional, adcount)
+ self._get_section(MessageSection.ANSWER, ancount)
+ self._get_section(MessageSection.AUTHORITY, aucount)
+ self._get_section(MessageSection.ADDITIONAL, adcount)
if not self.ignore_trailing and self.current != l:
raise TrailingJunk
if self.message.multi and self.message.tsig_ctx and \
self.message.tsig_ctx.update(self.wire)
+def _maybe_import_update():
+ # We avoid circular imports by doing this here. We do it in another
+ # function as doing it in _message_factory_from_opcode() makes "dns"
+ # a local symbol, and the first line fails :)
+ import dns.update
+
+
+def _message_factory_from_opcode(opcode):
+ if opcode == dns.opcode.QUERY:
+ return QueryMessage
+ elif opcode == dns.opcode.UPDATE:
+ _maybe_import_update()
+ return dns.update.UpdateMessage
+ else:
+ return Message
+
+
+def _message_factory_from_wire(wire):
+ if len(wire) < 12:
+ raise ShortHeader
+ (flags,) = struct.unpack('!H', wire[2:4])
+ opcode = dns.opcode.from_flags(flags)
+ return _message_factory_from_opcode(opcode)
+
+
def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
tsig_ctx=None, multi=False, first=True,
question_only=False, one_rr_per_rrset=False,
Returns a ``dns.message.Message``.
"""
- m = Message(id=0)
+ m = _message_factory_from_wire(wire)(id=0)
m.keyring = keyring
m.request_mac = request_mac
m.xfr = xfr
tok: the tokenizer.
message: The message object being built.
- updating: Is the message a dynamic update?
zone_rdclass: The class of the zone in messages which are
DNS dynamic updates.
last_name: The most recently read name when building a message object.
one_rr_per_rrset: Put each RR into its own RRset?
"""
- def __init__(self, text, message, idna_codec, one_rr_per_rrset=False):
- self.message = message
+ def __init__(self, text, idna_codec, one_rr_per_rrset=False):
+ self.message = None
self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec)
self.last_name = None
self.zone_rdclass = dns.rdataclass.IN
- self.updating = False
self.one_rr_per_rrset = one_rr_per_rrset
+ self.id = None
+ self.edns = -1
+ self.ednsflags = 0
+ self.payload = None
+ self.rcode = None
+ self.opcode = dns.opcode.QUERY
+ self.flags = 0
def _header_line(self, section):
"""Process one line from the text format header section."""
token = self.tok.get()
what = token.value
if what == 'id':
- self.message.id = self.tok.get_int()
+ self.id = self.tok.get_int()
elif what == 'flags':
while True:
token = self.tok.get()
if not token.is_identifier():
self.tok.unget(token)
break
- self.message.flags = self.message.flags | \
- dns.flags.from_text(token.value)
- if dns.opcode.is_update(self.message.flags):
- self.updating = True
+ self.flags = self.flags | dns.flags.from_text(token.value)
elif what == 'edns':
- self.message.edns = self.tok.get_int()
- self.message.ednsflags = self.message.ednsflags | \
- (self.message.edns << 16)
+ self.edns = self.tok.get_int()
+ self.ednsflags = self.ednsflags | (self.edns << 16)
elif what == 'eflags':
- if self.message.edns < 0:
- self.message.edns = 0
+ if self.edns < 0:
+ self.edns = 0
while True:
token = self.tok.get()
if not token.is_identifier():
self.tok.unget(token)
break
- self.message.ednsflags = self.message.ednsflags | \
+ self.ednsflags = self.ednsflags | \
dns.flags.edns_from_text(token.value)
elif what == 'payload':
- self.message.payload = self.tok.get_int()
- if self.message.edns < 0:
- self.message.edns = 0
+ self.payload = self.tok.get_int()
+ if self.edns < 0:
+ self.edns = 0
elif what == 'opcode':
text = self.tok.get_string()
- self.message.flags = self.message.flags | \
- dns.opcode.to_flags(dns.opcode.from_text(text))
+ self.opcode = dns.opcode.from_text(text)
+ self.flags = self.flags | dns.opcode.to_flags(self.opcode)
elif what == 'rcode':
text = self.tok.get_string()
- self.message.set_rcode(dns.rcode.from_text(text))
+ self.rcode = dns.rcode.from_text(text)
else:
raise UnknownHeaderField
self.tok.get_eol()
- def _question_line(self, section):
+ def _question_line(self, section_number):
"""Process one line from the text format question section."""
+ section = self.message.sections[section_number]
token = self.tok.get(want_leading=True)
if not token.is_whitespace():
self.last_name = self.tok.as_name(token, None)
rdclass = dns.rdataclass.IN
# Type
rdtype = dns.rdatatype.from_text(token.value)
- self.message.find_rrset(self.message.question, name,
- rdclass, rdtype, create=True,
- force_unique=True)
- if self.updating:
- self.zone_rdclass = rdclass
+ rrset = self.message.find_rrset(self.message.question, name,
+ rdclass, rdtype, create=True,
+ force_unique=True)
+ self.message._validate_rrset(section_number, rrset)
+ self.zone_rdclass = rdclass
self.tok.get_eol()
- def _rr_line(self, section):
+ def _rr_line(self, section_number):
"""Process one line from the text format answer, authority, or
additional data sections.
"""
+ section = self.message.sections[section_number]
deleting = None
# Name
token = self.tok.get(want_leading=True)
else:
rd = None
covers = dns.rdatatype.NONE
- force_unique = self.updating or self.one_rr_per_rrset
rrset = self.message.find_rrset(section, name,
rdclass, rdtype, covers,
- deleting, True, force_unique)
+ deleting, True, self.one_rr_per_rrset)
if rd is not None:
rrset.add(rd, ttl)
+ self.message._validate_rrset(section_number, rrset)
+
+ def _maybe_instantiate_message(self):
+ if self.message is None:
+ # Time to instantiate the message!
+ factory = _message_factory_from_opcode(self.opcode)
+ self.message = factory(id=self.id)
+ self.message.flags = self.flags
+ if self.edns >= 0:
+ self.message.edns = self.edns
+ if self.ednsflags:
+ self.message.ednsflags = self.ednsflags
+ if self.payload:
+ self.message.payload = self.payload
+ if self.rcode:
+ self.message.set_rcode(self.rcode)
+ self.one_rr_per_rrset = \
+ self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
def read(self):
"""Read a text format DNS message and build a dns.message.Message
object."""
line_method = self._header_line
- section = None
+ section_number = None
while 1:
token = self.tok.get(True, True)
if token.is_eol_or_eof():
u = token.value.upper()
if u == 'HEADER':
line_method = self._header_line
- elif u == 'QUESTION' or u == 'ZONE':
- line_method = self._question_line
- section = self.message.question
- elif u == 'ANSWER' or u == 'PREREQ':
- line_method = self._rr_line
- section = self.message.answer
- elif u == 'AUTHORITY' or u == 'UPDATE':
- line_method = self._rr_line
- section = self.message.authority
- elif u == 'ADDITIONAL':
- line_method = self._rr_line
- section = self.message.additional
+ elif u in {'QUESTION', 'ANSWER', 'AUTHORITY', 'ADDITIONAL',
+ 'ZONE', 'PREREQ', 'UPDATE'}:
+ # It's ugly, but we have to do the check above because
+ # if the token is JUST a comment, we want to ignore it,
+ # and not prehaps prematurely instantiate the message!
+ self._maybe_instantiate_message()
+ section_number = self.message._section_enum.make(u)
+ if section_number == 0:
+ line_method = self._question_line
+ else:
+ line_method = self._rr_line
self.tok.get_eol()
continue
self.tok.unget(token)
- line_method(section)
+ line_method(section_number)
+ for i in range(4):
+ self.message._finish_section(i)
+ return self.message
def from_text(text, idna_codec=None, one_rr_per_rrset=False):
# since it's an implementation detail. The official file
# interface is from_file().
- m = Message()
-
- reader = _TextReader(text, m, idna_codec, one_rr_per_rrset)
- reader.read()
-
- return m
+ reader = _TextReader(text, idna_codec, one_rr_per_rrset)
+ return reader.read()
def from_file(f, idna_codec=None, one_rr_per_rrset=False):
globals().update(UpdateSection.__members__)
-class Update(dns.message.Message):
+class UpdateMessage(dns.message.Message):
_section_enum = UpdateSection
- def __init__(self, zone, rdclass=dns.rdataclass.IN, keyring=None,
- keyname=None, keyalgorithm=dns.tsig.default_algorithm):
+ def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None,
+ keyname=None, keyalgorithm=dns.tsig.default_algorithm,
+ id=None):
"""Initialize a new DNS Update object.
See the documentation of the Message class for a complete
description of the keyring dictionary.
- *zone*, a ``dns.name.Name`` or ``str``, the zone which is being
- updated.
+ *zone*, a ``dns.name.Name``, ``str``, or ``None``, the zone
+ which is being updated. ``None`` should only be used by dnspython's
+ message constructors, as a zone is required for the convenience
+ methods like ``add()``, ``replace()``, etc.
*rdclass*, an ``int`` or ``str``, the class of the zone.
to use; defaults to ``None``. The key must be defined in the keyring.
*keyalgorithm*, a ``dns.name.Name``, the TSIG algorithm to use.
+
"""
- super().__init__()
+ super().__init__(id=id)
self.flags |= dns.opcode.to_flags(dns.opcode.UPDATE)
if isinstance(zone, str):
zone = dns.name.from_text(zone)
self.origin = zone
rdclass = dns.rdataclass.RdataClass.make(rdclass)
self.zone_rdclass = rdclass
- self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA,
- create=True, force_unique=True)
+ if self.origin:
+ self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA,
+ create=True, force_unique=True)
if keyring is not None:
self.use_tsig(keyring, keyname, algorithm=keyalgorithm)
dns.rdatatype.NONE, None,
True, True)
- def to_wire(self, origin=None, max_size=65535):
- """Return a string containing the update in DNS compressed wire
- format.
-
- *origin*, a ``dns.name.Name`` or ``None``, the origin to be
- appended to any relative names. If *origin* is ``None``, then
- the origin of the ``dns.update.Update`` message object is used
- (i.e. the *zone* parameter passed when the Update object was
- created).
+ def _get_one_rr_per_rrset(self, value):
+ # Updates are always one_rr_per_rrset
+ return True
- *max_size*, an ``int``, the maximum size of the wire format
- output; default is 0, which means "the message's request
- payload, if nonzero, or 65535".
+ def _validate_rrset(self, section, rrset):
+ if section == UpdateSection.ZONE:
+ if rrset.rdtype != dns.rdatatype.SOA:
+ raise dns.exception.FormError
- Returns a ``bytes``.
- """
+ def _finish_section(self, section):
+ if section == UpdateSection.ZONE and len(self.zone) != 1:
+ raise dns.exception.FormError
+ self.zone_rdclass = self.zone[0].rdclass
+ # We do NOT want to set origin here, as that would cause
+ # from_wire() relativization.
- if origin is None:
- origin = self.origin
- return super().to_wire(origin, max_size)
+# backwards compatibility
+Update = UpdateMessage