From: Remi Gacogne Date: Fri, 19 Nov 2021 12:17:12 +0000 (+0100) Subject: dnsdist: Ask background threads to stop in regression tests X-Git-Tag: rec-4.9.0-alpha0~12^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7373e3a6869acb1956980b1d1d7f6cbd71613c09;p=thirdparty%2Fpdns.git dnsdist: Ask background threads to stop in regression tests --- diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 4fd5694bcb..8a05461dc5 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -68,6 +68,9 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): _checkConfigExpectedOutput = None _verboseMode = False _skipListeningOnCL = False + _backgroundThreads = {} + _UDPResponder = None + _TCPResponder = None @classmethod def startResponders(cls): @@ -163,6 +166,10 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): cls._dnsdist.kill() cls._dnsdist.wait() + # tell the background threads to stop, if any + for backgroundThread in cls._backgroundThreads: + cls._backgroundThreads[backgroundThread] = False + @classmethod def _ResponderIncrementCounter(cls): if threading.currentThread().name in cls._responsesCounter: @@ -202,6 +209,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): @classmethod def UDPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, callback=None): + cls._backgroundThreads[threading.get_native_id()] = True # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. @@ -211,8 +219,17 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.bind(("127.0.0.1", port)) + sock.settimeout(1.0) while True: - data, addr = sock.recvfrom(4096) + try: + data, addr = sock.recvfrom(4096) + except socket.timeout: + if cls._backgroundThreads.get(threading.get_native_id(), False) == False: + del cls._backgroundThreads[threading.get_native_id()] + break + else: + continue + forceRcode = None try: request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) @@ -235,9 +252,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): if not wire: continue - sock.settimeout(2.0) sock.sendto(wire, addr) - sock.settimeout(None) + sock.close() @classmethod @@ -299,6 +315,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): @classmethod def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False, listeningAddr='127.0.0.1'): + cls._backgroundThreads[threading.get_native_id()] = True # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. @@ -314,6 +331,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): sys.exit(1) sock.listen(100) + sock.settimeout(1.0) if tlsContext: sock = tlsContext.wrap_socket(sock, server_side=True) @@ -324,6 +342,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): continue except ConnectionResetError: continue + except socket.timeout: + if cls._backgroundThreads.get(threading.get_native_id(), False) == False: + del cls._backgroundThreads[threading.get_native_id()] + break + else: + continue conn.settimeout(5.0) if multipleConnections: @@ -434,6 +458,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): @classmethod def DOHResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, useProxyProtocol=False): + cls._backgroundThreads[threading.get_native_id()] = True # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. @@ -449,6 +474,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): sys.exit(1) sock.listen(100) + sock.settimeout(1.0) if tlsContext: sock = tlsContext.wrap_socket(sock, server_side=True) @@ -461,6 +487,12 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): continue except ConnectionResetError: continue + except socket.timeout: + if cls._backgroundThreads.get(threading.get_native_id(), False) == False: + del cls._backgroundThreads[threading.get_native_id()] + break + else: + continue conn.settimeout(5.0) thread = threading.Thread(name='DoH Connection Handler', diff --git a/regression-tests.dnsdist/test_Protobuf.py b/regression-tests.dnsdist/test_Protobuf.py index b5384aaf6a..7f0e7ba4f0 100644 --- a/regression-tests.dnsdist/test_Protobuf.py +++ b/regression-tests.dnsdist/test_Protobuf.py @@ -17,6 +17,7 @@ class DNSDistProtobufTest(DNSDistTest): @classmethod def ProtobufListener(cls, port): + cls._backgroundThreads[threading.get_native_id()] = True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) try: @@ -26,8 +27,17 @@ class DNSDistProtobufTest(DNSDistTest): sys.exit(1) sock.listen(100) + sock.settimeout(1.0) while True: - (conn, _) = sock.accept() + try: + (conn, _) = sock.accept() + except socket.timeout: + if cls._backgroundThreads.get(threading.get_native_id(), False) == False: + del cls._backgroundThreads[threading.get_native_id()] + break + else: + continue + data = None while True: data = conn.recv(2) diff --git a/regression-tests.dnsdist/test_TCPFastOpen.py b/regression-tests.dnsdist/test_TCPFastOpen.py index 73c01a8ba3..96900ce68a 100644 --- a/regression-tests.dnsdist/test_TCPFastOpen.py +++ b/regression-tests.dnsdist/test_TCPFastOpen.py @@ -30,6 +30,7 @@ class TestBrokenTCPFastOpen(DNSDistTest): @classmethod def BrokenTCPResponder(cls, port): + cls._backgroundThreads[threading.get_native_id()] = True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) @@ -40,8 +41,17 @@ class TestBrokenTCPFastOpen(DNSDistTest): sys.exit(1) sock.listen(100) + sock.settimeout(1.0) while True: - (conn, _) = sock.accept() + try: + (conn, _) = sock.accept() + except socket.timeout: + if cls._backgroundThreads.get(threading.get_native_id(), False) == False: + del cls._backgroundThreads[threading.get_native_id()] + break + else: + continue + conn.settimeout(5.0) data = conn.recv(2) if not data: