From: Bob Halley Date: Fri, 26 Jun 2020 03:13:18 +0000 (-0700) Subject: new message class hierarchy and conversion of wire and text readers X-Git-Tag: v2.0.0rc2~61^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fd055443b6a796c2932a6e302f4af0bf1c1954de;p=thirdparty%2Fdnspython.git new message class hierarchy and conversion of wire and text readers --- diff --git a/dns/message.py b/dns/message.py index 3d92c11f..6272bbd4 100644 --- a/dns/message.py +++ b/dns/message.py @@ -78,6 +78,7 @@ class Truncated(dns.exception.DNSException): """ return self.kwargs['message'] + class MessageSection(dns.enum.IntEnum): """Message sections""" QUESTION = 0 @@ -403,7 +404,8 @@ class Message: method. *origin*, a ``dns.name.Name`` or ``None``, the origin to be appended - to any relative names. + to any relative names. If ``None``, and the message has an origin + attribute that is not ``None``, then it will be used. *max_size*, an ``int``, the maximum size of the wire format output; default is 0, which means "the message's request @@ -420,6 +422,8 @@ class Message: Returns a ``bytes``. """ + if origin is None and self.origin is not None: + origin = self.origin if max_size == 0: if self.request_payload != 0: max_size = self.request_payload @@ -601,6 +605,22 @@ class Message: self.flags &= 0x87FF self.flags |= dns.opcode.to_flags(opcode) + def _get_one_rr_per_rrset(self, value): + # What the caller picked is fine. + return value + + def _validate_rrset(self, section, rrset): + if rrset.rdclass == dns.rdataclass.ANY or \ + rrset.rdclass == dns.rdataclass.NONE: + raise dns.exception.FormError + + def _finish_section(self, section): + pass + + +class QueryMessage(Message): + pass + class _WireReader: @@ -611,22 +631,20 @@ class _WireReader: current: When building a message object from wire format, this variable contains the offset from the beginning of wire of the next octet to be read. - updating: Is the message a dynamic update? one_rr_per_rrset: Put each RR into its own RRset? ignore_trailing: Ignore trailing junk at end of request? zone_rdclass: The class of the zone in messages which are DNS dynamic updates. """ - def __init__(self, wire, message, question_only=False, - one_rr_per_rrset=False, ignore_trailing=False): + def __init__(self, wire, message, question_only, one_rr_per_rrset, + ignore_trailing): self.wire = dns.wiredata.maybe_wrap(wire) self.message = message self.current = 0 - self.updating = False self.zone_rdclass = dns.rdataclass.IN self.question_only = question_only - self.one_rr_per_rrset = one_rr_per_rrset + self.one_rr_per_rrset = message._get_one_rr_per_rrset(one_rr_per_rrset) self.ignore_trailing = ignore_trailing def _get_question(self, qcount): @@ -634,9 +652,6 @@ class _WireReader: the question section. """ - if self.updating and qcount > 1: - raise dns.exception.FormError - for i in range(qcount): (qname, used) = dns.name.from_wire(self.wire, self.current) if self.message.origin is not None: @@ -646,13 +661,13 @@ class _WireReader: struct.unpack('!HH', self.wire[self.current:self.current + 4]) self.current += 4 - self.message.find_rrset(self.message.question, qname, - rdclass, rdtype, create=True, - force_unique=True) - if self.updating: - self.zone_rdclass = rdclass + rrset = self.message.find_rrset(self.message.question, qname, + rdclass, rdtype, create=True, + force_unique=True) + self.message._validate_rrset(MessageSection.QUESTION, rrset) + self.message._finish_section(MessageSection.QUESTION) - def _get_section(self, section, count): + def _get_section(self, section_number, count): """Read the next I{count} records from the wire data and add them to the specified section. @@ -660,10 +675,8 @@ class _WireReader: count: the number of records to read """ - if self.updating or self.one_rr_per_rrset: - force_unique = True - else: - force_unique = False + section = self.message.sections[section_number] + force_unique = self.one_rr_per_rrset seen_opt = False for i in range(count): rr_start = self.current @@ -725,8 +738,7 @@ class _WireReader: else: if ttl > 0x7fffffff: ttl = 0 - if self.updating and \ - rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE): + if rdclass in (dns.rdataclass.ANY, dns.rdataclass.NONE): deleting = rdclass rdclass = self.zone_rdclass else: @@ -748,7 +760,14 @@ class _WireReader: deleting, True, force_unique) if rd is not None: rrset.add(rd, ttl) + # Note this validates the rrset every RR, but this is + # simpler than adding logic to remember a "last rrset" + # and validate it when the current rrset changes (or when + # we get to the end of the section. If it becomes a + # performance problem, we can do that. + self.message._validate_rrset(section_number, rrset) self.current += rdlen + self.message._finish_section(section_number) def read(self): """Read a wire format DNS message and build a dns.message.Message @@ -761,14 +780,12 @@ class _WireReader: aucount, adcount) = struct.unpack('!HHHHHH', self.wire[:12]) self.current = 12 self.message.original_id = self.message.id - if dns.opcode.is_update(self.message.flags): - self.updating = True self._get_question(qcount) if self.question_only: return - self._get_section(self.message.answer, ancount) - self._get_section(self.message.authority, aucount) - self._get_section(self.message.additional, adcount) + self._get_section(MessageSection.ANSWER, ancount) + self._get_section(MessageSection.AUTHORITY, aucount) + self._get_section(MessageSection.ADDITIONAL, adcount) if not self.ignore_trailing and self.current != l: raise TrailingJunk if self.message.multi and self.message.tsig_ctx and \ @@ -776,6 +793,31 @@ class _WireReader: self.message.tsig_ctx.update(self.wire) +def _maybe_import_update(): + # We avoid circular imports by doing this here. We do it in another + # function as doing it in _message_factory_from_opcode() makes "dns" + # a local symbol, and the first line fails :) + import dns.update + + +def _message_factory_from_opcode(opcode): + if opcode == dns.opcode.QUERY: + return QueryMessage + elif opcode == dns.opcode.UPDATE: + _maybe_import_update() + return dns.update.UpdateMessage + else: + return Message + + +def _message_factory_from_wire(wire): + if len(wire) < 12: + raise ShortHeader + (flags,) = struct.unpack('!H', wire[2:4]) + opcode = dns.opcode.from_flags(flags) + return _message_factory_from_opcode(opcode) + + def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, first=True, question_only=False, one_rr_per_rrset=False, @@ -835,7 +877,7 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, Returns a ``dns.message.Message``. """ - m = Message(id=0) + m = _message_factory_from_wire(wire)(id=0) m.keyring = keyring m.request_mac = request_mac m.xfr = xfr @@ -867,20 +909,25 @@ class _TextReader: tok: the tokenizer. message: The message object being built. - updating: Is the message a dynamic update? zone_rdclass: The class of the zone in messages which are DNS dynamic updates. last_name: The most recently read name when building a message object. one_rr_per_rrset: Put each RR into its own RRset? """ - def __init__(self, text, message, idna_codec, one_rr_per_rrset=False): - self.message = message + def __init__(self, text, idna_codec, one_rr_per_rrset=False): + self.message = None self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec) self.last_name = None self.zone_rdclass = dns.rdataclass.IN - self.updating = False self.one_rr_per_rrset = one_rr_per_rrset + self.id = None + self.edns = -1 + self.ednsflags = 0 + self.payload = None + self.rcode = None + self.opcode = dns.opcode.QUERY + self.flags = 0 def _header_line(self, section): """Process one line from the text format header section.""" @@ -888,49 +935,46 @@ class _TextReader: token = self.tok.get() what = token.value if what == 'id': - self.message.id = self.tok.get_int() + self.id = self.tok.get_int() elif what == 'flags': while True: token = self.tok.get() if not token.is_identifier(): self.tok.unget(token) break - self.message.flags = self.message.flags | \ - dns.flags.from_text(token.value) - if dns.opcode.is_update(self.message.flags): - self.updating = True + self.flags = self.flags | dns.flags.from_text(token.value) elif what == 'edns': - self.message.edns = self.tok.get_int() - self.message.ednsflags = self.message.ednsflags | \ - (self.message.edns << 16) + self.edns = self.tok.get_int() + self.ednsflags = self.ednsflags | (self.edns << 16) elif what == 'eflags': - if self.message.edns < 0: - self.message.edns = 0 + if self.edns < 0: + self.edns = 0 while True: token = self.tok.get() if not token.is_identifier(): self.tok.unget(token) break - self.message.ednsflags = self.message.ednsflags | \ + self.ednsflags = self.ednsflags | \ dns.flags.edns_from_text(token.value) elif what == 'payload': - self.message.payload = self.tok.get_int() - if self.message.edns < 0: - self.message.edns = 0 + self.payload = self.tok.get_int() + if self.edns < 0: + self.edns = 0 elif what == 'opcode': text = self.tok.get_string() - self.message.flags = self.message.flags | \ - dns.opcode.to_flags(dns.opcode.from_text(text)) + self.opcode = dns.opcode.from_text(text) + self.flags = self.flags | dns.opcode.to_flags(self.opcode) elif what == 'rcode': text = self.tok.get_string() - self.message.set_rcode(dns.rcode.from_text(text)) + self.rcode = dns.rcode.from_text(text) else: raise UnknownHeaderField self.tok.get_eol() - def _question_line(self, section): + def _question_line(self, section_number): """Process one line from the text format question section.""" + section = self.message.sections[section_number] token = self.tok.get(want_leading=True) if not token.is_whitespace(): self.last_name = self.tok.as_name(token, None) @@ -950,18 +994,19 @@ class _TextReader: rdclass = dns.rdataclass.IN # Type rdtype = dns.rdatatype.from_text(token.value) - self.message.find_rrset(self.message.question, name, - rdclass, rdtype, create=True, - force_unique=True) - if self.updating: - self.zone_rdclass = rdclass + rrset = self.message.find_rrset(self.message.question, name, + rdclass, rdtype, create=True, + force_unique=True) + self.message._validate_rrset(section_number, rrset) + self.zone_rdclass = rdclass self.tok.get_eol() - def _rr_line(self, section): + def _rr_line(self, section_number): """Process one line from the text format answer, authority, or additional data sections. """ + section = self.message.sections[section_number] deleting = None # Name token = self.tok.get(want_leading=True) @@ -1004,19 +1049,36 @@ class _TextReader: else: rd = None covers = dns.rdatatype.NONE - force_unique = self.updating or self.one_rr_per_rrset rrset = self.message.find_rrset(section, name, rdclass, rdtype, covers, - deleting, True, force_unique) + deleting, True, self.one_rr_per_rrset) if rd is not None: rrset.add(rd, ttl) + self.message._validate_rrset(section_number, rrset) + + def _maybe_instantiate_message(self): + if self.message is None: + # Time to instantiate the message! + factory = _message_factory_from_opcode(self.opcode) + self.message = factory(id=self.id) + self.message.flags = self.flags + if self.edns >= 0: + self.message.edns = self.edns + if self.ednsflags: + self.message.ednsflags = self.ednsflags + if self.payload: + self.message.payload = self.payload + if self.rcode: + self.message.set_rcode(self.rcode) + self.one_rr_per_rrset = \ + self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) def read(self): """Read a text format DNS message and build a dns.message.Message object.""" line_method = self._header_line - section = None + section_number = None while 1: token = self.tok.get(True, True) if token.is_eol_or_eof(): @@ -1025,22 +1087,24 @@ class _TextReader: u = token.value.upper() if u == 'HEADER': line_method = self._header_line - elif u == 'QUESTION' or u == 'ZONE': - line_method = self._question_line - section = self.message.question - elif u == 'ANSWER' or u == 'PREREQ': - line_method = self._rr_line - section = self.message.answer - elif u == 'AUTHORITY' or u == 'UPDATE': - line_method = self._rr_line - section = self.message.authority - elif u == 'ADDITIONAL': - line_method = self._rr_line - section = self.message.additional + elif u in {'QUESTION', 'ANSWER', 'AUTHORITY', 'ADDITIONAL', + 'ZONE', 'PREREQ', 'UPDATE'}: + # It's ugly, but we have to do the check above because + # if the token is JUST a comment, we want to ignore it, + # and not prehaps prematurely instantiate the message! + self._maybe_instantiate_message() + section_number = self.message._section_enum.make(u) + if section_number == 0: + line_method = self._question_line + else: + line_method = self._rr_line self.tok.get_eol() continue self.tok.unget(token) - line_method(section) + line_method(section_number) + for i in range(4): + self.message._finish_section(i) + return self.message def from_text(text, idna_codec=None, one_rr_per_rrset=False): @@ -1070,12 +1134,8 @@ def from_text(text, idna_codec=None, one_rr_per_rrset=False): # since it's an implementation detail. The official file # interface is from_file(). - m = Message() - - reader = _TextReader(text, m, idna_codec, one_rr_per_rrset) - reader.read() - - return m + reader = _TextReader(text, idna_codec, one_rr_per_rrset) + return reader.read() def from_file(f, idna_codec=None, one_rr_per_rrset=False): diff --git a/dns/update.py b/dns/update.py index 2aa8ca56..d57d0cd4 100644 --- a/dns/update.py +++ b/dns/update.py @@ -41,19 +41,22 @@ class UpdateSection(dns.enum.IntEnum): globals().update(UpdateSection.__members__) -class Update(dns.message.Message): +class UpdateMessage(dns.message.Message): _section_enum = UpdateSection - def __init__(self, zone, rdclass=dns.rdataclass.IN, keyring=None, - keyname=None, keyalgorithm=dns.tsig.default_algorithm): + def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None, + keyname=None, keyalgorithm=dns.tsig.default_algorithm, + id=None): """Initialize a new DNS Update object. See the documentation of the Message class for a complete description of the keyring dictionary. - *zone*, a ``dns.name.Name`` or ``str``, the zone which is being - updated. + *zone*, a ``dns.name.Name``, ``str``, or ``None``, the zone + which is being updated. ``None`` should only be used by dnspython's + message constructors, as a zone is required for the convenience + methods like ``add()``, ``replace()``, etc. *rdclass*, an ``int`` or ``str``, the class of the zone. @@ -68,16 +71,18 @@ class Update(dns.message.Message): to use; defaults to ``None``. The key must be defined in the keyring. *keyalgorithm*, a ``dns.name.Name``, the TSIG algorithm to use. + """ - super().__init__() + super().__init__(id=id) self.flags |= dns.opcode.to_flags(dns.opcode.UPDATE) if isinstance(zone, str): zone = dns.name.from_text(zone) self.origin = zone rdclass = dns.rdataclass.RdataClass.make(rdclass) self.zone_rdclass = rdclass - self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA, - create=True, force_unique=True) + if self.origin: + self.find_rrset(self.zone, self.origin, rdclass, dns.rdatatype.SOA, + create=True, force_unique=True) if keyring is not None: self.use_tsig(keyring, keyname, algorithm=keyalgorithm) @@ -288,23 +293,21 @@ class Update(dns.message.Message): dns.rdatatype.NONE, None, True, True) - def to_wire(self, origin=None, max_size=65535): - """Return a string containing the update in DNS compressed wire - format. - - *origin*, a ``dns.name.Name`` or ``None``, the origin to be - appended to any relative names. If *origin* is ``None``, then - the origin of the ``dns.update.Update`` message object is used - (i.e. the *zone* parameter passed when the Update object was - created). + def _get_one_rr_per_rrset(self, value): + # Updates are always one_rr_per_rrset + return True - *max_size*, an ``int``, the maximum size of the wire format - output; default is 0, which means "the message's request - payload, if nonzero, or 65535". + def _validate_rrset(self, section, rrset): + if section == UpdateSection.ZONE: + if rrset.rdtype != dns.rdatatype.SOA: + raise dns.exception.FormError - Returns a ``bytes``. - """ + def _finish_section(self, section): + if section == UpdateSection.ZONE and len(self.zone) != 1: + raise dns.exception.FormError + self.zone_rdclass = self.zone[0].rdclass + # We do NOT want to set origin here, as that would cause + # from_wire() relativization. - if origin is None: - origin = self.origin - return super().to_wire(origin, max_size) +# backwards compatibility +Update = UpdateMessage