]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add more functionality to nanonameserver.
authorBrian Wellington <bwelling@xbill.org>
Tue, 23 Jun 2020 23:46:48 +0000 (16:46 -0700)
committerBrian Wellington <bwelling@xbill.org>
Tue, 23 Jun 2020 23:46:48 +0000 (16:46 -0700)
- When no port is specified, pick the same port for UDP and TCP, so that
TCP fallback can be tested.

- Change handlers to get a single Request object instead of individual
parameters.  The Request object contains the message, peer, and
connection_type previously passed, and also adds the local address and
wire format message.  Additionally, it provides convenient properties
for accessing the question.

tests/nanonameserver.py
tests/test_query.py
tests/test_resolver.py

index 4293d8a0856592fec92b4df71e5bff2a3bb0571e..2498d9011e793f3931a8619861da608c692f6c0d 100644 (file)
@@ -2,6 +2,7 @@
 
 import contextlib
 import enum
+import errno
 import functools
 import socket
 import struct
@@ -29,6 +30,30 @@ class ConnectionType(enum.IntEnum):
     UDP = 1
     TCP = 2
 
+class Request:
+    def __init__(self, message, wire, peer, local, connection_type):
+        self.message = message
+        self.wire = wire
+        self.peer = peer
+        self.local = local
+        self.connection_type = connection_type
+
+    @property
+    def question(self):
+        return self.message.question[0]
+
+    @property
+    def qname(self):
+        return self.question.name
+
+    @property
+    def qclass(self):
+        return self.question.rdclass
+
+    @property
+    def qtype(self):
+        return self.question.rdtype
+
 class Server(threading.Thread):
 
     """The nanoserver is a nameserver skeleton suitable for faking a DNS
@@ -68,11 +93,7 @@ class Server(threading.Thread):
         self.tcp = None
         self.tcp_address = None
 
-    def __enter__(self):
-        (self.left, self.right) = socket.socketpair()
-        # We're making the UDP socket now so it can be sent to by the
-        # caller immediately (i.e. no race with the listener starting
-        # in the thread).
+    def _open_sockets(self):
         if self.enable_udp:
             self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
             self.udp.bind((self.address, self.port))
@@ -80,9 +101,42 @@ class Server(threading.Thread):
         if self.enable_tcp:
             self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
             self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-            self.tcp.bind((self.address, self.port))
+            port = self.port
+            if port is 0 and self.enable_udp:
+                port = self.udp_address[1]
+                port = 12347
+            self.tcp.bind((self.address, port))
             self.tcp.listen()
             self.tcp_address = self.tcp.getsockname()
+
+    def __enter__(self):
+        (self.left, self.right) = socket.socketpair()
+        # We're making the sockets now so they can be sent to by the
+        # caller immediately (i.e. no race with the listener starting
+        # in the thread).
+        open_udp_sockets = []
+        while True:
+            if self.enable_udp:
+                self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
+                self.udp.bind((self.address, self.port))
+                self.udp_address = self.udp.getsockname()
+            if self.enable_tcp:
+                self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+                self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                if self.port is 0 and self.enable_udp:
+                    try:
+                        self.tcp.bind((self.address, self.udp_address[1]))
+                    except OSError as e:
+                        if e.errno == errno.EADDRINUSE and \
+                           len(open_udp_sockets) < 100:
+                            open_udp_sockets.append(self.udp)
+                            continue
+                        raise
+                else:
+                    self.tcp.bind((self.address, self.port))
+                self.tcp.listen()
+                self.tcp_address = self.tcp.getsockname()
+            break
         if self.use_thread:
             self.start()
         return self
@@ -141,7 +195,7 @@ class Server(threading.Thread):
         else:
             return [thing]
 
-    def handle_wire(self, wire, peer, connection_type):
+    def handle_wire(self, wire, peer, local, connection_type):
         #
         # This is the common code to parse wire format, call handle() on
         # the message, and then generate response wire format (if handle()
@@ -180,8 +234,8 @@ class Server(threading.Thread):
             # items might have been appended to above, so skip
             # handle() if we already have a response.
             if not items:
-                items = self.maybe_listify(self.handle(q, peer,
-                                                       connection_type))
+                request = Request(q, wire, peer, local, connection_type)
+                items = self.maybe_listify(self.handle(request))
         except Exception:
             # Exceptions from handle get a SERVFAIL response.
             r = dns.message.make_response(q)
@@ -201,10 +255,11 @@ class Server(threading.Thread):
     async def serve_udp(self):
         with trio.socket.from_stdlib_socket(self.udp) as sock:
             self.udp = None  # we own cleanup
+            local = self.udp_address
             while True:
                 try:
                     (wire, peer) = await sock.recvfrom(65535)
-                    for wire in self.handle_wire(wire, peer,
+                    for wire in self.handle_wire(wire, peer, local,
                                                  ConnectionType.UDP):
                         await sock.sendto(wire, peer)
                 except Exception:
@@ -213,11 +268,13 @@ class Server(threading.Thread):
     async def serve_tcp(self, stream):
         try:
             peer = stream.socket.getpeername()
+            local = stream.socket.getsockname()
             while True:
                 ldata = await read_exactly(stream, 2)
                 (l,) = struct.unpack("!H", ldata)
                 wire = await read_exactly(stream, l)
-                for wire in self.handle_wire(wire, peer, ConnectionType.TCP):
+                for wire in self.handle_wire(wire, peer, local,
+                                             ConnectionType.TCP):
                     l = len(wire)
                     stream_message = struct.pack("!H", l) + wire
                     await stream.send_all(stream_message)
index b967a6e26737f591b4de7365defa9b0494e983db..895dc0d7477c2798b760a42d6538983833c41a7c 100644 (file)
@@ -248,16 +248,16 @@ ns2 A 10.0.0.1
 
 class AXFRNanoNameserver(Server):
 
-    def handle(self, message, peer, connection_type):
+    def handle(self, request):
         self.zone = dns.zone.from_text(axfr_zone, origin='example')
         self.origin = self.zone.origin
         items = []
         soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA)
-        response = dns.message.make_response(message)
+        response = dns.message.make_response(request.message)
         response.flags |= dns.flags.AA
         response.answer.append(soa)
         items.append(response)
-        response = dns.message.make_response(message)
+        response = dns.message.make_response(request.message)
         response.question = []
         response.flags |= dns.flags.AA
         for (name, rdataset) in self.zone.iterate_rdatasets():
@@ -269,7 +269,7 @@ class AXFRNanoNameserver(Server):
             rrset.update(rdataset)
             response.answer.append(rrset)
         items.append(response)
-        response = dns.message.make_response(message)
+        response = dns.message.make_response(request.message)
         response.question = []
         response.flags |= dns.flags.AA
         response.answer.append(soa)
@@ -329,10 +329,10 @@ class IXFRNanoNameserver(Server):
         super().__init__()
         self.response_text = response_text
 
-    def handle(self, message, peer, connection_type):
+    def handle(self, request):
         try:
             r = dns.message.from_text(self.response_text, one_rr_per_rrset=True)
-            r.id = message.id
+            r.id = request.message.id
             return r
         except Exception:
             pass
@@ -450,14 +450,14 @@ class XfrTests(unittest.TestCase):
 
 class TSIGNanoNameserver(Server):
 
-    def handle(self, message, peer, connection_type):
-        response = dns.message.make_response(message)
+    def handle(self, request):
+        response = dns.message.make_response(request.message)
         response.set_rcode(dns.rcode.REFUSED)
         response.flags |= dns.flags.RA
         try:
-            if message.question[0].rdtype == dns.rdatatype.A and \
-               message.question[0].rdclass == dns.rdataclass.IN:
-                rrs = dns.rrset.from_text(message.question[0].name, 300,
+            if request.qtype == dns.rdatatype.A and \
+               request.qclass == dns.rdataclass.IN:
+                rrs = dns.rrset.from_text(request.qname, 300,
                                           'IN', 'A', '1.2.3.4')
                 response.answer.append(rrs)
                 response.set_rcode(dns.rcode.NOERROR)
index fe7573ebd6c15afef0fe902e54c8f53a9765d750..25c5c57b3c40b5d913e56022da1692737ae1aeb9 100644 (file)
@@ -614,19 +614,18 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase):
 
 class NaptrNanoNameserver(Server):
 
-    def handle(self, message, peer, connection_type):
-        response = dns.message.make_response(message)
+    def handle(self, request):
+        response = dns.message.make_response(request.message)
         response.set_rcode(dns.rcode.REFUSED)
         response.flags |= dns.flags.RA
         try:
             zero_subdomain = dns.e164.from_e164('0')
-            if message.question[0].name.is_subdomain(zero_subdomain):
+            if request.qname.is_subdomain(zero_subdomain):
                 response.set_rcode(dns.rcode.NXDOMAIN)
                 response.flags |= dns.flags.AA
-            elif message.question[0].rdtype == dns.rdatatype.NAPTR and \
-               message.question[0].rdclass == dns.rdataclass.IN:
-                rrs = dns.rrset.from_text(message.question[0].name, 300,
-                                          'IN', 'NAPTR',
+            elif request.qtype == dns.rdatatype.NAPTR and \
+                 request.qclass == dns.rdataclass.IN:
+                rrs = dns.rrset.from_text(request.qname, 300, 'IN', 'NAPTR',
                                           '0 0 "" "" "" .')
                 response.answer.append(rrs)
                 response.set_rcode(dns.rcode.NOERROR)