From: Bob Halley Date: Tue, 2 Jun 2020 14:22:42 +0000 (-0700) Subject: pass peer and connection type to nanoserver handle() X-Git-Tag: v2.0.0rc1~126 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=842dc3a0c1466126d547a05b32abd29150ecc1d2;p=thirdparty%2Fdnspython.git pass peer and connection type to nanoserver handle() --- diff --git a/tests/nanonameserver.py b/tests/nanonameserver.py index 1b4a4349..a14d9259 100644 --- a/tests/nanonameserver.py +++ b/tests/nanonameserver.py @@ -1,6 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license import contextlib +import enum import functools import socket import struct @@ -11,6 +12,9 @@ import dns.message import dns.rcode import dns.trio.query +class ConnectionType(enum.IntEnum): + UDP = 1 + TCP = 2 class Server(threading.Thread): @@ -95,7 +99,7 @@ class Server(threading.Thread): finally: raise EOFError - def handle(self, message): + def handle(self, message, peer, connection_type): # # Handle message 'message'. Override this method to change # how the server behaves. @@ -113,7 +117,7 @@ class Server(threading.Thread): except Exception: return None - def handle_wire(self, wire): + def handle_wire(self, wire, peer, connection_type): # # This is the common code to parse wire format, call handle() on # the message, and then generate resposne wire format (if handle() @@ -123,6 +127,12 @@ class Server(threading.Thread): # # Returns a wire format message to send, or None indicating there # is nothing to do. + # + # XXXRTH It might be nice to have a "debug mode" in the server + # where we'd print something in all the places we're eating + # exceptions. That way bugs in handle() would be easier to + # find. + # r = None try: q = dns.message.from_wire(wire) @@ -142,7 +152,7 @@ class Server(threading.Thread): # r might have been set above, so skip handle() if we # already have a response. if r is None: - r = self.handle(q) + r = self.handle(q, peer, connection_type) except Exception: # Exceptions from handle get a SERVFAIL response. r = dns.message.make_response(q) @@ -158,20 +168,21 @@ class Server(threading.Thread): self.udp = None # we own cleanup while True: try: - (wire, from_address) = await sock.recvfrom(65535) - wire = self.handle_wire(wire) + (wire, peer) = await sock.recvfrom(65535) + wire = self.handle_wire(wire, peer, ConnectionType.UDP) if wire is not None: - await sock.sendto(wire, from_address) + await sock.sendto(wire, peer) except Exception: pass async def serve_tcp(self, stream): try: + peer = stream.socket.getpeername() while True: ldata = await dns.trio.query.read_exactly(stream, 2) (l,) = struct.unpack("!H", ldata) wire = await dns.trio.query.read_exactly(stream, l) - wire = self.handle_wire(wire) + wire = self.handle_wire(wire, peer, ConnectionType.TCP) if wire is not None: l = len(wire) stream_message = struct.pack("!H", l) + wire diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 87aebaac..309a89d7 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -595,7 +595,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase): class NaptrNanoNameserver(Server): - def handle(self, message): + def handle(self, message, peer, connection_type): response = dns.message.make_response(message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA