]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Ask background threads to stop in regression tests
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 19 Nov 2021 12:17:12 +0000 (13:17 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 29 Jun 2022 16:56:29 +0000 (18:56 +0200)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_Protobuf.py
regression-tests.dnsdist/test_TCPFastOpen.py

index 4fd5694bcb4a3ccc8a0d649f8d9de9c224fd8a5c..8a05461dc582de99925ded0716f89a63f556187e 100644 (file)
@@ -68,6 +68,9 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
     _checkConfigExpectedOutput = None
     _verboseMode = False
     _skipListeningOnCL = False
+    _backgroundThreads = {}
+    _UDPResponder = None
+    _TCPResponder = None
 
     @classmethod
     def startResponders(cls):
@@ -163,6 +166,10 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                     cls._dnsdist.kill()
                 cls._dnsdist.wait()
 
+        # tell the background threads to stop, if any
+        for backgroundThread in cls._backgroundThreads:
+          cls._backgroundThreads[backgroundThread] = False
+
     @classmethod
     def _ResponderIncrementCounter(cls):
         if threading.currentThread().name in cls._responsesCounter:
@@ -202,6 +209,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     @classmethod
     def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None):
+        cls._backgroundThreads[threading.get_native_id()] = True
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
@@ -211,8 +219,17 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         sock.bind(("127.0.0.1", port))
+        sock.settimeout(1.0)
         while True:
-            data, addr = sock.recvfrom(4096)
+            try:
+              data, addr = sock.recvfrom(4096)
+            except socket.timeout:
+              if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+                del cls._backgroundThreads[threading.get_native_id()]
+                break
+              else:
+                continue
+
             forceRcode = None
             try:
                 request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
@@ -235,9 +252,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             if not wire:
               continue
 
-            sock.settimeout(2.0)
             sock.sendto(wire, addr)
-            sock.settimeout(None)
+
         sock.close()
 
     @classmethod
@@ -299,6 +315,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     @classmethod
     def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1'):
+        cls._backgroundThreads[threading.get_native_id()] = True
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
@@ -314,6 +331,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             sys.exit(1)
 
         sock.listen(100)
+        sock.settimeout(1.0)
         if tlsContext:
           sock = tlsContext.wrap_socket(sock, server_side=True)
 
@@ -324,6 +342,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
               continue
             except ConnectionResetError:
               continue
+            except socket.timeout:
+              if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+                 del cls._backgroundThreads[threading.get_native_id()]
+                 break
+              else:
+                continue
 
             conn.settimeout(5.0)
             if multipleConnections:
@@ -434,6 +458,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     @classmethod
     def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False):
+        cls._backgroundThreads[threading.get_native_id()] = True
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
@@ -449,6 +474,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             sys.exit(1)
 
         sock.listen(100)
+        sock.settimeout(1.0)
         if tlsContext:
             sock = tlsContext.wrap_socket(sock, server_side=True)
 
@@ -461,6 +487,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
                 continue
             except ConnectionResetError:
               continue
+            except socket.timeout:
+                if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+                    del cls._backgroundThreads[threading.get_native_id()]
+                    break
+                else:
+                    continue
 
             conn.settimeout(5.0)
             thread = threading.Thread(name='DoH Connection Handler',
index b5384aaf6a05f1f28b4227d8f1ba44a62d450b5a..7f0e7ba4f0f0778dec8d04af5f72360e31f72f47 100644 (file)
@@ -17,6 +17,7 @@ class DNSDistProtobufTest(DNSDistTest):
 
     @classmethod
     def ProtobufListener(cls, port):
+        cls._backgroundThreads[threading.get_native_id()] = True
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
         try:
@@ -26,8 +27,17 @@ class DNSDistProtobufTest(DNSDistTest):
             sys.exit(1)
 
         sock.listen(100)
+        sock.settimeout(1.0)
         while True:
-            (conn, _) = sock.accept()
+            try:
+                (conn, _) = sock.accept()
+            except socket.timeout:
+                if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+                    del cls._backgroundThreads[threading.get_native_id()]
+                    break
+                else:
+                    continue
+
             data = None
             while True:
                 data = conn.recv(2)
index 73c01a8ba3f5f8c61c8975eed172b6b945be889d..96900ce68a98c8f1beac0720d6730a4ff073b4a8 100644 (file)
@@ -30,6 +30,7 @@ class TestBrokenTCPFastOpen(DNSDistTest):
 
     @classmethod
     def BrokenTCPResponder(cls, port):
+        cls._backgroundThreads[threading.get_native_id()] = True
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
@@ -40,8 +41,17 @@ class TestBrokenTCPFastOpen(DNSDistTest):
             sys.exit(1)
 
         sock.listen(100)
+        sock.settimeout(1.0)
         while True:
-            (conn, _) = sock.accept()
+            try:
+                (conn, _) = sock.accept()
+            except socket.timeout:
+                if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
+                    del cls._backgroundThreads[threading.get_native_id()]
+                    break
+                else:
+                    continue
+
             conn.settimeout(5.0)
             data = conn.recv(2)
             if not data: