]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
pass peer and connection type to nanoserver handle()
authorBob Halley <halley@dnspython.org>
Tue, 2 Jun 2020 14:22:42 +0000 (07:22 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 2 Jun 2020 14:22:42 +0000 (07:22 -0700)
tests/nanonameserver.py
tests/test_resolver.py

index 1b4a43497fbf7924d8f9c9ac36f021d39082f40c..a14d925967b742261c19054df6936630cc140cc4 100644 (file)
@@ -1,6 +1,7 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
 import contextlib
+import enum
 import functools
 import socket
 import struct
@@ -11,6 +12,9 @@ import dns.message
 import dns.rcode
 import dns.trio.query
 
+class ConnectionType(enum.IntEnum):
+    UDP = 1
+    TCP = 2
 
 class Server(threading.Thread):
 
@@ -95,7 +99,7 @@ class Server(threading.Thread):
         finally:
             raise EOFError
 
-    def handle(self, message):
+    def handle(self, message, peer, connection_type):
         #
         # Handle message 'message'.  Override this method to change
         # how the server behaves.
@@ -113,7 +117,7 @@ class Server(threading.Thread):
         except Exception:
             return None
 
-    def handle_wire(self, wire):
+    def handle_wire(self, wire, peer, connection_type):
         #
         # This is the common code to parse wire format, call handle() on
         # the message, and then generate resposne wire format (if handle()
@@ -123,6 +127,12 @@ class Server(threading.Thread):
         #
         # Returns a wire format message to send, or None indicating there
         # is nothing to do.
+        #
+        # XXXRTH It might be nice to have a "debug mode" in the server
+        # where we'd print something in all the places we're eating
+        # exceptions.  That way bugs in handle() would be easier to
+        # find.
+        #
         r = None
         try:
             q = dns.message.from_wire(wire)
@@ -142,7 +152,7 @@ class Server(threading.Thread):
             # r might have been set above, so skip handle() if we
             # already have a response.
             if r is None:
-                r = self.handle(q)
+                r = self.handle(q, peer, connection_type)
         except Exception:
             # Exceptions from handle get a SERVFAIL response.
             r = dns.message.make_response(q)
@@ -158,20 +168,21 @@ class Server(threading.Thread):
             self.udp = None  # we own cleanup
             while True:
                 try:
-                    (wire, from_address) = await sock.recvfrom(65535)
-                    wire = self.handle_wire(wire)
+                    (wire, peer) = await sock.recvfrom(65535)
+                    wire = self.handle_wire(wire, peer, ConnectionType.UDP)
                     if wire is not None:
-                        await sock.sendto(wire, from_address)
+                        await sock.sendto(wire, peer)
                 except Exception:
                     pass
 
     async def serve_tcp(self, stream):
         try:
+            peer = stream.socket.getpeername()
             while True:
                 ldata = await dns.trio.query.read_exactly(stream, 2)
                 (l,) = struct.unpack("!H", ldata)
                 wire = await dns.trio.query.read_exactly(stream, l)
-                wire = self.handle_wire(wire)
+                wire = self.handle_wire(wire, peer, ConnectionType.TCP)
                 if wire is not None:
                     l = len(wire)
                     stream_message = struct.pack("!H", l) + wire
index 87aebaac8de1218df3072891e11620efd7d9e815..309a89d70fe3a8e62c1e1a89f55286b2b103b899 100644 (file)
@@ -595,7 +595,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase):
 
 class NaptrNanoNameserver(Server):
 
-    def handle(self, message):
+    def handle(self, message, peer, connection_type):
         response = dns.message.make_response(message)
         response.set_rcode(dns.rcode.REFUSED)
         response.flags |= dns.flags.RA