From 36b8ad33236a4ee118577b574f375f9d07b73be5 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 31 May 2020 09:54:00 -0700 Subject: [PATCH] add nanonamserver, a handy testing tool --- nanonameserver.py | 194 +++++++++++++++++++++++++++++++++++++++++ tests/test_resolver.py | 53 ++++++++++- 2 files changed, 243 insertions(+), 4 deletions(-) create mode 100644 nanonameserver.py diff --git a/nanonameserver.py b/nanonameserver.py new file mode 100644 index 00000000..aaec009f --- /dev/null +++ b/nanonameserver.py @@ -0,0 +1,194 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import contextlib +import functools +import socket +import struct +import threading +import trio + +import dns.message +import dns.rcode +import dns.trio.query + + +class Server(threading.Thread): + + """The nanoserver is a nameserver skeleton suitable for faking a DNS + server for various testing purposes. It executes with a trio run + loop in a dedicated thread, and is a context manager. Exiting the + context manager will ensure the server shuts down. + + If a port is not specified, random ports will be chosen. + + Applications should subclass the server and override the handle() + method to determine how the server responds to queries. The + default behavior is to refuse everything. + + If use_thread is set to False in the constructor, then the + server's main() method can be used directly in a trio nursery, + allowing the server's cancellation to be managed in the Trio way. + In this case, no thread creation ever happens even though Server + is a subclass of thread, because the start() method is never + called. + """ + + def __init__(self, address='127.0.0.1', port=0, enable_udp=True, + enable_tcp=True, use_thread=True): + super().__init__() + self.address = address + self.port = port + self.enable_udp = enable_udp + self.enable_tcp = enable_tcp + self.use_thread = use_thread + self.left = None + self.right = None + self.udp = None + self.udp_address = None + self.tcp = None + self.tcp_address = None + + def __enter__(self): + (self.left, self.right) = socket.socketpair() + # We're making the UDP socket now so it can be sent to by the + # caller immediately (i.e. no race with the listener starting + # in the thread). + if self.enable_udp: + self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) + self.udp.bind((self.address, self.port)) + self.udp_address = self.udp.getsockname() + if self.enable_tcp: + self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.tcp.bind((self.address, self.port)) + self.tcp.listen() + self.tcp_address = self.udp.getsockname() + if self.use_thread: + self.start() + return self + + def __exit__(self, ex_ty, ex_va, ex_tr): + if self.left: + self.left.close() + if self.use_thread and self.is_alive(): + self.join() + if self.right: + self.right.close() + if self.udp: + self.udp.close() + if self.tcp: + self.tcp.close() + + async def wait_for_input_or_eof(self): + # + # This trio task just waits for input on the right half of the + # socketpair (the left half is owned by the context manager + # returned by launch). As soon as something is read, or the + # socket returns EOF, EOFError is raised, causing a the + # nursery to cancel all other nursery tasks, in particular the + # listeners. + # + try: + with trio.socket.from_stdlib_socket(self.right) as sock: + self.right = None # we own cleanup + await sock.recv(1) + finally: + raise EOFError + + def handle(self, message): + # + # Handle message 'message'. Override this method to change + # how the server behaves. + # + # The return value is either a dns.message.Message or a bytes. + # We allow a bytes to be returned for cases where handle wants + # to return an invalid DNS message for testing purposes. + # + r = dns.message.make_response(message) + r.set_rcode(dns.rcode.REFUSED) + return r + + async def serve_udp(self): + with trio.socket.from_stdlib_socket(self.udp) as sock: + self.udp = None # we own cleanup + while True: + try: + (wire, from_address) = await sock.recvfrom(65535) + q = dns.message.from_wire(wire) + r = self.handle(q) + if isinstance(r, dns.message.Message): + wire = r.to_wire() + else: + wire = r + await sock.sendto(wire, from_address) + except Exception: + pass + + async def serve_tcp(self, stream): + try: + while True: + ldata = await dns.trio.query.read_exactly(stream, 2) + (l,) = struct.unpack("!H", ldata) + wire = await dns.trio.query.read_exactly(stream, l) + q = dns.message.from_wire(wire) + r = self.handle(q) + if isinstance(r, dns.message.Message): + wire = r.to_wire() + else: + wire = r + l = len(wire) + stream_message = struct.pack("!H", l) + wire + await stream.send_all(stream_message) + except Exception: + pass + + async def orchestrate_tcp(self): + with trio.socket.from_stdlib_socket(self.tcp) as sock: + self.tcp = None # we own cleanup + listener = trio.SocketListener(sock) + async with trio.open_nursery() as nursery: + serve = functools.partial(trio.serve_listeners, self.serve_tcp, + [listener], handler_nursery=nursery) + nursery.start_soon(serve) + + async def main(self): + try: + async with trio.open_nursery() as nursery: + if self.use_thread: + nursery.start_soon(self.wait_for_input_or_eof) + if self.enable_udp: + nursery.start_soon(self.serve_udp) + if self.enable_tcp: + nursery.start_soon(self.orchestrate_tcp) + except Exception: + pass + + def run(self): + if not self.use_thread: + raise RuntimeError('start() called on a use_thread=False Server') + trio.run(self.main) + +if __name__ == "__main__": + import sys + import time + + async def trio_main(): + try: + with Server(port=5354, use_thread=False) as server: + print(f'Trio mode: listening on UDP: {server.udp_address}, ' + + f'TCP: {server.tcp_address}') + async with trio.open_nursery() as nursery: + nursery.start_soon(server.main) + except Exception: + pass + + def threaded_main(): + with Server(port=5354) as server: + print(f'Thread Mode: listening on UDP: {server.udp_address}, ' + + f'TCP: {server.tcp_address}') + time.sleep(300) + + if len(sys.argv) > 1 and sys.argv[1] == 'trio': + trio.run(trio_main) + else: + threaded_main() diff --git a/tests/test_resolver.py b/tests/test_resolver.py index c5814da6..202c272a 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -22,6 +22,7 @@ import socket import time import unittest +import dns.e164 import dns.message import dns.name import dns.rdataclass @@ -36,6 +37,17 @@ try: except socket.gaierror: _network_available = False +# Some tests use a "nano nameserver" for testing. It requires trio +# and threading, so try to import it and if it doesn't work, skip +# those tests. +try: + from nanonameserver import Server + _nanonameserver_available = True +except ImportError: + _nanonameserver_available = False + class Server(object): + pass + resolv_conf = u""" /t/t # comment 1 @@ -580,7 +592,40 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase): with self.assertRaises(ValueError): resolver.nameservers = invalid_nameserver -if __name__ == '__main__': - from IPython.core.debugger import set_trace - set_trace() - unittest.main() + +class NaptrNanoNameserver(Server): + + def handle(self, message): + response = dns.message.make_response(message) + response.set_rcode(dns.rcode.REFUSED) + response.flags |= dns.flags.RA + try: + if message.question[0].rdtype == dns.rdatatype.NAPTR and \ + message.question[0].rdclass == dns.rdataclass.IN: + rrs = dns.rrset.from_text(message.question[0].name, 300, + 'IN', 'NAPTR', + '0 0 "" "" "" .') + response.answer.append(rrs) + response.set_rcode(dns.rcode.NOERROR) + response.flags |= dns.flags.AA + except Exception: + pass + return response + + +@unittest.skipIf(not (_network_available and _nanonameserver_available), + "Internet and NanoAuth required") +class NanoTests(unittest.TestCase): + + def testE164Query(self): + with NaptrNanoNameserver() as na: + res = dns.resolver.Resolver() + res.port = na.udp_address[1] + res.nameservers = [ na.udp_address[0] ] + answer = dns.e164.query('1650551212', ['e164.arpa'], res) + self.assertEqual(answer[0].order, 0) + self.assertEqual(answer[0].preference, 0) + self.assertEqual(answer[0].flags, b'') + self.assertEqual(answer[0].service, b'') + self.assertEqual(answer[0].regexp, b'') + self.assertEqual(answer[0].replacement, dns.name.root) -- 2.47.3