]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Add regression test for the restart counting
authorOliver Chen <oliver.chen@nokia-sbell.com>
Tue, 17 Jun 2025 03:49:48 +0000 (03:49 +0000)
committerOliver Chen <oliver.chen@nokia-sbell.com>
Tue, 17 Jun 2025 03:49:48 +0000 (03:49 +0000)
regression-tests.dnsdist/test_RestartQuery.py

index 48a6f30c3ef7a72bbfa8a45bd0cc45fe57b430b5..27d9bf3e30e3d3bd9bd2730a101de92c279bbf65 100644 (file)
@@ -153,3 +153,89 @@ class TestRestartProxyProtocolThenNot(DNSDistTest):
             (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
             self.assertTrue(receivedProxyPayload)
             self.assertTrue(receivedDNSData)
+
+class QueryCounter:
+
+    def __init__(self, name):
+        self.name = name
+        self.qcnt = 0
+
+    def __call__(self):
+        return self.qcnt
+
+    def create_cb(self):
+        def callback(request):
+            self.qcnt += 1
+            response = dns.message.make_response(request)
+            response.set_rcode(dns.rcode.REFUSED)
+            return response.to_wire()
+        return callback
+
+class TestRestartCount(DNSDistTest):
+
+    _queryCounts = {}
+
+    _testServer1Port = pickAvailablePort()
+    _testServer2Port = pickAvailablePort()
+    _testServer3Port = pickAvailablePort()
+    _testServer4Port = pickAvailablePort()
+    _serverPorts = [_testServer1Port, _testServer2Port, _testServer3Port, _testServer4Port]
+    _config_params = ['_testServer1Port', '_testServer2Port', '_testServer3Port', '_testServer4Port']
+    _config_template = """
+    MaxRestart = 2
+    s0 = newServer{name="s0", address="127.0.0.1:%s"}
+    s0:setUp()
+    s0:addPool("pool0")
+    s1 = newServer{name="s1", address="127.0.0.1:%s"}
+    s1:setUp()
+    s1:addPool("pool1")
+    s2 = newServer{name="s2", address="127.0.0.1:%s"}
+    s2:setUp()
+    s2:addPool("pool2")
+    s3 = newServer{name="s3", address="127.0.0.1:%s"}
+    s3:setUp()
+    s3:addPool("pool3")
+    function makeQueryRestartable(dq) dq:setRestartable() return DNSAction.None end
+    addAction(AllRule(), LuaAction(makeQueryRestartable))
+    function restartQuery(dr)
+        if dr:getRestartCount() < MaxRestart then
+            dr.pool = "pool"..tostring(dr:getRestartCount() + 1)
+            dr:restart()
+        else
+            return DNSResponseAction.ServFail
+        end
+        return DNSResponseAction.None
+    end
+    addResponseAction(RCodeRule(DNSRCode.REFUSED), LuaResponseAction(restartQuery))
+    addAction(AllRule(), PoolAction("pool0"))
+    """
+    @classmethod
+    def startResponders(cls):
+        print("Launching responders..")
+
+        for i, name in enumerate(['s0', 's1', 's2', 's3']):
+            cls._queryCounts[name] = QueryCounter(name)
+            cb = cls._queryCounts[name].create_cb()
+            responder = threading.Thread(name=name, target=cls.UDPResponder, args=[cls._serverPorts[i], cls._toResponderQueue, cls._fromResponderQueue, False, cb])
+            responder.daemon = True
+            responder.start()
+
+    def testDefault(self):
+
+        numberOfQueries = 100
+        name = 'restart.count.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.SERVFAIL)
+
+        # send 100 queries
+        for _ in range(numberOfQueries):
+            (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+        # if restart count is correct, s0/s1/s2 would get all the queies while s3 would get none
+        self.assertEqual(self._queryCounts['s0'](), numberOfQueries)
+        self.assertEqual(self._queryCounts['s1'](), numberOfQueries)
+        self.assertEqual(self._queryCounts['s2'](), numberOfQueries)
+        self.assertEqual(self._queryCounts['s3'](),  0)