]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
selftest: allow dns_hub.py to listen on more than one address
authorStefan Metzmacher <metze@samba.org>
Wed, 11 Mar 2020 15:55:33 +0000 (16:55 +0100)
committerAndreas Schneider <asn@cryptomilk.org>
Fri, 27 Mar 2020 09:02:37 +0000 (09:02 +0000)
This makes it possible to serve ipv4 and ipv6 at the same time.

Signed-off-by: Stefan Metzmacher <metze@samba.org>
Reviewed-by: Andreas Schneider <asn@samba.org>
selftest/target/dns_hub.py

index 6b3c0f0708729434163f5746be012c9512b66a17..e9982c1341b9398ca87fbb8198c5b716fb560631 100755 (executable)
@@ -24,6 +24,7 @@ import threading
 import sys
 import select
 import socket
+import collections
 import time
 from samba.dcerpc import dns
 import samba.ndr as ndr
@@ -158,44 +159,81 @@ class DnsHandler(sserver.BaseRequestHandler):
 
 
 class server_thread(threading.Thread):
-    def __init__(self, server):
-        threading.Thread.__init__(self)
+    def __init__(self, server, name):
+        threading.Thread.__init__(self, name=name)
         self.server = server
 
     def run(self):
+        print("dns_hub[%s]: before serve_forever()" % self.name)
         self.server.serve_forever()
-        print("dns_hub: after serve_forever()")
+        print("dns_hub[%s]: after serve_forever()" % self.name)
 
+    def stop(self):
+        print("dns_hub[%s]: before shutdown()" % self.name)
+        self.server.shutdown()
+        print("dns_hub[%s]: after shutdown()" % self.name)
+
+class UDPV4Server(sserver.UDPServer):
+    address_family = socket.AF_INET
+
+class UDPV6Server(sserver.UDPServer):
+    address_family = socket.AF_INET6
 
 def main():
     if len(sys.argv) < 4:
-        print("Usage: dns_hub.py TIMEOUT HOST MAPPING")
+        print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
         sys.exit(1)
 
     timeout = int(sys.argv[1]) * 1000
     timeout = min(timeout, 2**31 - 1)  # poll with 32-bit int can't take more
-    host = sys.argv[2]
-
-    server = sserver.UDPServer((host, int(53)), DnsHandler)
-
+    # we pass in the listen addresses as a comma-separated string.
+    listenaddresses = sys.argv[2].split(',')
     # we pass in the realm-to-IP mappings as a comma-separated key=value
     # string. Convert this back into a dictionary that the DnsHandler can use
-    realm_mapping = dict(kv.split('=') for kv in sys.argv[3].split(','))
-    server.realm_to_ip_mappings = realm_mapping
+    realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(','))
+
+    def prepare_server_thread(listenaddress, realm_mappings):
+
+        flags = socket.AddressInfo.AI_NUMERICHOST
+        flags |= socket.AddressInfo.AI_NUMERICSERV
+        flags |= socket.AddressInfo.AI_PASSIVE
+        addr_info = socket.getaddrinfo(listenaddress, int(53),
+                                       type=socket.SocketKind.SOCK_DGRAM,
+                                       flags=flags)
+        assert len(addr_info) == 1
+        if addr_info[0][0] == socket.AddressFamily.AF_INET6:
+            server = UDPV6Server(addr_info[0][4], DnsHandler)
+        else:
+            server = UDPV4Server(addr_info[0][4], DnsHandler)
+
+        # we pass in the realm-to-IP mappings as a comma-separated key=value
+        # string. Convert this back into a dictionary that the DnsHandler can use
+        server.realm_to_ip_mappings = realm_mappings
+        t = server_thread(server, name="UDP[%s]" % listenaddress)
+        return t
 
     print("dns_hub will proxy DNS requests for the following realms:")
-    for realm, ip in server.realm_to_ip_mappings.items():
+    for realm, ip in realm_mappings.items():
         print("  {0} ==> {1}".format(realm, ip))
 
-    t = server_thread(server)
-    t.start()
+    print("dns_hub will listen on the following UDP addresses:")
+    threads = []
+    for listenaddress in listenaddresses:
+        print("  %s" % listenaddress)
+        t = prepare_server_thread(listenaddress, realm_mappings)
+        threads.append(t)
+
+    for t in threads:
+        t.start()
     p = select.poll()
     stdin = sys.stdin.fileno()
     p.register(stdin, select.POLLIN)
     p.poll(timeout)
     print("dns_hub: after poll()")
-    server.shutdown()
-    t.join()
+    for t in threads:
+        t.stop()
+    for t in threads:
+        t.join()
     print("dns_hub: before exit()")
     sys.exit(0)