From: Remi Gacogne Date: Mon, 12 Dec 2016 15:28:17 +0000 (+0100) Subject: dnsdist: Prevent race while creating new TCP worker threads X-Git-Tag: dnsdist-1.1.0-beta2~1^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ded1985a766447fbb444a6db1cc1f809f8c19e21;p=thirdparty%2Fpdns.git dnsdist: Prevent race while creating new TCP worker threads We try very hard to avoid using locks, but we need to prevent two threads inserting into the TCP workers vector concurrently. While this can't happen at runtime since the healthcheck thread is the only one calling `g_tcpclientthreads->addTCPClientThread()`, this might happen at startup time because we start the TCP acceptor threads one after another and they all call it once. This might result, for example, in one vector entry being overwritten and another one remaining value-initialized to zero. --- diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 3fc64c9b50..512d222d6b 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -1524,7 +1524,7 @@ vector> setupLua(bool client, const std::string& confi setLuaNoSideEffect(); boost::format fmt("%-10d %-10d %-10d %-10d\n"); g_outputBuffer += (fmt % "Clients" % "MaxClients" % "Queued" % "MaxQueued").str(); - g_outputBuffer += (fmt % g_tcpclientthreads->d_numthreads % g_maxTCPClientThreads % g_tcpclientthreads->d_queued % g_maxTCPQueuedConnections).str(); + g_outputBuffer += (fmt % g_tcpclientthreads->getThreadsCount() % g_maxTCPClientThreads % g_tcpclientthreads->getQueuedCount() % g_maxTCPQueuedConnections).str(); }); g_lua.writeFunction("setCacheCleaningDelay", [](uint32_t delay) { g_cacheCleaningDelay = delay; }); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 13d2c8d331..8e30f7dad2 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -77,14 +77,8 @@ struct ConnectionInfo uint64_t g_maxTCPQueuedConnections{1000}; void* tcpClientThread(int pipefd); -// Should not be called simultaneously! void TCPClientCollection::addTCPClientThread() { - if (d_numthreads >= d_tcpclientthreads.capacity()) { - warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity()); - return; - } - vinfolog("Adding TCP Client thread"); int pipefds[2] = { -1, -1}; @@ -112,7 +106,19 @@ void TCPClientCollection::addTCPClientThread() return; } - d_tcpclientthreads.push_back(pipefds[1]); + { + std::lock_guard lock(d_mutex); + + if (d_numthreads >= d_tcpclientthreads.capacity()) { + warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity()); + close(pipefds[0]); + close(pipefds[1]); + return; + } + + d_tcpclientthreads.push_back(pipefds[1]); + } + ++d_numthreads; } @@ -202,7 +208,7 @@ void* tcpClientThread(int pipefd) throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what()); } - --g_tcpclientthreads->d_queued; + g_tcpclientthreads->decrementQueuedCount(); ci=*citmp; delete citmp; @@ -575,7 +581,7 @@ void* tcpAcceptorThread(void* p) continue; } - if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) { + if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) { close(ci->fd); delete ci; ci=nullptr; @@ -592,7 +598,7 @@ void* tcpAcceptorThread(void* p) writen2WithTimeout(pipe, &ci, sizeof(ci), 0); } else { - --g_tcpclientthreads->d_queued; + g_tcpclientthreads->decrementQueuedCount(); queuedCounterIncremented = false; close(ci->fd); delete ci; @@ -606,7 +612,7 @@ void* tcpAcceptorThread(void* p) delete ci; ci = nullptr; if (queuedCounterIncremented) { - --g_tcpclientthreads->d_queued; + g_tcpclientthreads->decrementQueuedCount(); } } catch(...){} diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index b0eaacfd9f..01b88cca83 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1371,7 +1371,7 @@ void* healthChecksThread() for(;;) { sleep(interval); - if(g_tcpclientthreads->d_queued > 1 && g_tcpclientthreads->d_numthreads < g_tcpclientthreads->d_maxthreads) + if(g_tcpclientthreads->getQueuedCount() > 1 && !g_tcpclientthreads->hasReachedMaxThreads()) g_tcpclientthreads->addTCPClientThread(); for(auto& dss : g_dstates.getCopy()) { // this points to the actual shared_ptrs! diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index a8d877fbab..895a677d54 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -363,23 +363,40 @@ struct ClientState class TCPClientCollection { std::vector d_tcpclientthreads; + std::atomic d_numthreads{0}; std::atomic d_pos{0}; -public: - std::atomic d_queued{0}, d_numthreads{0}; + std::atomic d_queued{0}; uint64_t d_maxthreads{0}; + std::mutex d_mutex; +public: TCPClientCollection(size_t maxThreads) { d_maxthreads = maxThreads; d_tcpclientthreads.reserve(maxThreads); } - int getThread() { uint64_t pos = d_pos++; ++d_queued; return d_tcpclientthreads[pos % d_numthreads]; } + bool hasReachedMaxThreads() const + { + return d_numthreads >= d_maxthreads; + } + uint64_t getThreadsCount() const + { + return d_numthreads; + } + uint64_t getQueuedCount() const + { + return d_queued; + } + void decrementQueuedCount() + { + --d_queued; + } void addTCPClientThread(); };