]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Minor _WireReader refactoring.
authorBrian Wellington <bwelling@xbill.org>
Fri, 26 Jun 2020 18:30:29 +0000 (11:30 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 26 Jun 2020 18:30:29 +0000 (11:30 -0700)
Instead of parsing the header to determine which Message subclass to
create and passing that message to _WireReader, make _WireReader create
the Message subclass itself.

dns/message.py

index a08b0517aefa2f8a9f999c8106d43a60632806d8..00166caa5b4fb774938ef66ccd5b8cf9b0c999c8 100644 (file)
@@ -623,6 +623,23 @@ class QueryMessage(Message):
     pass
 
 
+def _maybe_import_update():
+    # We avoid circular imports by doing this here.  We do it in another
+    # function as doing it in _message_factory_from_opcode() makes "dns"
+    # a local symbol, and the first line fails :)
+    import dns.update  # noqa: F401
+
+
+def _message_factory_from_opcode(opcode):
+    if opcode == dns.opcode.QUERY:
+        return QueryMessage
+    elif opcode == dns.opcode.UPDATE:
+        _maybe_import_update()
+        return dns.update.UpdateMessage
+    else:
+        return Message
+
+
 class _WireReader:
 
     """Wire format reader.
@@ -632,21 +649,24 @@ class _WireReader:
     current: When building a message object from wire format, this
     variable contains the offset from the beginning of wire of the next octet
     to be read.
+    initialize_message: Callback to set message parsing options
+    question_only: Are we only reading the question?
     one_rr_per_rrset: Put each RR into its own RRset?
     ignore_trailing: Ignore trailing junk at end of request?
     zone_rdclass: The class of the zone in messages which are
     DNS dynamic updates.
     """
 
-    def __init__(self, wire, message, question_only, one_rr_per_rrset,
-                 ignore_trailing):
+    def __init__(self, wire, initialize_message, question_only=False,
+                 one_rr_per_rrset=False, ignore_trailing=False):
         self.wire = dns.wiredata.maybe_wrap(wire)
-        self.message = message
+        self.message = None
         self.current = 0
-        self.zone_rdclass = dns.rdataclass.IN
+        self.initialize_message = initialize_message
         self.question_only = question_only
-        self.one_rr_per_rrset = message._get_one_rr_per_rrset(one_rr_per_rrset)
+        self.one_rr_per_rrset = one_rr_per_rrset
         self.ignore_trailing = ignore_trailing
+        self.zone_rdclass = dns.rdataclass.IN
 
     def _get_question(self, qcount):
         """Read the next *qcount* records from the wire data and add them to
@@ -774,10 +794,15 @@ class _WireReader:
         l = len(self.wire)
         if l < 12:
             raise ShortHeader
-        (self.message.id, self.message.flags, qcount, ancount,
-         aucount, adcount) = struct.unpack('!HHHHHH', self.wire[:12])
+        (id, flags, qcount, ancount, aucount, adcount) = \
+            struct.unpack('!HHHHHH', self.wire[:12])
         self.current = 12
-        self.message.original_id = self.message.id
+        factory = _message_factory_from_opcode(dns.opcode.from_flags(flags))
+        self.message = factory(id=id)
+        self.message.flags = flags
+        self.initialize_message(self.message)
+        self.one_rr_per_rrset = \
+            self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
         self._get_question(qcount)
         if self.question_only:
             return
@@ -789,31 +814,7 @@ class _WireReader:
         if self.message.multi and self.message.tsig_ctx and \
                 not self.message.had_tsig:
             self.message.tsig_ctx.update(self.wire)
-
-
-def _maybe_import_update():
-    # We avoid circular imports by doing this here.  We do it in another
-    # function as doing it in _message_factory_from_opcode() makes "dns"
-    # a local symbol, and the first line fails :)
-    import dns.update  # noqa: F401
-
-
-def _message_factory_from_opcode(opcode):
-    if opcode == dns.opcode.QUERY:
-        return QueryMessage
-    elif opcode == dns.opcode.UPDATE:
-        _maybe_import_update()
-        return dns.update.UpdateMessage
-    else:
-        return Message
-
-
-def _message_factory_from_wire(wire):
-    if len(wire) < 12:
-        raise ShortHeader
-    (flags,) = struct.unpack('!H', wire[2:4])
-    opcode = dns.opcode.from_flags(flags)
-    return _message_factory_from_opcode(opcode)
+        return self.message
 
 
 def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
@@ -875,22 +876,23 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     Returns a ``dns.message.Message``.
     """
 
-    m = _message_factory_from_wire(wire)(id=0)
-    m.keyring = keyring
-    m.request_mac = request_mac
-    m.xfr = xfr
-    m.origin = origin
-    m.tsig_ctx = tsig_ctx
-    m.multi = multi
-    m.first = first
-
-    reader = _WireReader(wire, m, question_only, one_rr_per_rrset,
-                         ignore_trailing)
+    def initialize_message(message):
+        message.keyring = keyring
+        message.request_mac = request_mac
+        message.xfr = xfr
+        message.origin = origin
+        message.tsig_ctx = tsig_ctx
+        message.multi = multi
+        message.first = first
+
+    reader = _WireReader(wire, initialize_message, question_only,
+                         one_rr_per_rrset, ignore_trailing)
     try:
-        reader.read()
+        m = reader.read()
     except dns.exception.FormError:
-        if m.flags & dns.flags.TC and raise_on_truncation:
-            raise Truncated(message=m)
+        if reader.message and (reader.message.flags & dns.flags.TC) and \
+           raise_on_truncation:
+            raise Truncated(message=reader.message)
         else:
             raise
     # Reading a truncated message might not have any errors, so we