]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: support server state change lua callback
authorOliver Chen <oliver.chen@nokia-sbell.com>
Mon, 23 Jun 2025 06:15:06 +0000 (06:15 +0000)
committerOliver Chen <oliver.chen@nokia-sbell.com>
Mon, 23 Jun 2025 11:49:39 +0000 (11:49 +0000)
pdns/dnsdistdist/dnsdist-backend.cc
pdns/dnsdistdist/dnsdist-lua-hooks.cc
pdns/dnsdistdist/dnsdist-lua-hooks.hh
pdns/dnsdistdist/dnsdist.cc
pdns/dnsdistdist/dnsdist.hh
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/test-dnsdist_cc.cc
regression-tests.dnsdist/test_HealthChecks.py

index 6664084fd6145aeeb827486336ff29bca85b9a98..ab32818a2060f3ff328494fc6790d04f61c96238 100644 (file)
@@ -817,6 +817,7 @@ void DownstreamState::submitHealthCheckResult(bool initial, bool newResult)
         updateNextLazyHealthCheck(*stats, false);
       }
     }
+    handleServerStateChange(getNameWithAddr(), newResult);
     return;
   }
 
@@ -891,6 +892,7 @@ void DownstreamState::submitHealthCheckResult(bool initial, bool newResult)
     if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) {
       g_snmpAgent->sendBackendStatusChangeTrap(*this);
     }
+    handleServerStateChange(getNameWithAddr(), newResult);
   }
 }
 
index 45cc25e1e6fba730895ff1e9d2573dbc48518625..25718ab417bfc920228ce1bb5216592cef6899a3 100644 (file)
@@ -9,9 +9,11 @@ namespace dnsdist::lua::hooks
 using ExitCallback = std::function<void()>;
 using MaintenanceCallback = std::function<void()>;
 using TicketsKeyAddedHook = std::function<void(const std::string&, size_t)>;
+using ServerStateChangeCallback = std::function<void(const std::string&, bool)>;
 
 static LockGuarded<std::vector<ExitCallback>> s_exitCallbacks;
 static LockGuarded<std::vector<MaintenanceCallback>> s_maintenanceHooks;
+static LockGuarded<std::vector<ServerStateChangeCallback>> s_serverStateChangeHooks;
 
 void runMaintenanceHooks(const LuaContext& context)
 {
@@ -65,6 +67,25 @@ static void setTicketsKeyAddedHook(const LuaContext& context, const TicketsKeyAd
   });
 }
 
+void runServerStateChangeHooks(const LuaContext& context, const std::string& nameWithAddr, bool newState)
+{
+  (void)context;
+  for (const auto& callback : *(s_serverStateChangeHooks.lock())) {
+    callback(nameWithAddr, newState);
+  }
+}
+
+static void addServerStateChangeCallback(const LuaContext& context, ServerStateChangeCallback callback)
+{
+  (void)context;
+  s_serverStateChangeHooks.lock()->push_back(std::move(callback));
+}
+
+void clearServerStateChangeCallbacks()
+{
+  s_serverStateChangeHooks.lock()->clear();
+}
+
 void setupLuaHooks(LuaContext& luaCtx)
 {
   luaCtx.writeFunction("addMaintenanceCallback", [&luaCtx](const MaintenanceCallback& callback) {
@@ -79,6 +100,10 @@ void setupLuaHooks(LuaContext& luaCtx)
     setLuaSideEffect();
     setTicketsKeyAddedHook(luaCtx, hook);
   });
+  luaCtx.writeFunction("addServerStateChangeCallback", [&luaCtx](const ServerStateChangeCallback& hook) {
+    setLuaSideEffect();
+    addServerStateChangeCallback(luaCtx, hook);
+  });
 }
 
 }
index a1b8c3ca8f96b5cdff40d5524c19d50f884af1a9..abd0f29161b8196849f7a5800add3c1b2ca18db3 100644 (file)
@@ -22,6 +22,7 @@
 #pragma once
 
 #include <functional>
+#include <string>
 
 class LuaContext;
 
@@ -31,5 +32,6 @@ void runMaintenanceHooks(const LuaContext& context);
 void clearMaintenanceHooks();
 void runExitCallbacks(const LuaContext& context);
 void clearExitCallbacks();
+void runServerStateChangeHooks(const LuaContext& context, const std::string& nameWithAddr, bool newState);
 void setupLuaHooks(LuaContext& luaCtx);
 }
index d89c590ec3ed7dfde7b1bffc8261c81450847e76..9d760f842dfe4cf9fbedb509a802267fdf875465 100644 (file)
@@ -1596,6 +1596,17 @@ bool handleTimeoutResponseRules(const std::vector<dnsdist::rules::ResponseRuleAc
   return dnsResponse.isAsynchronous();
 }
 
+void handleServerStateChange(const string& nameWithAddr, bool newResult)
+{
+  try {
+    auto lua = g_lua.lock();
+    dnsdist::lua::hooks::runServerStateChangeHooks(*lua, nameWithAddr, newResult);
+  }
+  catch (const std::exception& exp) {
+    warnlog("Error calling the Lua hook for Server State Change: %s", exp.what());
+  }
+}
+
 class UDPTCPCrossQuerySender : public TCPQuerySender
 {
 public:
@@ -2821,6 +2832,7 @@ static void cleanupLuaObjects(LuaContext& /* luaCtx */)
   });
   dnsdist::webserver::clearWebHandlers();
   dnsdist::lua::hooks::clearMaintenanceHooks();
+  dnsdist::lua::hooks::clearServerStateChangeCallbacks();
 }
 #endif /* defined(COVERAGE) || (defined(__SANITIZE_ADDRESS__) && defined(HAVE_LEAK_SANITIZER_INTERFACE)) */
 
index d2a8095bbc9d66c4b89001bea2c9eb06d2dcef7b..3f21e76fec3343d652669ebc57630902ff063b87 100644 (file)
@@ -1031,3 +1031,4 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs
 void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol, bool fromBackend);
 void handleResponseSent(const InternalQueryState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, bool fromBackend);
 bool handleTimeoutResponseRules(const std::vector<dnsdist::rules::ResponseRuleAction>& rules, InternalQueryState& ids, const std::shared_ptr<DownstreamState>& ds, const std::shared_ptr<TCPQuerySender>& sender);
+void handleServerStateChange(const std::string& nameWithAddr, bool newResult);
index a3934abd2829cb68761a6c326db3343286aa1e3b..db7315cba9006a61ec1765f0699aa23c88cc8abe 100644 (file)
@@ -2236,6 +2236,23 @@ Other functions
     end
     addMaintenanceCallback(myCallback)
 
+.. function:: addServerStateChangeCallback(callback)
+
+  .. versionadded:: 2.0.0
+
+  Register a Lua function to be called when a server state changed during the health check process.
+  The function should not block for a long period of time, as it would otherwise delay the execution of the other functions registered for this hook, as well as the execution of the health check process.
+
+  :param function callback: The function to be called. It returns no value and takes two parameters, first parameter is a string with format as the same as return from :func:`Server:getNameWithAddr()` to identify the server, second parameter is a bool value indicating server state is up if true else down.
+
+  .. code-block:: lua
+
+    function serverStateChanged(nameAddr, newState)
+      if newState then state = 'up' else state = 'down' end
+      print(string.format('Server State Changed: %s -> %s', nameAddr, state))
+    end
+    addServerStateChangeCallback(serverStateChanged)
+
 
 .. function:: getAddressInfo(hostname, callback)
 
index ebe7142554934f1953cc13940d64df3a2c0b2f74..286f3b42aec72330bc1b8d90719bb002f6b1d288 100644 (file)
@@ -74,6 +74,13 @@ bool handleTimeoutResponseRules(const std::vector<dnsdist::rules::ResponseRuleAc
   return false;
 }
 
+void handleServerStateChange(const string& nameWithAddr, bool newResult)
+{
+  (void)nameWithAddr;
+  (void)newResult;
+  return;
+}
+
 bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
 {
   (void)origFD;
index 1c1557f2146a56579b792aa4bea78c56fcc57761..4348d76b35be76c947affd9add9efee7a72cc305 100644 (file)
@@ -547,3 +547,92 @@ class TestUpdateHCParamsCombo2(HealthCheckUpdateParams):
         time.sleep(1)
         # now should timeout and failure increased
         self.assertEqual(self.getBackendMetric(0, 'healthCheckFailures'), beforeFailure+1)
+
+class TestServerStateChange(HealthCheckTest):
+
+    _healthQueue = queue.Queue()
+    _dropHealthCheck = False
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%d")
+    webserver("127.0.0.1:%s")
+    setWebserverConfig({apiKey="%s"})
+    srv = newServer{address="127.0.0.1:%d",maxCheckFailures=1,checkTimeout=1000,checkInterval=1,rise=1}
+    srv:setAuto(false)
+    serverUpCount = {}
+    serverDownCount = {}
+    function ServerStateChange(nameAddr, newState)
+        if newState then
+            if not serverUpCount[nameAddr] then serverUpCount[nameAddr] = 0 end
+            serverUpCount[nameAddr] = serverUpCount[nameAddr] + 1
+        else
+            if not serverDownCount[nameAddr] then serverDownCount[nameAddr] = 0 end
+            serverDownCount[nameAddr] = serverDownCount[nameAddr] + 1
+        end
+    end
+    addServerStateChangeCallback(ServerStateChange)
+    function getCount(nameAddr, state)
+        if state then
+            if not serverUpCount[nameAddr] then serverUpCount[nameAddr] = 0 end
+            return serverUpCount[nameAddr]
+        else
+            if not serverDownCount[nameAddr] then serverDownCount[nameAddr] = 0 end
+            return serverDownCount[nameAddr]
+        end
+    end
+    """
+
+    @classmethod
+    def startResponders(cls):
+        print("Launching responders..")
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, cls.healthCallback])
+        cls._UDPResponder.daemon = True
+        cls._UDPResponder.start()
+
+    @classmethod
+    def healthCallback(cls, request):
+        if cls._dropHealthCheck:
+          cls._healthQueue.put(False)
+          print("health check received drop")
+          return ResponderDropAction()
+        response = dns.message.make_response(request)
+        cls._healthQueue.put(True)
+        print("health check received return")
+        return response.to_wire()
+
+    @classmethod
+    def setDrop(cls, flag=True):
+        cls._dropHealthCheck = flag
+
+    def getCount(self, nameAddr, state):
+        if state:
+            return int(self.sendConsoleCommand("getCount('{}', true)".format(nameAddr)).strip("\n"))
+        return int(self.sendConsoleCommand("getCount('{}', false)".format(nameAddr)).strip("\n"))
+
+    def testServerStateChange(self):
+        """
+        HealthChecks: test Server State Change callback
+        """
+
+        nameAddr = self.sendConsoleCommand("getServer(0):getNameWithAddr()").strip("\n")
+        self.assertTrue(nameAddr)
+
+        time.sleep(1)
+        # server initial up shall have been hit
+        self.assertEqual(self.getBackendStatus(), 'up')
+        self.assertEqual(self.getCount(nameAddr, True), 1)
+        self.assertEqual(self.getCount(nameAddr, False), 0)
+
+        self.setDrop(True)
+        time.sleep(2.5)
+        # up count no change, down count increased by 1
+        self.assertEqual(self.getBackendStatus(), 'down')
+        self.assertEqual(self.getCount(nameAddr, True), 1)
+        self.assertEqual(self.getCount(nameAddr, False), 1)
+
+        self.setDrop(False)
+        time.sleep(1.5)
+        # up count increased again, down count no change
+        self.assertEqual(self.getBackendStatus(), 'up')
+        self.assertEqual(self.getCount(nameAddr, True), 2)
+        self.assertEqual(self.getCount(nameAddr, False), 1)