From: Brian Wellington Date: Tue, 23 Jun 2020 23:46:48 +0000 (-0700) Subject: Add more functionality to nanonameserver. X-Git-Tag: v2.0.0rc2~67^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=167c89548218390cd1c28610bec86524a41617bc;p=thirdparty%2Fdnspython.git Add more functionality to nanonameserver. - When no port is specified, pick the same port for UDP and TCP, so that TCP fallback can be tested. - Change handlers to get a single Request object instead of individual parameters. The Request object contains the message, peer, and connection_type previously passed, and also adds the local address and wire format message. Additionally, it provides convenient properties for accessing the question. --- diff --git a/tests/nanonameserver.py b/tests/nanonameserver.py index 4293d8a0..2498d901 100644 --- a/tests/nanonameserver.py +++ b/tests/nanonameserver.py @@ -2,6 +2,7 @@ import contextlib import enum +import errno import functools import socket import struct @@ -29,6 +30,30 @@ class ConnectionType(enum.IntEnum): UDP = 1 TCP = 2 +class Request: + def __init__(self, message, wire, peer, local, connection_type): + self.message = message + self.wire = wire + self.peer = peer + self.local = local + self.connection_type = connection_type + + @property + def question(self): + return self.message.question[0] + + @property + def qname(self): + return self.question.name + + @property + def qclass(self): + return self.question.rdclass + + @property + def qtype(self): + return self.question.rdtype + class Server(threading.Thread): """The nanoserver is a nameserver skeleton suitable for faking a DNS @@ -68,11 +93,7 @@ class Server(threading.Thread): self.tcp = None self.tcp_address = None - def __enter__(self): - (self.left, self.right) = socket.socketpair() - # We're making the UDP socket now so it can be sent to by the - # caller immediately (i.e. no race with the listener starting - # in the thread). + def _open_sockets(self): if self.enable_udp: self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) self.udp.bind((self.address, self.port)) @@ -80,9 +101,42 @@ class Server(threading.Thread): if self.enable_tcp: self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.tcp.bind((self.address, self.port)) + port = self.port + if port is 0 and self.enable_udp: + port = self.udp_address[1] + port = 12347 + self.tcp.bind((self.address, port)) self.tcp.listen() self.tcp_address = self.tcp.getsockname() + + def __enter__(self): + (self.left, self.right) = socket.socketpair() + # We're making the sockets now so they can be sent to by the + # caller immediately (i.e. no race with the listener starting + # in the thread). + open_udp_sockets = [] + while True: + if self.enable_udp: + self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) + self.udp.bind((self.address, self.port)) + self.udp_address = self.udp.getsockname() + if self.enable_tcp: + self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.port is 0 and self.enable_udp: + try: + self.tcp.bind((self.address, self.udp_address[1])) + except OSError as e: + if e.errno == errno.EADDRINUSE and \ + len(open_udp_sockets) < 100: + open_udp_sockets.append(self.udp) + continue + raise + else: + self.tcp.bind((self.address, self.port)) + self.tcp.listen() + self.tcp_address = self.tcp.getsockname() + break if self.use_thread: self.start() return self @@ -141,7 +195,7 @@ class Server(threading.Thread): else: return [thing] - def handle_wire(self, wire, peer, connection_type): + def handle_wire(self, wire, peer, local, connection_type): # # This is the common code to parse wire format, call handle() on # the message, and then generate response wire format (if handle() @@ -180,8 +234,8 @@ class Server(threading.Thread): # items might have been appended to above, so skip # handle() if we already have a response. if not items: - items = self.maybe_listify(self.handle(q, peer, - connection_type)) + request = Request(q, wire, peer, local, connection_type) + items = self.maybe_listify(self.handle(request)) except Exception: # Exceptions from handle get a SERVFAIL response. r = dns.message.make_response(q) @@ -201,10 +255,11 @@ class Server(threading.Thread): async def serve_udp(self): with trio.socket.from_stdlib_socket(self.udp) as sock: self.udp = None # we own cleanup + local = self.udp_address while True: try: (wire, peer) = await sock.recvfrom(65535) - for wire in self.handle_wire(wire, peer, + for wire in self.handle_wire(wire, peer, local, ConnectionType.UDP): await sock.sendto(wire, peer) except Exception: @@ -213,11 +268,13 @@ class Server(threading.Thread): async def serve_tcp(self, stream): try: peer = stream.socket.getpeername() + local = stream.socket.getsockname() while True: ldata = await read_exactly(stream, 2) (l,) = struct.unpack("!H", ldata) wire = await read_exactly(stream, l) - for wire in self.handle_wire(wire, peer, ConnectionType.TCP): + for wire in self.handle_wire(wire, peer, local, + ConnectionType.TCP): l = len(wire) stream_message = struct.pack("!H", l) + wire await stream.send_all(stream_message) diff --git a/tests/test_query.py b/tests/test_query.py index b967a6e2..895dc0d7 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -248,16 +248,16 @@ ns2 A 10.0.0.1 class AXFRNanoNameserver(Server): - def handle(self, message, peer, connection_type): + def handle(self, request): self.zone = dns.zone.from_text(axfr_zone, origin='example') self.origin = self.zone.origin items = [] soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA) - response = dns.message.make_response(message) + response = dns.message.make_response(request.message) response.flags |= dns.flags.AA response.answer.append(soa) items.append(response) - response = dns.message.make_response(message) + response = dns.message.make_response(request.message) response.question = [] response.flags |= dns.flags.AA for (name, rdataset) in self.zone.iterate_rdatasets(): @@ -269,7 +269,7 @@ class AXFRNanoNameserver(Server): rrset.update(rdataset) response.answer.append(rrset) items.append(response) - response = dns.message.make_response(message) + response = dns.message.make_response(request.message) response.question = [] response.flags |= dns.flags.AA response.answer.append(soa) @@ -329,10 +329,10 @@ class IXFRNanoNameserver(Server): super().__init__() self.response_text = response_text - def handle(self, message, peer, connection_type): + def handle(self, request): try: r = dns.message.from_text(self.response_text, one_rr_per_rrset=True) - r.id = message.id + r.id = request.message.id return r except Exception: pass @@ -450,14 +450,14 @@ class XfrTests(unittest.TestCase): class TSIGNanoNameserver(Server): - def handle(self, message, peer, connection_type): - response = dns.message.make_response(message) + def handle(self, request): + response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA try: - if message.question[0].rdtype == dns.rdatatype.A and \ - message.question[0].rdclass == dns.rdataclass.IN: - rrs = dns.rrset.from_text(message.question[0].name, 300, + if request.qtype == dns.rdatatype.A and \ + request.qclass == dns.rdataclass.IN: + rrs = dns.rrset.from_text(request.qname, 300, 'IN', 'A', '1.2.3.4') response.answer.append(rrs) response.set_rcode(dns.rcode.NOERROR) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index fe7573eb..25c5c57b 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -614,19 +614,18 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase): class NaptrNanoNameserver(Server): - def handle(self, message, peer, connection_type): - response = dns.message.make_response(message) + def handle(self, request): + response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA try: zero_subdomain = dns.e164.from_e164('0') - if message.question[0].name.is_subdomain(zero_subdomain): + if request.qname.is_subdomain(zero_subdomain): response.set_rcode(dns.rcode.NXDOMAIN) response.flags |= dns.flags.AA - elif message.question[0].rdtype == dns.rdatatype.NAPTR and \ - message.question[0].rdclass == dns.rdataclass.IN: - rrs = dns.rrset.from_text(message.question[0].name, 300, - 'IN', 'NAPTR', + elif request.qtype == dns.rdatatype.NAPTR and \ + request.qclass == dns.rdataclass.IN: + rrs = dns.rrset.from_text(request.qname, 300, 'IN', 'NAPTR', '0 0 "" "" "" .') response.answer.append(rrs) response.set_rcode(dns.rcode.NOERROR)