updateNextLazyHealthCheck(*stats, false);
}
}
+ handleServerStateChange(getNameWithAddr(), newResult);
return;
}
if (g_snmpAgent != nullptr && dnsdist::configuration::getImmutableConfiguration().d_snmpTrapsEnabled) {
g_snmpAgent->sendBackendStatusChangeTrap(*this);
}
+ handleServerStateChange(getNameWithAddr(), newResult);
}
}
using ExitCallback = std::function<void()>;
using MaintenanceCallback = std::function<void()>;
using TicketsKeyAddedHook = std::function<void(const std::string&, size_t)>;
+using ServerStateChangeCallback = std::function<void(const std::string&, bool)>;
static LockGuarded<std::vector<ExitCallback>> s_exitCallbacks;
static LockGuarded<std::vector<MaintenanceCallback>> s_maintenanceHooks;
+static LockGuarded<std::vector<ServerStateChangeCallback>> s_serverStateChangeHooks;
void runMaintenanceHooks(const LuaContext& context)
{
});
}
+void runServerStateChangeHooks(const LuaContext& context, const std::string& nameWithAddr, bool newState)
+{
+ (void)context;
+ for (const auto& callback : *(s_serverStateChangeHooks.lock())) {
+ callback(nameWithAddr, newState);
+ }
+}
+
+static void addServerStateChangeCallback(const LuaContext& context, ServerStateChangeCallback callback)
+{
+ (void)context;
+ s_serverStateChangeHooks.lock()->push_back(std::move(callback));
+}
+
+void clearServerStateChangeCallbacks()
+{
+ s_serverStateChangeHooks.lock()->clear();
+}
+
void setupLuaHooks(LuaContext& luaCtx)
{
luaCtx.writeFunction("addMaintenanceCallback", [&luaCtx](const MaintenanceCallback& callback) {
setLuaSideEffect();
setTicketsKeyAddedHook(luaCtx, hook);
});
+ luaCtx.writeFunction("addServerStateChangeCallback", [&luaCtx](const ServerStateChangeCallback& hook) {
+ setLuaSideEffect();
+ addServerStateChangeCallback(luaCtx, hook);
+ });
}
}
#pragma once
#include <functional>
+#include <string>
class LuaContext;
void clearMaintenanceHooks();
void runExitCallbacks(const LuaContext& context);
void clearExitCallbacks();
+void runServerStateChangeHooks(const LuaContext& context, const std::string& nameWithAddr, bool newState);
void setupLuaHooks(LuaContext& luaCtx);
}
return dnsResponse.isAsynchronous();
}
+void handleServerStateChange(const string& nameWithAddr, bool newResult)
+{
+ try {
+ auto lua = g_lua.lock();
+ dnsdist::lua::hooks::runServerStateChangeHooks(*lua, nameWithAddr, newResult);
+ }
+ catch (const std::exception& exp) {
+ warnlog("Error calling the Lua hook for Server State Change: %s", exp.what());
+ }
+}
+
class UDPTCPCrossQuerySender : public TCPQuerySender
{
public:
});
dnsdist::webserver::clearWebHandlers();
dnsdist::lua::hooks::clearMaintenanceHooks();
+ dnsdist::lua::hooks::clearServerStateChangeCallbacks();
}
#endif /* defined(COVERAGE) || (defined(__SANITIZE_ADDRESS__) && defined(HAVE_LEAK_SANITIZER_INTERFACE)) */
void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol, bool fromBackend);
void handleResponseSent(const InternalQueryState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, bool fromBackend);
bool handleTimeoutResponseRules(const std::vector<dnsdist::rules::ResponseRuleAction>& rules, InternalQueryState& ids, const std::shared_ptr<DownstreamState>& ds, const std::shared_ptr<TCPQuerySender>& sender);
+void handleServerStateChange(const std::string& nameWithAddr, bool newResult);
end
addMaintenanceCallback(myCallback)
+.. function:: addServerStateChangeCallback(callback)
+
+ .. versionadded:: 2.0.0
+
+ Register a Lua function to be called when a server state changed during the health check process.
+ The function should not block for a long period of time, as it would otherwise delay the execution of the other functions registered for this hook, as well as the execution of the health check process.
+
+ :param function callback: The function to be called. It returns no value and takes two parameters, first parameter is a string with format as the same as return from :func:`Server:getNameWithAddr()` to identify the server, second parameter is a bool value indicating server state is up if true else down.
+
+ .. code-block:: lua
+
+ function serverStateChanged(nameAddr, newState)
+ if newState then state = 'up' else state = 'down' end
+ print(string.format('Server State Changed: %s -> %s', nameAddr, state))
+ end
+ addServerStateChangeCallback(serverStateChanged)
+
.. function:: getAddressInfo(hostname, callback)
return false;
}
+void handleServerStateChange(const string& nameWithAddr, bool newResult)
+{
+ (void)nameWithAddr;
+ (void)newResult;
+ return;
+}
+
bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
{
(void)origFD;
time.sleep(1)
# now should timeout and failure increased
self.assertEqual(self.getBackendMetric(0, 'healthCheckFailures'), beforeFailure+1)
+
+class TestServerStateChange(HealthCheckTest):
+
+ _healthQueue = queue.Queue()
+ _dropHealthCheck = False
+ _config_template = """
+ setKey("%s")
+ controlSocket("127.0.0.1:%d")
+ webserver("127.0.0.1:%s")
+ setWebserverConfig({apiKey="%s"})
+ srv = newServer{address="127.0.0.1:%d",maxCheckFailures=1,checkTimeout=1000,checkInterval=1,rise=1}
+ srv:setAuto(false)
+ serverUpCount = {}
+ serverDownCount = {}
+ function ServerStateChange(nameAddr, newState)
+ if newState then
+ if not serverUpCount[nameAddr] then serverUpCount[nameAddr] = 0 end
+ serverUpCount[nameAddr] = serverUpCount[nameAddr] + 1
+ else
+ if not serverDownCount[nameAddr] then serverDownCount[nameAddr] = 0 end
+ serverDownCount[nameAddr] = serverDownCount[nameAddr] + 1
+ end
+ end
+ addServerStateChangeCallback(ServerStateChange)
+ function getCount(nameAddr, state)
+ if state then
+ if not serverUpCount[nameAddr] then serverUpCount[nameAddr] = 0 end
+ return serverUpCount[nameAddr]
+ else
+ if not serverDownCount[nameAddr] then serverDownCount[nameAddr] = 0 end
+ return serverDownCount[nameAddr]
+ end
+ end
+ """
+
+ @classmethod
+ def startResponders(cls):
+ print("Launching responders..")
+ cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, cls.healthCallback])
+ cls._UDPResponder.daemon = True
+ cls._UDPResponder.start()
+
+ @classmethod
+ def healthCallback(cls, request):
+ if cls._dropHealthCheck:
+ cls._healthQueue.put(False)
+ print("health check received drop")
+ return ResponderDropAction()
+ response = dns.message.make_response(request)
+ cls._healthQueue.put(True)
+ print("health check received return")
+ return response.to_wire()
+
+ @classmethod
+ def setDrop(cls, flag=True):
+ cls._dropHealthCheck = flag
+
+ def getCount(self, nameAddr, state):
+ if state:
+ return int(self.sendConsoleCommand("getCount('{}', true)".format(nameAddr)).strip("\n"))
+ return int(self.sendConsoleCommand("getCount('{}', false)".format(nameAddr)).strip("\n"))
+
+ def testServerStateChange(self):
+ """
+ HealthChecks: test Server State Change callback
+ """
+
+ nameAddr = self.sendConsoleCommand("getServer(0):getNameWithAddr()").strip("\n")
+ self.assertTrue(nameAddr)
+
+ time.sleep(1)
+ # server initial up shall have been hit
+ self.assertEqual(self.getBackendStatus(), 'up')
+ self.assertEqual(self.getCount(nameAddr, True), 1)
+ self.assertEqual(self.getCount(nameAddr, False), 0)
+
+ self.setDrop(True)
+ time.sleep(2.5)
+ # up count no change, down count increased by 1
+ self.assertEqual(self.getBackendStatus(), 'down')
+ self.assertEqual(self.getCount(nameAddr, True), 1)
+ self.assertEqual(self.getCount(nameAddr, False), 1)
+
+ self.setDrop(False)
+ time.sleep(1.5)
+ # up count increased again, down count no change
+ self.assertEqual(self.getBackendStatus(), 'up')
+ self.assertEqual(self.getCount(nameAddr, True), 2)
+ self.assertEqual(self.getCount(nameAddr, False), 1)