]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Use atomic type for potential read/write race condition
authorOliver Chen <oliver.chen@nokia-sbell.com>
Wed, 30 Apr 2025 03:40:22 +0000 (03:40 +0000)
committerOliver Chen <oliver.chen@nokia-sbell.com>
Wed, 30 Apr 2025 03:44:20 +0000 (03:44 +0000)
Only a few numerical healthcheck parameters are selected,
and changed to use atomic type for those parameters so as to
avoid potential read/write race conditions.

pdns/dnsdistdist/dnsdist-backend.cc
pdns/dnsdistdist/dnsdist-lua-bindings.cc
pdns/dnsdistdist/dnsdist.hh
pdns/dnsdistdist/docs/reference/config.rst
regression-tests.dnsdist/test_HealthChecks.py

index d1f144ee2267897cda82025337174968cea94c4e..a9dfe89a8e30cf1c967999c012ce4a5e848d60c2 100644 (file)
@@ -751,7 +751,7 @@ void DownstreamState::updateNextLazyHealthCheck(LazyHealthCheckStats& stats, boo
       /* we are still in the "up" state, we need to send the next query quickly to
          determine if the backend is really down */
       stats.d_nextCheck = now + d_config.checkInterval;
-      vinfolog("Backend %s is in potential failure state, next check in %d seconds", getNameWithAddr(), d_config.checkInterval);
+      vinfolog("Backend %s is in potential failure state, next check in %d seconds", getNameWithAddr(), d_config.checkInterval.load());
     }
     else if (consecutiveSuccessfulChecks > 0) {
       /* we are in 'Failed' state, but just had one (or more) successful check,
index c0460da562bfa8752c50f94864c3a1006d849d98..a2ea3e69ef58fa1246829cbd3dd164f64c877ef9 100644 (file)
@@ -137,16 +137,23 @@ void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck)
     }
     state.setLazyAuto();
   });
-  luaCtx.registerFunction<void (DownstreamState::*)(boost::optional<LuaAssociativeTable<boost::variant<bool,size_t,std::string>>>)>("setHealthCheckParams", [](DownstreamState& state, boost::optional<LuaAssociativeTable<boost::variant<bool,size_t,std::string>>> vars) {
-    std::string valueStr;
-    getOptionalValue<size_t>(vars, "maxCheckFailures", state.d_config.maxCheckFailures);
-    getOptionalValue<size_t>(vars, "rise", state.d_config.minRiseSuccesses);
-    getOptionalValue<size_t>(vars, "checkTimeout", state.d_config.checkTimeout);
-    getOptionalValue<size_t>(vars, "checkInterval", state.d_config.checkInterval);
-    getOptionalValue<std::string>(vars, "checkType", state.d_config.checkType);
-    getOptionalValue<bool>(vars, "checkTCP", state.d_config.d_tcpCheck);
-    if (getOptionalValue<std::string>(vars, "checkName", valueStr) > 0) {
-      state.d_config.checkName = DNSName(valueStr);
+  luaCtx.registerFunction<void (DownstreamState::*)(boost::optional<LuaAssociativeTable<boost::variant<size_t>>>)>("setHealthCheckParams", [](DownstreamState& state, boost::optional<LuaAssociativeTable<boost::variant<size_t>>> vars) {
+    size_t value = 0;
+    getOptionalValue<size_t>(vars, "maxCheckFailures", value);
+    if (value > 0) {
+      state.d_config.maxCheckFailures.store(value);
+    }
+    getOptionalValue<size_t>(vars, "rise", value);
+    if (value > 0) {
+      state.d_config.minRiseSuccesses.store(value);
+    }
+    getOptionalValue<size_t>(vars, "checkTimeout", value);
+    if (value > 0) {
+      state.d_config.checkTimeout.store(value);
+    }
+    getOptionalValue<size_t>(vars, "checkInterval", value);
+    if (value > 0) {
+      state.d_config.checkInterval.store(value);
     }
   });
   luaCtx.registerFunction<std::string (DownstreamState::*)() const>("getName", [](const DownstreamState& state) -> const std::string& { return state.getName(); });
index ddfaf0d8abb8cd4638909b32cf5433a88f6f74e3..0ae1fbdc878052287d7a0503bfc6096b6b7dad5a 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "config.h"
 
+#include <atomic>
 #include <condition_variable>
 #include <memory>
 #include <mutex>
@@ -543,12 +544,12 @@ struct DownstreamState : public std::enable_shared_from_this<DownstreamState>
     TimeoutOrServFail
   };
 
-  struct Config
+  struct BaseConfig
   {
-    Config()
+    BaseConfig()
     {
     }
-    Config(const ComboAddress& remote_) :
+    BaseConfig(const ComboAddress& remote_) :
       remote(remote_)
     {
     }
@@ -579,20 +580,16 @@ struct DownstreamState : public std::enable_shared_from_this<DownstreamState>
     int tcpRecvTimeout{30};
     int tcpSendTimeout{30};
     int d_qpsLimit{0};
-    unsigned int checkInterval{1};
     unsigned int sourceItf{0};
     QType checkType{QType::A};
     uint16_t checkClass{QClass::IN};
     uint16_t d_retries{5};
-    uint16_t checkTimeout{1000}; /* in milliseconds */
     uint16_t d_lazyHealthCheckSampleSize{100};
     uint16_t d_lazyHealthCheckMinSampleCount{1};
     uint16_t d_lazyHealthCheckFailedInterval{30};
     uint16_t d_lazyHealthCheckMaxBackOff{3600};
     uint8_t d_lazyHealthCheckThreshold{20};
     LazyHealthCheckMode d_lazyHealthCheckMode{LazyHealthCheckMode::TimeoutOrServFail};
-    uint8_t maxCheckFailures{1};
-    uint8_t minRiseSuccesses{1};
     uint8_t udpTimeout{0};
     uint8_t dscp{0};
     Availability availability{Availability::Auto};
@@ -613,6 +610,28 @@ struct DownstreamState : public std::enable_shared_from_this<DownstreamState>
     bool d_upgradeToLazyHealthChecks{false};
   };
 
+  struct Config : public BaseConfig {
+    Config(): BaseConfig()
+    {
+    }
+    Config(const ComboAddress& remote_) :
+      BaseConfig(remote_)
+    {
+    }
+    Config(const Config& c) : BaseConfig(c)
+    {
+      checkInterval.store(c.checkInterval.load());
+      checkTimeout.store(c.checkTimeout.load());
+      maxCheckFailures.store(c.maxCheckFailures.load());
+      minRiseSuccesses.store(c.minRiseSuccesses.load());
+    }
+
+    std::atomic<unsigned int> checkInterval{1};
+    std::atomic<uint16_t> checkTimeout{1000}; /* in milliseconds */
+    std::atomic<uint8_t> maxCheckFailures{1};
+    std::atomic<uint8_t> minRiseSuccesses{1};
+  };
+
   struct HealthCheckMetrics
   {
     stat_t d_failures{0};
index e3007e3d483f8133158d497ca0b3d1dda795eab5..bf577e79a5e62185ef05b67c3d68b53d4c8abec2 100644 (file)
@@ -898,13 +898,10 @@ A server object returned by :func:`getServer` can be manipulated with these func
     :header: Keyword, Type
     :widths: auto
 
-    ``checkName``                            ``string``
-    ``checkType``                            ``string``
     ``checkTimeout``                         ``number``
     ``checkInterval``                        ``number``
     ``maxCheckFailures``                     ``number``
     ``rise``                                 ``number``
-    ``checkTCP``                             ``bool``
 
   Apart from the functions, a :class:`Server` object has these attributes:
 
index d3acbcf832ae48e3e3c276794870bcd5975fc348..12347074e4e6417ab61cae79cd7295ce08b699bd 100644 (file)
@@ -5,7 +5,7 @@ import ssl
 import threading
 import time
 import dns
-from queue import Queue
+import queue
 from dnsdisttests import DNSDistTest, pickAvailablePort, ResponderDropAction
 
 class HealthCheckTest(DNSDistTest):
@@ -402,45 +402,32 @@ class TestLazyHealthChecks(HealthCheckTest):
 
 class HealthCheckUpdateParams(HealthCheckTest):
 
-    _healthQueue = Queue()
+    _healthQueue = queue.Queue()
+    _dropHealthCheck = False
 
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
-        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, cls.healthCallbackUdp])
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, cls.healthCallback])
         cls._UDPResponder.daemon = True
         cls._UDPResponder.start()
-        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, cls.healthCallbackTcp])
-        cls._TCPResponder.daemon = True
-        cls._TCPResponder.start()
 
     @classmethod
-    def healthCallbackUdp(cls, request):
-        qn, qt= str(request.question[0].name), request.question[0].rdtype
+    def healthCallback(cls, request):
+        if cls._dropHealthCheck:
+          cls._healthQueue.put(False)
+          return ResponderDropAction()
         response = dns.message.make_response(request)
-        if qn.endswith("drop.hc.dnsdist.org.") or qn.endswith("tcponly.hc.dnsdist.org."):
-            response = None
-        if response is None:
-            cls._healthQueue.put((False, qn, qt))
-            return ResponderDropAction()
-        cls._healthQueue.put((True, qn, qt))
+        cls._healthQueue.put(True)
         return response.to_wire()
 
     @classmethod
-    def healthCallbackTcp(cls, request):
-        qn, qt= str(request.question[0].name), request.question[0].rdtype
-        response = dns.message.make_response(request)
-        if qn.endswith("drop.hc.dnsdist.org."):
-            response = None
-        if response is None:
-            cls._healthQueue.put((False, qn, qt))
-            return ResponderDropAction()
-        cls._healthQueue.put((True, qn, qt))
-        return response.to_wire()
+    def wait1(cls, block=True):
+        return cls._healthQueue.get(block)
 
     @classmethod
-    def wait1(cls):
-        return cls._healthQueue.get()
+    def setDrop(cls, flag=True):
+        cls._dropHealthCheck = flag
 
 class TestUpdateHCParamsCombo1(HealthCheckUpdateParams):
 
@@ -449,42 +436,43 @@ class TestUpdateHCParamsCombo1(HealthCheckUpdateParams):
 
     def testCombo1(self):
         """
-        HealthChecks: Update checkName, maxCheckFailures, rise, checkTCP
+        HealthChecks: Update maxCheckFailures, rise
         """
         # consume health checks upon sys init
-        for _ in [1, 2]: rc, qn, qt = self.wait1()
-        self.assertEqual(rc, True)
+        try:
+          while self.wait1(False): pass
+        except queue.Empty: pass
+
+        self.assertEqual(self.wait1(), True)
         time.sleep(0.1)
         self.assertEqual(self.getBackendMetric(0, 'healthCheckFailures'), 0)
         self.assertEqual(self.getBackendStatus(), 'up')
 
-        self.sendConsoleCommand("getServer(0):setHealthCheckParams({checkName='drop.hc.dnsdist.org',maxCheckFailures=2,rise=2})")
+        self.sendConsoleCommand("getServer(0):setHealthCheckParams({maxCheckFailures=2,rise=2})")
+        self.setDrop()
 
         # wait for 1st failure
         for i in [1,2,3]:
-            rc, qn, qt = self.wait1()
+            rc = self.wait1()
             if rc is False: break
         self.assertGreater(3, i)
-        self.assertEqual(qn, 'drop.hc.dnsdist.org.')
         time.sleep(1.1)
         # should have failures but still up
         self.assertGreater(self.getBackendMetric(0, 'healthCheckFailures'), 0)
         self.assertEqual(self.getBackendStatus(), 'up')
 
         # wait for 2nd failure
-        rc, qn, qt = self.wait1()
-        self.assertEqual(rc, False)
-        self.assertEqual(qn, 'drop.hc.dnsdist.org.')
+        self.assertEqual(self.wait1(), False)
         time.sleep(1.1)
         # should have more failures and down
         self.assertGreater(self.getBackendMetric(0, 'healthCheckFailures'), 1)
         self.assertEqual(self.getBackendStatus(), 'down')
 
-        self.sendConsoleCommand("getServer(0):setHealthCheckParams({checkName='tcponly.hc.powerdns.com',checkTCP=true})")
+        self.setDrop(False)
 
         # wait for 1st success
         for i in [1,2,3]:
-            rc, qn, qt = self.wait1()
+            rc = self.wait1()
             if rc is True: break
         self.assertGreater(3, i)
         time.sleep(0.1)
@@ -494,8 +482,7 @@ class TestUpdateHCParamsCombo1(HealthCheckUpdateParams):
         beforeFailure = self.getBackendMetric(0, 'healthCheckFailures')
 
         # wati for 2nd success
-        rc, qn, qt = self.wait1()
-        self.assertEqual(rc, True)
+        self.assertEqual(self.wait1(), True)
         time.sleep(0.1)
         # should have no more failures, back to up
         self.assertEqual(self.getBackendMetric(0, 'healthCheckFailures'), beforeFailure)
@@ -508,34 +495,34 @@ class TestUpdateHCParamsCombo2(HealthCheckUpdateParams):
 
     def testCombo2(self):
         """
-        HealthChecks: Update checkType, checkTimeout, checkInterval
+        HealthChecks: Update checkTimeout, checkInterval
         """
         # consume health checks upon sys init
-        for _ in [1, 2]: rc, qn, qt = self.wait1()
-        self.assertEqual(rc, True)
+        try:
+          while self.wait1(False): pass
+        except queue.Empty: pass
+
+        self.assertEqual(self.wait1(), True)
         time.sleep(0.1)
         self.assertEqual(self.getBackendMetric(0, 'healthCheckFailures'), 0)
         self.assertEqual(self.getBackendStatus(), 'up')
 
-        self.sendConsoleCommand("getServer(0):setHealthCheckParams({checkType='TXT',checkInterval=2})")
+        self.sendConsoleCommand("getServer(0):setHealthCheckParams({checkInterval=2})")
 
         # start timing
-        rc, qn, qt = self.wait1()
+        self.assertEqual(self.wait1(), True)
         t1 = time.time()
-        self.assertEqual(rc, True)
-        self.assertEqual(qt, dns.rdatatype.TXT)
-        rc, qn, qt = self.wait1()
+        self.assertEqual(self.wait1(), True)
         t2 = time.time()
-        self.assertEqual(rc, True)
-        self.assertEqual(qt, dns.rdatatype.TXT)
         # intervals shall be greater than 1
         self.assertGreater(t2-t1, 1.5)
 
-        self.sendConsoleCommand("getServer(0):setHealthCheckParams({checkName='drop.hc.dnsdist.org',checkTimeout=2000})")
+        self.sendConsoleCommand("getServer(0):setHealthCheckParams({checkTimeout=2000})")
+        self.setDrop()
 
         # wait for 1st failure
         for i in [1,2,3]:
-            rc, qn, qt = self.wait1()
+            rc = self.wait1()
             if rc is False: break
         self.assertGreater(3, i)