]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Prevent race while creating new TCP worker threads
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 12 Dec 2016 15:28:17 +0000 (16:28 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 12 Dec 2016 15:28:17 +0000 (16:28 +0100)
We try very hard to avoid using locks, but we need to prevent two
threads inserting into the TCP workers vector concurrently. While
this can't happen at runtime since the healthcheck thread is the
only one calling `g_tcpclientthreads->addTCPClientThread()`, this
might happen at startup time because we start the TCP acceptor
threads one after another and they all call it once.
This might result, for example, in one vector entry being overwritten
and another one remaining value-initialized to zero.

pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh

index 3fc64c9b501804ff3bc6726f61c1a3076a96063b..512d222d6bf4060120c66efb58cc019916e522ca 100644 (file)
@@ -1524,7 +1524,7 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       setLuaNoSideEffect();
       boost::format fmt("%-10d %-10d %-10d %-10d\n");
       g_outputBuffer += (fmt % "Clients" % "MaxClients" % "Queued" % "MaxQueued").str();
-      g_outputBuffer += (fmt % g_tcpclientthreads->d_numthreads % g_maxTCPClientThreads % g_tcpclientthreads->d_queued % g_maxTCPQueuedConnections).str();
+      g_outputBuffer += (fmt % g_tcpclientthreads->getThreadsCount() % g_maxTCPClientThreads % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections).str();
     });
 
   g_lua.writeFunction("setCacheCleaningDelay", [](uint32_t delay) { g_cacheCleaningDelay = delay; });
index 13d2c8d3319f5757a66fc73a84a79a96dbe15579..8e30f7dad2f6a0bbeab588bdf4677bf213c25518 100644 (file)
@@ -77,14 +77,8 @@ struct ConnectionInfo
 uint64_t g_maxTCPQueuedConnections{1000};
 void* tcpClientThread(int pipefd);
 
-// Should not be called simultaneously!
 void TCPClientCollection::addTCPClientThread()
 {
-  if (d_numthreads >= d_tcpclientthreads.capacity()) {
-    warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
-    return;
-  }
-
   vinfolog("Adding TCP Client thread");
 
   int pipefds[2] = { -1, -1};
@@ -112,7 +106,19 @@ void TCPClientCollection::addTCPClientThread()
     return;
   }
 
-  d_tcpclientthreads.push_back(pipefds[1]);
+  {
+    std::lock_guard<std::mutex> lock(d_mutex);
+
+    if (d_numthreads >= d_tcpclientthreads.capacity()) {
+      warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
+      close(pipefds[0]);
+      close(pipefds[1]);
+      return;
+    }
+
+    d_tcpclientthreads.push_back(pipefds[1]);
+  }
+
   ++d_numthreads;
 }
 
@@ -202,7 +208,7 @@ void* tcpClientThread(int pipefd)
       throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
     }
 
-    --g_tcpclientthreads->d_queued;
+    g_tcpclientthreads->decrementQueuedCount();
     ci=*citmp;
     delete citmp;    
 
@@ -575,7 +581,7 @@ void* tcpAcceptorThread(void* p)
        continue;
       }
 
-      if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) {
+      if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
         close(ci->fd);
         delete ci;
         ci=nullptr;
@@ -592,7 +598,7 @@ void* tcpAcceptorThread(void* p)
         writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
       }
       else {
-        --g_tcpclientthreads->d_queued;
+        g_tcpclientthreads->decrementQueuedCount();
         queuedCounterIncremented = false;
         close(ci->fd);
         delete ci;
@@ -606,7 +612,7 @@ void* tcpAcceptorThread(void* p)
       delete ci;
       ci = nullptr;
       if (queuedCounterIncremented) {
-        --g_tcpclientthreads->d_queued;
+        g_tcpclientthreads->decrementQueuedCount();
       }
     }
     catch(...){}
index b0eaacfd9fc9bf5315279ebe9b14950c70d06c62..01b88cca8372b2fd2da5b044a2abb738d75c33b5 100644 (file)
@@ -1371,7 +1371,7 @@ void* healthChecksThread()
   for(;;) {
     sleep(interval);
 
-    if(g_tcpclientthreads->d_queued > 1 && g_tcpclientthreads->d_numthreads < g_tcpclientthreads->d_maxthreads)
+    if(g_tcpclientthreads->getQueuedCount() > 1 && !g_tcpclientthreads->hasReachedMaxThreads())
       g_tcpclientthreads->addTCPClientThread();
 
     for(auto& dss : g_dstates.getCopy()) { // this points to the actual shared_ptrs!
index a8d877fbab9b66fbdcec012a4082cc5438bab2f5..895a677d543328ec5b27ac2a7e79e581bdacc444 100644 (file)
@@ -363,23 +363,40 @@ struct ClientState
 
 class TCPClientCollection {
   std::vector<int> d_tcpclientthreads;
+  std::atomic<uint64_t> d_numthreads{0};
   std::atomic<uint64_t> d_pos{0};
-public:
-  std::atomic<uint64_t> d_queued{0}, d_numthreads{0};
+  std::atomic<uint64_t> d_queued{0};
   uint64_t d_maxthreads{0};
+  std::mutex d_mutex;
+public:
 
   TCPClientCollection(size_t maxThreads)
   {
     d_maxthreads = maxThreads;
     d_tcpclientthreads.reserve(maxThreads);
   }
-
   int getThread()
   {
     uint64_t pos = d_pos++;
     ++d_queued;
     return d_tcpclientthreads[pos % d_numthreads];
   }
+  bool hasReachedMaxThreads() const
+  {
+    return d_numthreads >= d_maxthreads;
+  }
+  uint64_t getThreadsCount() const
+  {
+    return d_numthreads;
+  }
+  uint64_t getQueuedCount() const
+  {
+    return d_queued;
+  }
+  void decrementQueuedCount()
+  {
+    --d_queued;
+  }
   void addTCPClientThread();
 };