]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Optimize policy with already sorted servers and add tag testing 15670/head
authorOliver Chen <oliver.chen@nokia-sbell.com>
Tue, 17 Jun 2025 03:01:02 +0000 (03:01 +0000)
committerOliver Chen <oliver.chen@nokia-sbell.com>
Tue, 17 Jun 2025 03:01:02 +0000 (03:01 +0000)
pdns/dnsdistdist/dnsdist-lbpolicies.cc
regression-tests.dnsdist/test_Routing.py

index 343cdf19334676ea65688de1efc842538207b81e..ac17eb5efab1ba558230e3a1c350c4e2b7ce5e74 100644 (file)
@@ -269,16 +269,15 @@ shared_ptr<DownstreamState> orderedWrandUntag(const ServerPolicy::NumberedServer
   candidates.reserve(servers.size());
 
   int curOrder = std::numeric_limits<int>::max();
-  unsigned int startIndex = 0;
   unsigned int curNumber = 1;
 
   for (const auto& svr : servers) {
-    if (svr.second->isUp() && svr.second->d_config.order <= curOrder && (!dnsq->ids.qTag || dnsq->ids.qTag->count(svr.second->getNameWithAddr()) == 0)) {
-      if (svr.second->d_config.order < curOrder) {
-          curOrder = svr.second->d_config.order;
-          startIndex = candidates.size();
-          curNumber = 1;
+    if (svr.second->isUp() && (!dnsq->ids.qTag || dnsq->ids.qTag->count(svr.second->getNameWithAddr()) == 0)) {
+      // the servers in a pool are already sorted in ascending order by its 'order', see ``ServerPool::addServer()``
+      if (svr.second->d_config.order > curOrder) {
+        break;
       }
+      curOrder = svr.second->d_config.order;
       candidates.push_back(ServerPolicy::NumberedServer(curNumber++, svr.second));
     }
   }
@@ -287,8 +286,7 @@ shared_ptr<DownstreamState> orderedWrandUntag(const ServerPolicy::NumberedServer
     return {};
   }
 
-  ServerPolicy::NumberedServerVector selected(candidates.begin() + startIndex, candidates.end());
-  return wrandom(selected, dnsq);
+  return wrandom(candidates, dnsq);
 }
 
 std::shared_ptr<const ServerPolicy::NumberedServerVector> getDownstreamCandidates(const std::string& poolName)
index c8e65cf5402fb9a41110a583fa8ee1b7e5560546..8da44193f3e5652b615286fd4664a9f0da28f199 100644 (file)
@@ -1009,6 +1009,7 @@ class QueryCounter:
 
     def __init__(self, name):
         self.name = name
+        self.refuse = False
         self.qcnt = 0
 
     def __call__(self):
@@ -1017,6 +1018,9 @@ class QueryCounter:
     def reset(self):
         self.qcnt = 0
 
+    def set_refuse(self, flag):
+        self.refuse = True if flag else False
+
     def create_cb(self):
         def callback(request):
             self.qcnt += 1
@@ -1026,7 +1030,7 @@ class QueryCounter:
                                 dns.rdataclass.IN,
                                 dns.rdatatype.A,
                                 '127.0.0.1')
-            response.answer.append(rrset)
+            response.set_rcode(dns.rcode.REFUSED) if self.refuse else response.answer.append(rrset)
             return response.to_wire()
         return callback
 
@@ -1054,13 +1058,21 @@ class TestRoutingOrderedWRandUntag(DNSDistTest):
     s21:setUp()
     s22 = newServer{name="s22", address="127.0.0.1:%s", order=2, weight=2}
     s22:setUp()
-    function setServerDown(name)
+    function setServerState(name, flag)
         for _, s in ipairs(getServers()) do
             if s.name == name then
-                s:setDown()
+                if flag then s:setUp() else s:setDown() end
             end
         end
     end
+    function makeQueryRestartable(dq) dq:setRestartable() return DNSAction.None end
+    addAction(AllRule(), LuaAction(makeQueryRestartable))
+    function restartQuery(dr)
+        dr:setTag(dr:getSelectedBackend():getNameWithAddr(), "1")
+        dr:restart()
+        return DNSResponseAction.None
+    end
+    addResponseAction(RCodeRule(DNSRCode.REFUSED), LuaResponseAction(restartQuery))
     """
 
     @classmethod
@@ -1074,10 +1086,13 @@ class TestRoutingOrderedWRandUntag(DNSDistTest):
             responder.daemon = True
             responder.start()
 
-    def setDown(self, name):
-        self.sendConsoleCommand("setServerDown('{}')".format(name))
+    def setServerUp(self, name):
+        self.sendConsoleCommand("setServerState('{}', true)".format(name))
 
-    def testDefault(self):
+    def setServerDown(self, name):
+        self.sendConsoleCommand("setServerState('{}', false)".format(name))
+
+    def testPolicy(self):
         """
         Routing: orderedWrandUntag
 
@@ -1095,6 +1110,8 @@ class TestRoutingOrderedWRandUntag(DNSDistTest):
                                     '127.0.0.1')
         expectedResponse.answer.append(rrset)
 
+        ### test normal first ordered then random weighted routing ###
+
         # send 100 queries
         for _ in range(numberOfQueries):
             (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
@@ -1107,11 +1124,38 @@ class TestRoutingOrderedWRandUntag(DNSDistTest):
         self.assertEqual(self._queryCounts['s21'](),  0)
         self.assertEqual(self._queryCounts['s22'](),  0)
 
+        ### test tagged servers for restart
+
+        # reset counters
+        for name in ['s11', 's12', 's21', 's22']:
+            self._queryCounts[name].reset()
+
+        self._queryCounts['s11'].set_refuse(True)
+        self.setServerDown('s12')
+
+        # send 100 queries
+        for _ in range(numberOfQueries):
+            (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+        # s11 receives all 100 initial queries and always refuse to trigger restart
+        # s12 is not selected for both initial and restarted queries
+        # s21+s22 shall receive all the 100 restarted queries
+        self.assertEqual(self._queryCounts['s11'](),  numberOfQueries)
+        self.assertEqual(self._queryCounts['s12'](),  0)
+        self.assertEqual(self._queryCounts['s21']()+self._queryCounts['s22'](), numberOfQueries)
+
+        self._queryCounts['s11'].set_refuse(False)
+        self.setServerUp('s12')
+
+        ### further test server down conditions ###
+
         # reset counters
         for name in ['s11', 's12', 's21', 's22']:
             self._queryCounts[name].reset()
 
-        self.setDown('s11')
+        self.setServerDown('s11')
 
         # send 100 queries
         for _ in range(numberOfQueries):
@@ -1129,7 +1173,7 @@ class TestRoutingOrderedWRandUntag(DNSDistTest):
         for name in ['s11', 's12', 's21', 's22']:
             self._queryCounts[name].reset()
 
-        self.setDown('s12')
+        self.setServerDown('s12')
 
         # send 100 queries
         for _ in range(numberOfQueries):