]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
add nanonamserver, a handy testing tool
authorBob Halley <halley@dnspython.org>
Sun, 31 May 2020 16:54:00 +0000 (09:54 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 31 May 2020 16:54:00 +0000 (09:54 -0700)
nanonameserver.py [new file with mode: 0644]
tests/test_resolver.py

diff --git a/nanonameserver.py b/nanonameserver.py
new file mode 100644 (file)
index 0000000..aaec009
--- /dev/null
@@ -0,0 +1,194 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import contextlib
+import functools
+import socket
+import struct
+import threading
+import trio
+
+import dns.message
+import dns.rcode
+import dns.trio.query
+
+
+class Server(threading.Thread):
+
+    """The nanoserver is a nameserver skeleton suitable for faking a DNS
+    server for various testing purposes.  It executes with a trio run
+    loop in a dedicated thread, and is a context manager.  Exiting the
+    context manager will ensure the server shuts down.
+
+    If a port is not specified, random ports will be chosen.
+
+    Applications should subclass the server and override the handle()
+    method to determine how the server responds to queries.  The
+    default behavior is to refuse everything.
+
+    If use_thread is set to False in the constructor, then the
+    server's main() method can be used directly in a trio nursery,
+    allowing the server's cancellation to be managed in the Trio way.
+    In this case, no thread creation ever happens even though Server
+    is a subclass of thread, because the start() method is never
+    called.
+    """
+
+    def __init__(self, address='127.0.0.1', port=0, enable_udp=True,
+                 enable_tcp=True, use_thread=True):
+        super().__init__()
+        self.address = address
+        self.port = port
+        self.enable_udp = enable_udp
+        self.enable_tcp = enable_tcp
+        self.use_thread = use_thread
+        self.left = None
+        self.right = None
+        self.udp = None
+        self.udp_address = None
+        self.tcp = None
+        self.tcp_address = None
+
+    def __enter__(self):
+        (self.left, self.right) = socket.socketpair()
+        # We're making the UDP socket now so it can be sent to by the
+        # caller immediately (i.e. no race with the listener starting
+        # in the thread).
+        if self.enable_udp:
+            self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+            self.udp.bind((self.address, self.port))
+            self.udp_address = self.udp.getsockname()
+        if self.enable_tcp:
+            self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+            self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            self.tcp.bind((self.address, self.port))
+            self.tcp.listen()
+            self.tcp_address = self.udp.getsockname()
+        if self.use_thread:
+            self.start()
+        return self
+
+    def __exit__(self, ex_ty, ex_va, ex_tr):
+        if self.left:
+            self.left.close()
+        if self.use_thread and self.is_alive():
+            self.join()
+        if self.right:
+            self.right.close()
+        if self.udp:
+            self.udp.close()
+        if self.tcp:
+            self.tcp.close()
+
+    async def wait_for_input_or_eof(self):
+        #
+        # This trio task just waits for input on the right half of the
+        # socketpair (the left half is owned by the context manager
+        # returned by launch).  As soon as something is read, or the
+        # socket returns EOF, EOFError is raised, causing a the
+        # nursery to cancel all other nursery tasks, in particular the
+        # listeners.
+        #
+        try:
+            with trio.socket.from_stdlib_socket(self.right) as sock:
+                self.right = None  # we own cleanup
+                await sock.recv(1)
+        finally:
+            raise EOFError
+
+    def handle(self, message):
+        #
+        # Handle message 'message'.  Override this method to change
+        # how the server behaves.
+        #
+        # The return value is either a dns.message.Message or a bytes.
+        # We allow a bytes to be returned for cases where handle wants
+        # to return an invalid DNS message for testing purposes.
+        #
+        r = dns.message.make_response(message)
+        r.set_rcode(dns.rcode.REFUSED)
+        return r
+
+    async def serve_udp(self):
+        with trio.socket.from_stdlib_socket(self.udp) as sock:
+            self.udp = None  # we own cleanup
+            while True:
+                try:
+                    (wire, from_address) = await sock.recvfrom(65535)
+                    q = dns.message.from_wire(wire)
+                    r = self.handle(q)
+                    if isinstance(r, dns.message.Message):
+                        wire = r.to_wire()
+                    else:
+                        wire = r
+                    await sock.sendto(wire, from_address)
+                except Exception:
+                    pass
+
+    async def serve_tcp(self, stream):
+        try:
+            while True:
+                ldata = await dns.trio.query.read_exactly(stream, 2)
+                (l,) = struct.unpack("!H", ldata)
+                wire = await dns.trio.query.read_exactly(stream, l)
+                q = dns.message.from_wire(wire)
+                r = self.handle(q)
+                if isinstance(r, dns.message.Message):
+                    wire = r.to_wire()
+                else:
+                    wire = r
+                l = len(wire)
+                stream_message = struct.pack("!H", l) + wire
+                await stream.send_all(stream_message)
+        except Exception:
+            pass
+
+    async def orchestrate_tcp(self):
+        with trio.socket.from_stdlib_socket(self.tcp) as sock:
+            self.tcp = None  # we own cleanup
+            listener = trio.SocketListener(sock)
+            async with trio.open_nursery() as nursery:
+                serve = functools.partial(trio.serve_listeners, self.serve_tcp,
+                                          [listener], handler_nursery=nursery)
+                nursery.start_soon(serve)
+
+    async def main(self):
+        try:
+            async with trio.open_nursery() as nursery:
+                if self.use_thread:
+                    nursery.start_soon(self.wait_for_input_or_eof)
+                if self.enable_udp:
+                    nursery.start_soon(self.serve_udp)
+                if self.enable_tcp:
+                    nursery.start_soon(self.orchestrate_tcp)
+        except Exception:
+            pass
+
+    def run(self):
+        if not self.use_thread:
+            raise RuntimeError('start() called on a use_thread=False Server')
+        trio.run(self.main)
+
+if __name__ == "__main__":
+    import sys
+    import time
+
+    async def trio_main():
+        try:
+            with Server(port=5354, use_thread=False) as server:
+                print(f'Trio mode: listening on UDP: {server.udp_address}, ' +
+                      f'TCP: {server.tcp_address}')
+                async with trio.open_nursery() as nursery:
+                    nursery.start_soon(server.main)
+        except Exception:
+            pass
+
+    def threaded_main():
+        with Server(port=5354) as server:
+            print(f'Thread Mode: listening on UDP: {server.udp_address}, ' +
+                  f'TCP: {server.tcp_address}')
+            time.sleep(300)
+
+    if len(sys.argv) > 1 and sys.argv[1] == 'trio':
+        trio.run(trio_main)
+    else:
+        threaded_main()
index c5814da6e5517b0fe9543ca5cbfecaa0e41dbb05..202c272a9a4f7a2d242a8156caa4dfd8449c2a9f 100644 (file)
@@ -22,6 +22,7 @@ import socket
 import time
 import unittest
 
+import dns.e164
 import dns.message
 import dns.name
 import dns.rdataclass
@@ -36,6 +37,17 @@ try:
 except socket.gaierror:
     _network_available = False
 
+# Some tests use a "nano nameserver" for testing.  It requires trio
+# and threading, so try to import it and if it doesn't work, skip
+# those tests.
+try:
+    from nanonameserver import Server
+    _nanonameserver_available = True
+except ImportError:
+    _nanonameserver_available = False
+    class Server(object):
+        pass
+
 resolv_conf = u"""
     /t/t
 # comment 1
@@ -580,7 +592,40 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase):
             with self.assertRaises(ValueError):
                 resolver.nameservers = invalid_nameserver
 
-if __name__ == '__main__':
-    from IPython.core.debugger import set_trace
-    set_trace()
-    unittest.main()
+
+class NaptrNanoNameserver(Server):
+
+    def handle(self, message):
+        response = dns.message.make_response(message)
+        response.set_rcode(dns.rcode.REFUSED)
+        response.flags |= dns.flags.RA
+        try:
+            if message.question[0].rdtype == dns.rdatatype.NAPTR and \
+               message.question[0].rdclass == dns.rdataclass.IN:
+                rrs = dns.rrset.from_text(message.question[0].name, 300,
+                                          'IN', 'NAPTR',
+                                          '0 0 "" "" "" .')
+                response.answer.append(rrs)
+                response.set_rcode(dns.rcode.NOERROR)
+                response.flags |= dns.flags.AA
+        except Exception:
+            pass
+        return response
+
+
+@unittest.skipIf(not (_network_available and _nanonameserver_available),
+                 "Internet and NanoAuth required")
+class NanoTests(unittest.TestCase):
+
+    def testE164Query(self):
+        with NaptrNanoNameserver() as na:
+            res = dns.resolver.Resolver()
+            res.port = na.udp_address[1]
+            res.nameservers = [ na.udp_address[0] ]
+            answer = dns.e164.query('1650551212', ['e164.arpa'], res)
+            self.assertEqual(answer[0].order, 0)
+            self.assertEqual(answer[0].preference, 0)
+            self.assertEqual(answer[0].flags, b'')
+            self.assertEqual(answer[0].service, b'')
+            self.assertEqual(answer[0].regexp, b'')
+            self.assertEqual(answer[0].replacement, dns.name.root)