]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
set tcp_address correctly; allow handle() to return a list
authorBob Halley <halley@dnspython.org>
Sun, 14 Jun 2020 19:43:56 +0000 (12:43 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 14 Jun 2020 19:43:56 +0000 (12:43 -0700)
tests/nanonameserver.py

index b1bb6f2ed9d4acaec51aa46ff9da0189daf809bf..554c78b7918b5ba2982bc5bd0dddd4afed550aa4 100644 (file)
@@ -51,13 +51,14 @@ class Server(threading.Thread):
     """
 
     def __init__(self, address='127.0.0.1', port=0, enable_udp=True,
-                 enable_tcp=True, use_thread=True):
+                 enable_tcp=True, use_thread=True, origin=None):
         super().__init__()
         self.address = address
         self.port = port
         self.enable_udp = enable_udp
         self.enable_tcp = enable_tcp
         self.use_thread = use_thread
+        self.origin = origin
         self.left = None
         self.right = None
         self.udp = None
@@ -79,7 +80,7 @@ class Server(threading.Thread):
             self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
             self.tcp.bind((self.address, self.port))
             self.tcp.listen()
-            self.tcp_address = self.udp.getsockname()
+            self.tcp_address = self.tcp.getsockname()
         if self.use_thread:
             self.start()
         return self
@@ -118,10 +119,12 @@ class Server(threading.Thread):
         # how the server behaves.
         #
         # The return value is either a dns.message.Message, a bytes,
-        # or None.  We allow a bytes to be returned for cases where
-        # handle wants to return an invalid DNS message for testing
-        # purposes.  We allow None to be returned to indicate there is
-        # no response.
+        # None, or a list of one of those.  We allow a bytes to be
+        # returned for cases where handle wants to return an invalid
+        # DNS message for testing purposes.  We allow None to be
+        # returned to indicate there is no response.  If a list is
+        # returned, then the output code will run for each returned
+        # item.
         #
         try:
             r = dns.message.make_response(message)
@@ -130,54 +133,65 @@ class Server(threading.Thread):
         except Exception:
             return None
 
+    def maybe_listify(self, thing):
+        if isinstance(thing, list):
+            return thing
+        else:
+            return [thing]
+
+    def maybe_render(self, thing):
+        if isinstance(thing, dns.message.Message):
+            return thing.to_wire(self.origin)
+        else:
+            return thing
+
     def handle_wire(self, wire, peer, connection_type):
         #
         # This is the common code to parse wire format, call handle() on
-        # the message, and then generate resposne wire format (if handle()
+        # the message, and then generate response wire format (if handle()
         # didn't do it).
         #
         # It also handles any exceptions from handle()
         #
-        # Returns a wire format message to send, or None indicating there
-        # is nothing to do.
+        # Returns a (possibly empty) list of wire format message to send.
         #
         # XXXRTH It might be nice to have a "debug mode" in the server
         # where we'd print something in all the places we're eating
         # exceptions.  That way bugs in handle() would be easier to
         # find.
         #
+        items = []
         r = None
         try:
             q = dns.message.from_wire(wire)
         except dns.message.ShortHeader:
             # There is no hope of answering this one!
-            return None
+            return []
         except Exception:
             # Try to make a FORMERR using just the question section.
             try:
                 q = dns.message.from_wire(wire, question_only=True)
                 r = dns.message.make_response(q)
                 r.set_rcode(dns.rcode.FORMERR)
+                items.append(r)
             except Exception:
                 # We could try to make a response from only the header
                 # if dnspython had a header_only option to
                 # from_wire(), or if we truncated wire outselves, but
                 # for now we just drop.
-                return None
+                return []
         try:
-            # r might have been set above, so skip handle() if we
-            # already have a response.
-            if r is None:
-                r = self.handle(q, peer, connection_type)
+            # items might have been appended to above, so skip
+            # handle() if we already have a response.
+            if not items:
+                items = self.maybe_listify(self.handle(q, peer,
+                                                       connection_type))
         except Exception:
             # Exceptions from handle get a SERVFAIL response.
             r = dns.message.make_response(q)
             r.set_rcode(dns.rcode.SERVFAIL)
-        if isinstance(r, dns.message.Message):
-            wire = r.to_wire()
-        else:
-            wire = r
-        return wire
+            items = [r]
+        return [self.maybe_render(x) for x in items]
 
     async def serve_udp(self):
         with trio.socket.from_stdlib_socket(self.udp) as sock:
@@ -185,8 +199,8 @@ class Server(threading.Thread):
             while True:
                 try:
                     (wire, peer) = await sock.recvfrom(65535)
-                    wire = self.handle_wire(wire, peer, ConnectionType.UDP)
-                    if wire is not None:
+                    for wire in self.handle_wire(wire, peer,
+                                                 ConnectionType.UDP):
                         await sock.sendto(wire, peer)
                 except Exception:
                     pass
@@ -198,8 +212,7 @@ class Server(threading.Thread):
                 ldata = await read_exactly(stream, 2)
                 (l,) = struct.unpack("!H", ldata)
                 wire = await read_exactly(stream, l)
-                wire = self.handle_wire(wire, peer, ConnectionType.TCP)
-                if wire is not None:
+                for wire in self.handle_wire(wire, peer, ConnectionType.TCP):
                     l = len(wire)
                     stream_message = struct.pack("!H", l) + wire
                     await stream.send_all(stream_message)