From 444d448d5e136d45ae3884f0c91aa4ef40fd646f Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 14 Jun 2020 12:43:56 -0700 Subject: [PATCH] set tcp_address correctly; allow handle() to return a list --- tests/nanonameserver.py | 61 +++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/tests/nanonameserver.py b/tests/nanonameserver.py index b1bb6f2e..554c78b7 100644 --- a/tests/nanonameserver.py +++ b/tests/nanonameserver.py @@ -51,13 +51,14 @@ class Server(threading.Thread): """ def __init__(self, address='127.0.0.1', port=0, enable_udp=True, - enable_tcp=True, use_thread=True): + enable_tcp=True, use_thread=True, origin=None): super().__init__() self.address = address self.port = port self.enable_udp = enable_udp self.enable_tcp = enable_tcp self.use_thread = use_thread + self.origin = origin self.left = None self.right = None self.udp = None @@ -79,7 +80,7 @@ class Server(threading.Thread): self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.tcp.bind((self.address, self.port)) self.tcp.listen() - self.tcp_address = self.udp.getsockname() + self.tcp_address = self.tcp.getsockname() if self.use_thread: self.start() return self @@ -118,10 +119,12 @@ class Server(threading.Thread): # how the server behaves. # # The return value is either a dns.message.Message, a bytes, - # or None. We allow a bytes to be returned for cases where - # handle wants to return an invalid DNS message for testing - # purposes. We allow None to be returned to indicate there is - # no response. + # None, or a list of one of those. We allow a bytes to be + # returned for cases where handle wants to return an invalid + # DNS message for testing purposes. We allow None to be + # returned to indicate there is no response. If a list is + # returned, then the output code will run for each returned + # item. # try: r = dns.message.make_response(message) @@ -130,54 +133,65 @@ class Server(threading.Thread): except Exception: return None + def maybe_listify(self, thing): + if isinstance(thing, list): + return thing + else: + return [thing] + + def maybe_render(self, thing): + if isinstance(thing, dns.message.Message): + return thing.to_wire(self.origin) + else: + return thing + def handle_wire(self, wire, peer, connection_type): # # This is the common code to parse wire format, call handle() on - # the message, and then generate resposne wire format (if handle() + # the message, and then generate response wire format (if handle() # didn't do it). # # It also handles any exceptions from handle() # - # Returns a wire format message to send, or None indicating there - # is nothing to do. + # Returns a (possibly empty) list of wire format message to send. # # XXXRTH It might be nice to have a "debug mode" in the server # where we'd print something in all the places we're eating # exceptions. That way bugs in handle() would be easier to # find. # + items = [] r = None try: q = dns.message.from_wire(wire) except dns.message.ShortHeader: # There is no hope of answering this one! - return None + return [] except Exception: # Try to make a FORMERR using just the question section. try: q = dns.message.from_wire(wire, question_only=True) r = dns.message.make_response(q) r.set_rcode(dns.rcode.FORMERR) + items.append(r) except Exception: # We could try to make a response from only the header # if dnspython had a header_only option to # from_wire(), or if we truncated wire outselves, but # for now we just drop. - return None + return [] try: - # r might have been set above, so skip handle() if we - # already have a response. - if r is None: - r = self.handle(q, peer, connection_type) + # items might have been appended to above, so skip + # handle() if we already have a response. + if not items: + items = self.maybe_listify(self.handle(q, peer, + connection_type)) except Exception: # Exceptions from handle get a SERVFAIL response. r = dns.message.make_response(q) r.set_rcode(dns.rcode.SERVFAIL) - if isinstance(r, dns.message.Message): - wire = r.to_wire() - else: - wire = r - return wire + items = [r] + return [self.maybe_render(x) for x in items] async def serve_udp(self): with trio.socket.from_stdlib_socket(self.udp) as sock: @@ -185,8 +199,8 @@ class Server(threading.Thread): while True: try: (wire, peer) = await sock.recvfrom(65535) - wire = self.handle_wire(wire, peer, ConnectionType.UDP) - if wire is not None: + for wire in self.handle_wire(wire, peer, + ConnectionType.UDP): await sock.sendto(wire, peer) except Exception: pass @@ -198,8 +212,7 @@ class Server(threading.Thread): ldata = await read_exactly(stream, 2) (l,) = struct.unpack("!H", ldata) wire = await read_exactly(stream, l) - wire = self.handle_wire(wire, peer, ConnectionType.TCP) - if wire is not None: + for wire in self.handle_wire(wire, peer, ConnectionType.TCP): l = len(wire) stream_message = struct.pack("!H", l) + wire await stream.send_all(stream_message) -- 2.47.3