From: Oliver Chen Date: Tue, 17 Jun 2025 03:49:48 +0000 (+0000) Subject: Add regression test for the restart counting X-Git-Tag: dnsdist-2.0.0-beta1~13^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e40f099cc70c09e625ab6d694afcf6810e02a3cd;p=thirdparty%2Fpdns.git Add regression test for the restart counting --- diff --git a/regression-tests.dnsdist/test_RestartQuery.py b/regression-tests.dnsdist/test_RestartQuery.py index 48a6f30c3e..27d9bf3e30 100644 --- a/regression-tests.dnsdist/test_RestartQuery.py +++ b/regression-tests.dnsdist/test_RestartQuery.py @@ -153,3 +153,89 @@ class TestRestartProxyProtocolThenNot(DNSDistTest): (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) self.assertTrue(receivedProxyPayload) self.assertTrue(receivedDNSData) + +class QueryCounter: + + def __init__(self, name): + self.name = name + self.qcnt = 0 + + def __call__(self): + return self.qcnt + + def create_cb(self): + def callback(request): + self.qcnt += 1 + response = dns.message.make_response(request) + response.set_rcode(dns.rcode.REFUSED) + return response.to_wire() + return callback + +class TestRestartCount(DNSDistTest): + + _queryCounts = {} + + _testServer1Port = pickAvailablePort() + _testServer2Port = pickAvailablePort() + _testServer3Port = pickAvailablePort() + _testServer4Port = pickAvailablePort() + _serverPorts = [_testServer1Port, _testServer2Port, _testServer3Port, _testServer4Port] + _config_params = ['_testServer1Port', '_testServer2Port', '_testServer3Port', '_testServer4Port'] + _config_template = """ + MaxRestart = 2 + s0 = newServer{name="s0", address="127.0.0.1:%s"} + s0:setUp() + s0:addPool("pool0") + s1 = newServer{name="s1", address="127.0.0.1:%s"} + s1:setUp() + s1:addPool("pool1") + s2 = newServer{name="s2", address="127.0.0.1:%s"} + s2:setUp() + s2:addPool("pool2") + s3 = newServer{name="s3", address="127.0.0.1:%s"} + s3:setUp() + s3:addPool("pool3") + function makeQueryRestartable(dq) dq:setRestartable() return DNSAction.None end + addAction(AllRule(), LuaAction(makeQueryRestartable)) + function restartQuery(dr) + if dr:getRestartCount() < MaxRestart then + dr.pool = "pool"..tostring(dr:getRestartCount() + 1) + dr:restart() + else + return DNSResponseAction.ServFail + end + return DNSResponseAction.None + end + addResponseAction(RCodeRule(DNSRCode.REFUSED), LuaResponseAction(restartQuery)) + addAction(AllRule(), PoolAction("pool0")) + """ + @classmethod + def startResponders(cls): + print("Launching responders..") + + for i, name in enumerate(['s0', 's1', 's2', 's3']): + cls._queryCounts[name] = QueryCounter(name) + cb = cls._queryCounts[name].create_cb() + responder = threading.Thread(name=name, target=cls.UDPResponder, args=[cls._serverPorts[i], cls._toResponderQueue, cls._fromResponderQueue, False, cb]) + responder.daemon = True + responder.start() + + def testDefault(self): + + numberOfQueries = 100 + name = 'restart.count.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + expectedResponse = dns.message.make_response(query) + expectedResponse.set_rcode(dns.rcode.SERVFAIL) + + # send 100 queries + for _ in range(numberOfQueries): + (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEqual(expectedResponse, receivedResponse) + + # if restart count is correct, s0/s1/s2 would get all the queies while s3 would get none + self.assertEqual(self._queryCounts['s0'](), numberOfQueries) + self.assertEqual(self._queryCounts['s1'](), numberOfQueries) + self.assertEqual(self._queryCounts['s2'](), numberOfQueries) + self.assertEqual(self._queryCounts['s3'](), 0)