From 86e601ef18df9d4fb37f296083c44de5cfa4f802 Mon Sep 17 00:00:00 2001 From: Oliver Chen Date: Sat, 5 Apr 2025 05:05:17 +0000 Subject: [PATCH] Add regression test case for timeout response action --- regression-tests.dnsdist/dnsdisttests.py | 12 ++ .../test_TimeoutResponse.py | 152 ++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 regression-tests.dnsdist/test_TimeoutResponse.py diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index cc9bd25b9f..a365498525 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -63,6 +63,12 @@ def pickAvailablePort(): workerPorts[workerID] = port return port +class DropAction(object): + """ + An object to indicate a drop action shall be taken + """ + pass + class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): """ Set up a dnsdist instance and responder threads. @@ -351,6 +357,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): if not wire: continue + elif isinstance(wire, DropAction): + continue sock.sendto(wire, addr) @@ -392,6 +400,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): if not wire: conn.close() return + elif isinstance(wire, DropAction): + return wireLen = struct.pack("!H", len(wire)) if partialWrite: @@ -556,6 +566,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): conn.close() conn = None break + elif isinstance(wire, DropAction): + break headers = [ (':status', str(status)), diff --git a/regression-tests.dnsdist/test_TimeoutResponse.py b/regression-tests.dnsdist/test_TimeoutResponse.py new file mode 100644 index 0000000000..5cf8a347a5 --- /dev/null +++ b/regression-tests.dnsdist/test_TimeoutResponse.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +import ssl +import threading +import dns +from dnsdisttests import DNSDistTest, pickAvailablePort, DropAction + +_common_config = """ + addDOHLocal("127.0.0.1:%d", "server.chain", "server.key", {'/dns-query'}, {library='nghttp2'}) + addDOQLocal("127.0.0.1:%d", "server.chain", "server.key") + addDOH3Local("127.0.0.1:%d", "server.chain", "server.key") + addTLSLocal("127.0.0.1:%d", "server.chain", "server.key") + + function makeQueryRestartable(dq) + dq:setRestartable() + return DNSAction.None + end + + function restartQuery(dr) + if dr.pool ~= 'restarted' then + dr.pool = 'restarted' + dr:restart() + end + return DNSResponseAction.None + end + + addAction(AllRule(), LuaAction(makeQueryRestartable)) + addTimeoutResponseAction(AllRule(), LuaResponseAction(restartQuery)) +""" + +def timeoutResponseCallback(request): + return DropAction() + +def normalResponseCallback(request): + response = dns.message.make_response(request) + rrset = dns.rrset.from_text(request.question[0].name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + return response.to_wire() + +def dohTimeoutResponseCallback(request, headers, fromQueue, toQueue): + return 200, timeoutResponseCallback(request) + +def dohNormalResponseCallback(request, headers, fromQueue, toQueue): + return 200, normalResponseCallback(request) + +class TestTimeoutBackendUdpTcp(DNSDistTest): + + # this test suite uses different responder ports + _testNormalServerPort = pickAvailablePort() + _testTimeoutServerPort = pickAvailablePort() + _dohWithNGHTTP2ServerPort = pickAvailablePort() + _doqServerPort = pickAvailablePort() + _doh3ServerPort = pickAvailablePort() + _tlsServerPort = pickAvailablePort() + + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _dohWithNGHTTP2BaseURL = ("https://%s:%d/dns-query" % ("127.0.0.1", _dohWithNGHTTP2ServerPort)) + _dohBaseURL = ("https://%s:%d/" % (_serverName, _doh3ServerPort)) + + _config_template = """ + newServer{address="127.0.0.1:%d",pool='restarted',udpTimeout=2,tcpRecvTimeout=2}:setUp() + newServer{address="127.0.0.1:%d",pool='',udpTimeout=2,tcpRecvTimeout=2}:setUp() + """ + _common_config + _config_params = ['_testNormalServerPort', '_testTimeoutServerPort', '_dohWithNGHTTP2ServerPort', '_doqServerPort', '_doh3ServerPort', '_tlsServerPort'] + _verboseMode = True + + @classmethod + def startResponders(cls): + print("Launching responders..") + + # timeout + cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testTimeoutServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, timeoutResponseCallback]) + cls._UDPResponder.daemon = True + cls._UDPResponder.start() + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testTimeoutServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, timeoutResponseCallback]) + cls._TCPResponder.daemon = True + cls._TCPResponder.start() + cls._UDPResponderNormal = threading.Thread(name='UDP ResponderNormal', target=cls.UDPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, normalResponseCallback]) + cls._UDPResponderNormal.daemon = True + cls._UDPResponderNormal.start() + cls._TCPResponderNormal = threading.Thread(name='TCP ResponderNormal', target=cls.TCPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, normalResponseCallback]) + cls._TCPResponderNormal.daemon = True + cls._TCPResponderNormal.start() + + def testTimeoutRestartQuery(self): + """ + Restart: Timeout then restarted to a second pool + """ + name = 'timeout.restart.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery", "sendDOQQueryWrapper", "sendDOH3QueryWrapper", "sendDOTQueryWrapper", "sendDOHWithNGHTTP2QueryWrapper"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False, timeout=6) + self.assertTrue(receivedResponse) + self.assertEqual(receivedResponse, expectedResponse) + +class TestTimeoutBackendDOH(TestTimeoutBackendUdpTcp): + + _config_template = """ + newServer{address="127.0.0.1:%d",pool='restarted',udpTimeout=2,tcpRecvTimeout=2,tls='openssl',validateCertificates=true,caStore='ca.pem',subjectName='powerdns.com',dohPath='/dns-query'}:setUp() + newServer{address="127.0.0.1:%d",pool='',udpTimeout=2,tcpRecvTimeout=2,tls='openssl',validateCertificates=true,caStore='ca.pem',subjectName='powerdns.com',dohPath='/dns-query'}:setUp() + """ + _common_config + + @classmethod + def startResponders(cls): + + # timeout + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching DOH responder..") + cls._DOHResponder = threading.Thread(name='DOH Responder', target=cls.DOHResponder, args=[cls._testTimeoutServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, dohTimeoutResponseCallback, tlsContext]) + cls._DOHResponder.daemon = True + cls._DOHResponder.start() + + cls._DOHResponder = threading.Thread(name='DOH ResponderNormal', target=cls.DOHResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, dohNormalResponseCallback, tlsContext]) + cls._DOHResponder.daemon = True + cls._DOHResponder.start() + +class TestTimeoutBackendDOT(TestTimeoutBackendUdpTcp): + + _config_template = """ + newServer{address="127.0.0.1:%d",pool='restarted',udpTimeout=2,tcpRecvTimeout=2,tls='openssl',validateCertificates=true,caStore='ca.pem',subjectName='powerdns.com'}:setUp() + newServer{address="127.0.0.1:%d",pool='',udpTimeout=2,tcpRecvTimeout=2,tls='openssl',validateCertificates=true,caStore='ca.pem',subjectName='powerdns.com'}:setUp() + """ + _common_config + + @classmethod + def startResponders(cls): + + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain('server.chain', 'server.key') + + print("Launching TLS responder..") + cls._TLSResponder = threading.Thread(name='TLS Responder', target=cls.TCPResponder, args=[cls._testTimeoutServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, timeoutResponseCallback, tlsContext]) + cls._TLSResponder.daemon = True + cls._TLSResponder.start() + + cls._TLSResponder = threading.Thread(name='TLS ResponderNormal', target=cls.TCPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, normalResponseCallback, tlsContext]) + cls._TLSResponder.daemon = True + cls._TLSResponder.start() -- 2.47.2