]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Refactor common code between UDP and TCP; add basic exception handling
authorBob Halley <halley@dnspython.org>
Tue, 2 Jun 2020 14:05:46 +0000 (07:05 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 2 Jun 2020 14:05:46 +0000 (07:05 -0700)
tests/nanonameserver.py

index aaec009fd68c7781175ab382caec10638709a0d7..1b4a43497fbf7924d8f9c9ac36f021d39082f40c 100644 (file)
@@ -100,13 +100,58 @@ class Server(threading.Thread):
         # Handle message 'message'.  Override this method to change
         # how the server behaves.
         #
-        # The return value is either a dns.message.Message or a bytes.
-        # We allow a bytes to be returned for cases where handle wants
-        # to return an invalid DNS message for testing purposes.
+        # The return value is either a dns.message.Message, a bytes,
+        # or None.  We allow a bytes to be returned for cases where
+        # handle wants to return an invalid DNS message for testing
+        # purposes.  We allow None to be returned to indicate there is
+        # no response.
         #
-        r = dns.message.make_response(message)
-        r.set_rcode(dns.rcode.REFUSED)
-        return r
+        try:
+            r = dns.message.make_response(message)
+            r.set_rcode(dns.rcode.REFUSED)
+            return r
+        except Exception:
+            return None
+
+    def handle_wire(self, wire):
+        #
+        # This is the common code to parse wire format, call handle() on
+        # the message, and then generate resposne wire format (if handle()
+        # didn't do it).
+        #
+        # It also handles any exceptions from handle()
+        #
+        # Returns a wire format message to send, or None indicating there
+        # is nothing to do.
+        r = None
+        try:
+            q = dns.message.from_wire(wire)
+        except Exception:
+            # Try to make a FORMERR using just the question section.
+            try:
+                q = dns.message.from_wire(wire, question_only=True)
+                r = dns.message.make_response(q)
+                r.set_rcode(dns.rcode.FORMERR)
+            except Exception:
+                # We could try to make a response from only the header
+                # if dnspython had a header_only option to
+                # from_wire(), or if we truncated wire outselves, but
+                # for now we just drop.
+                return None
+        try:
+            # r might have been set above, so skip handle() if we
+            # already have a response.
+            if r is None:
+                r = self.handle(q)
+        except Exception:
+            # Exceptions from handle get a SERVFAIL response.
+            r = dns.message.make_response(q)
+            r.set_rcode(dns.rcode.SERVFAIL)
+        if isinstance(r, dns.message.Message):
+            wire = r.to_wire()
+        else:
+            wire = r
+        return wire
 
     async def serve_udp(self):
         with trio.socket.from_stdlib_socket(self.udp) as sock:
@@ -114,13 +159,9 @@ class Server(threading.Thread):
             while True:
                 try:
                     (wire, from_address) = await sock.recvfrom(65535)
-                    q = dns.message.from_wire(wire)
-                    r = self.handle(q)
-                    if isinstance(r, dns.message.Message):
-                        wire = r.to_wire()
-                    else:
-                        wire = r
-                    await sock.sendto(wire, from_address)
+                    wire = self.handle_wire(wire)
+                    if wire is not None:
+                        await sock.sendto(wire, from_address)
                 except Exception:
                     pass
 
@@ -130,15 +171,11 @@ class Server(threading.Thread):
                 ldata = await dns.trio.query.read_exactly(stream, 2)
                 (l,) = struct.unpack("!H", ldata)
                 wire = await dns.trio.query.read_exactly(stream, l)
-                q = dns.message.from_wire(wire)
-                r = self.handle(q)
-                if isinstance(r, dns.message.Message):
-                    wire = r.to_wire()
-                else:
-                    wire = r
-                l = len(wire)
-                stream_message = struct.pack("!H", l) + wire
-                await stream.send_all(stream_message)
+                wire = self.handle_wire(wire)
+                if wire is not None:
+                    l = len(wire)
+                    stream_message = struct.pack("!H", l) + wire
+                    await stream.send_all(stream_message)
         except Exception:
             pass