]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactoring of the TCP/TLS workers using channels
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 17 Mar 2022 15:26:08 +0000 (16:26 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 13 Jun 2023 07:59:33 +0000 (09:59 +0200)
pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-tcp.hh

index a5af69e2ced7d1f7caa7989f20fdb487a431b21c..17309a4f664277f0d71588dd36e723b1d0ed65ad 100644 (file)
@@ -118,7 +118,7 @@ std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstrea
   return downstream;
 }
 
-static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector<ClientState*> tcpAcceptStates);
+static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates);
 
 TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientState*> tcpAcceptStates): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads)
 {
@@ -129,83 +129,37 @@ TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientSt
 
 void TCPClientCollection::addTCPClientThread(std::vector<ClientState*>& tcpAcceptStates)
 {
-  auto preparePipe = [](int fds[2], const std::string& type) -> bool {
-    if (pipe(fds) < 0) {
-      errlog("Error creating the TCP thread %s pipe: %s", type, stringerror());
-      return false;
-    }
-
-    if (!setNonBlocking(fds[0])) {
-      int err = errno;
-      close(fds[0]);
-      close(fds[1]);
-      errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err));
-      return false;
-    }
-
-    if (!setNonBlocking(fds[1])) {
-      int err = errno;
-      close(fds[0]);
-      close(fds[1]);
-      errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err));
-      return false;
-    }
-
-    if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) {
-      setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize);
-    }
-
-    return true;
-  };
+  try {
+    auto [queryChannelSender, queryChannelReceiver] = pdns::channel::createObjectQueue<ConnectionInfo>(true, g_tcpInternalPipeBufferSize);
 
-  int pipefds[2] = { -1, -1};
-  if (!preparePipe(pipefds, "communication")) {
-    return;
-  }
+    auto [crossProtocolQueryChannelSender, crossProtocolQueryChannelReceiver] = pdns::channel::createObjectQueue<CrossProtocolQuery>(true, g_tcpInternalPipeBufferSize);
 
-  int crossProtocolQueriesFDs[2] = { -1, -1};
-  if (!preparePipe(crossProtocolQueriesFDs, "cross-protocol queries")) {
-    return;
-  }
+    auto [crossProtocolResponseChannelSender, crossProtocolResponseChannelReceiver] = pdns::channel::createObjectQueue<TCPCrossProtocolResponse>(true, g_tcpInternalPipeBufferSize);
 
-  int crossProtocolResponsesFDs[2] = { -1, -1};
-  if (!preparePipe(crossProtocolResponsesFDs, "cross-protocol responses")) {
-    return;
-  }
+    vinfolog("Adding TCP Client thread");
 
-  vinfolog("Adding TCP Client thread");
-
-  {
     if (d_numthreads >= d_tcpclientthreads.size()) {
       vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size());
-      close(crossProtocolQueriesFDs[0]);
-      close(crossProtocolQueriesFDs[1]);
-      close(crossProtocolResponsesFDs[0]);
-      close(crossProtocolResponsesFDs[1]);
-      close(pipefds[0]);
-      close(pipefds[1]);
       return;
     }
 
-    /* from now on this side of the pipe will be managed by that object,
-       no need to worry about it */
-    TCPWorkerThread worker(pipefds[1], crossProtocolQueriesFDs[1], crossProtocolResponsesFDs[1]);
+    TCPWorkerThread worker(std::move(queryChannelSender), std::move(crossProtocolQueryChannelSender));
+
     try {
-      std::thread t1(tcpClientThread, pipefds[0], crossProtocolQueriesFDs[0], crossProtocolResponsesFDs[0], crossProtocolResponsesFDs[1], tcpAcceptStates);
+      std::thread t1(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates);
       t1.detach();
     }
     catch (const std::runtime_error& e) {
-      /* the thread creation failed, don't leak */
       errlog("Error creating a TCP thread: %s", e.what());
-      close(pipefds[0]);
-      close(crossProtocolQueriesFDs[0]);
-      close(crossProtocolResponsesFDs[0]);
       return;
     }
 
     d_tcpclientthreads.at(d_numthreads) = std::move(worker);
     ++d_numthreads;
   }
+  catch (const std::exception& e) {
+    errlog("Error creating TCP worker: %", e.what());
+  }
 }
 
 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
@@ -620,23 +574,16 @@ std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion&
 
 void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
 {
-  if (d_threadData.crossProtocolResponsesPipe == -1) {
-    throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
-  }
-
   std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
-  auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now);
-  static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
-  ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr));
-  if (sent != sizeof(ptr)) {
-    if (errno == EAGAIN || errno == EWOULDBLOCK) {
+  try {
+    auto ptr = std::make_unique<TCPCrossProtocolResponse>(std::move(response), state, now);
+    if (!state->d_threadData.crossProtocolResponseSender.send(std::move(ptr))) {
       ++g_stats.tcpCrossProtocolResponsePipeFull;
       vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
     }
-    else {
-      vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
-    }
-    delete ptr;
+  }
+  catch (const std::exception& e) {
+    vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
   }
 }
 
@@ -1151,111 +1098,82 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param
 {
   auto threadData = boost::any_cast<TCPClientThreadData*>(param);
 
-  ConnectionInfo* citmp{nullptr};
-
-  ssize_t got = read(pipefd, &citmp, sizeof(citmp));
-  if (got == 0) {
-    throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
-  }
-  else if (got == -1) {
-    if (errno == EAGAIN || errno == EINTR) {
+  std::unique_ptr<ConnectionInfo> citmp{nullptr};
+  try {
+    auto tmp = threadData->queryReceiver.receive();
+    if (!tmp) {
       return;
     }
-    throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+    citmp = std::move(*tmp);
   }
-  else if (got != sizeof(citmp)) {
-    throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  catch (const std::exception& e) {
+    throw std::runtime_error("Error while reading from the TCP query channel: " + std::string(e.what()));
   }
 
-  try {
-    g_tcpclientthreads->decrementQueuedCount();
+  g_tcpclientthreads->decrementQueuedCount();
 
-    struct timeval now;
-    gettimeofday(&now, nullptr);
-    auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
-    delete citmp;
-    citmp = nullptr;
-
-    IncomingTCPConnectionState::handleIO(state, now);
-  }
-  catch (...) {
-    delete citmp;
-    citmp = nullptr;
-    throw;
-  }
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+  auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
+  IncomingTCPConnectionState::handleIO(state, now);
 }
 
 static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
 {
   auto threadData = boost::any_cast<TCPClientThreadData*>(param);
-  CrossProtocolQuery* tmp{nullptr};
 
-  ssize_t got = read(pipefd, &tmp, sizeof(tmp));
-  if (got == 0) {
-    throw std::runtime_error("EOF while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
-  }
-  else if (got == -1) {
-    if (errno == EAGAIN || errno == EINTR) {
+  std::unique_ptr<CrossProtocolQuery> cpq{nullptr};
+  try {
+    auto tmp = threadData->crossProtocolQueryReceiver.receive();
+    if (!tmp) {
       return;
     }
-    throw std::runtime_error("Error while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+    cpq = std::move(*tmp);
   }
-  else if (got != sizeof(tmp)) {
-    throw std::runtime_error("Partial read while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  catch (const std::exception& e) {
+    throw std::runtime_error("Error while reading from the TCP cross-protocol channel: " + std::string(e.what()));
   }
 
-  try {
-    struct timeval now;
-    gettimeofday(&now, nullptr);
+  struct timeval now;
+  gettimeofday(&now, nullptr);
 
-    std::shared_ptr<TCPQuerySender> tqs = tmp->getTCPQuerySender();
-    auto query = std::move(tmp->query);
-    auto downstreamServer = std::move(tmp->downstream);
-    auto proxyProtocolPayloadSize = tmp->proxyProtocolPayloadSize;
-    delete tmp;
-    tmp = nullptr;
+  std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
+  auto query = std::move(cpq->query);
+  auto downstreamServer = std::move(cpq->downstream);
+  auto proxyProtocolPayloadSize = cpq->proxyProtocolPayloadSize;
 
-    try {
-      auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
+  try {
+    auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
 
-      prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
-      query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize;
+    prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
+    query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize;
 
-      vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr());
+    vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr());
 
-      downstream->queueQuery(tqs, std::move(query));
-    }
-    catch (...) {
-      tqs->notifyIOError(std::move(query.d_idstate), now);
-    }
+    downstream->queueQuery(tqs, std::move(query));
   }
   catch (...) {
-    delete tmp;
-    tmp = nullptr;
+    tqs->notifyIOError(std::move(query.d_idstate), now);
   }
 }
 
 static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param)
 {
-  TCPCrossProtocolResponse* tmp{nullptr};
+  auto threadData = boost::any_cast<TCPClientThreadData*>(param);
 
-  ssize_t got = read(pipefd, &tmp, sizeof(tmp));
-  if (got == 0) {
-    throw std::runtime_error("EOF while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
-  }
-  else if (got == -1) {
-    if (errno == EAGAIN || errno == EINTR) {
+  std::unique_ptr<TCPCrossProtocolResponse> cpr{nullptr};
+  try {
+    auto tmp = threadData->crossProtocolResponseReceiver.receive();
+    if (!tmp) {
       return;
     }
-    throw std::runtime_error("Error while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+    cpr = std::move(*tmp);
   }
-  else if (got != sizeof(tmp)) {
-    throw std::runtime_error("Partial read while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  catch (const std::exception& e) {
+    throw std::runtime_error("Error while reading from the TCP cross-protocol response: " + std::string(e.what()));
   }
 
-  auto response = std::move(*tmp);
-  delete tmp;
-  tmp = nullptr;
+  auto response = std::move(*cpr);
 
   try {
     if (response.d_response.d_buffer.empty()) {
@@ -1283,7 +1201,7 @@ struct TCPAcceptorParam
 
 static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData);
 
-static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector<ClientState*> tcpAcceptStates)
+static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates)
 {
   /* we get launched with a pipe on which we receive file descriptors from clients that we own
      from that point on */
@@ -1292,11 +1210,14 @@ static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int cros
 
   try {
     TCPClientThreadData data;
-    /* this is the writing end! */
-    data.crossProtocolResponsesPipe = crossProtocolResponsesWritePipeFD;
-    data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
-    data.mplexer->addReadFD(crossProtocolQueriesPipeFD, handleCrossProtocolQuery, &data);
-    data.mplexer->addReadFD(crossProtocolResponsesListenPipeFD, handleCrossProtocolResponse, &data);
+    data.crossProtocolResponseSender = std::move(crossProtocolResponseSender);
+    data.queryReceiver = std::move(queryReceiver);
+    data.crossProtocolQueryReceiver = std::move(crossProtocolQueryReceiver);
+    data.crossProtocolResponseReceiver = std::move(crossProtocolResponseReceiver);
+
+    data.mplexer->addReadFD(data.queryReceiver.getDescriptor(), handleIncomingTCPQuery, &data);
+    data.mplexer->addReadFD(data.crossProtocolQueryReceiver.getDescriptor(), handleCrossProtocolQuery, &data);
+    data.mplexer->addReadFD(data.crossProtocolResponseReceiver.getDescriptor(), handleCrossProtocolResponse, &data);
 
     /* only used in single acceptor mode for now */
     auto acl = g_ACL.getLocal();
index 59c4df410d241882a7b6438d42c3bf4feaf811a4..04a5ff4da68283f38a7e48abd150760b763a7980 100644 (file)
@@ -3,6 +3,8 @@
 #include "dolog.hh"
 #include "dnsdist-tcp.hh"
 
+class TCPCrossProtocolResponse;
+
 class TCPClientThreadData
 {
 public:
@@ -15,7 +17,10 @@ public:
   LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions;
   LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions;
   std::unique_ptr<FDMultiplexer> mplexer{nullptr};
-  int crossProtocolResponsesPipe{-1};
+  pdns::channel::Receiver<ConnectionInfo> queryReceiver;
+  pdns::channel::Receiver<CrossProtocolQuery> crossProtocolQueryReceiver;
+  pdns::channel::Receiver<TCPCrossProtocolResponse> crossProtocolResponseReceiver;
+  pdns::channel::Sender<TCPCrossProtocolResponse> crossProtocolResponseSender;
 };
 
 class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
index 3d11f1a4f4975fd26e4024fb909ccad173922645..04fb37737413d7aab64ac794ba364de9254a0cbf 100644 (file)
@@ -22,6 +22,7 @@
 #pragma once
 
 #include <unistd.h>
+#include "channel.hh"
 #include "iputils.hh"
 #include "dnsdist.hh"
 
@@ -213,20 +214,16 @@ public:
     }
 
     uint64_t pos = d_pos++;
-    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle();
-    auto tmp = conn.release();
-
     /* we need to increment this counter _before_ writing to the pipe,
        otherwise there is a very real possiblity that the other end
        decrement the counter before we can increment it, leading to an underflow */
     ++d_queued;
-    if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
+    if (!d_tcpclientthreads.at(pos % d_numthreads).d_querySender.send(std::move(conn))) {
       --d_queued;
       ++g_stats.tcpQueryPipeFull;
-      delete tmp;
-      tmp = nullptr;
       return false;
     }
+
     return true;
   }
 
@@ -237,13 +234,8 @@ public:
     }
 
     uint64_t pos = d_pos++;
-    auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueriesPipe.getHandle();
-    auto tmp = cpq.release();
-
-    if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
+    if (!d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQuerySender.send(std::move(cpq))) {
       ++g_stats.tcpCrossProtocolQueryPipeFull;
-      delete tmp;
-      tmp = nullptr;
       return false;
     }
 
@@ -279,8 +271,8 @@ private:
     {
     }
 
-    TCPWorkerThread(int newConnPipe, int crossProtocolQueriesPipe, int crossProtocolResponsesPipe) :
-      d_newConnectionPipe(newConnPipe), d_crossProtocolQueriesPipe(crossProtocolQueriesPipe), d_crossProtocolResponsesPipe(crossProtocolResponsesPipe)
+    TCPWorkerThread(pdns::channel::Sender<ConnectionInfo>&& querySender, pdns::channel::Sender<CrossProtocolQuery>&& crossProtocolQuerySender) :
+      d_querySender(std::move(querySender)), d_crossProtocolQuerySender(std::move(crossProtocolQuerySender))
     {
     }
 
@@ -289,9 +281,8 @@ private:
     TCPWorkerThread(const TCPWorkerThread& rhs) = delete;
     TCPWorkerThread& operator=(const TCPWorkerThread&) = delete;
 
-    FDWrapper d_newConnectionPipe;
-    FDWrapper d_crossProtocolQueriesPipe;
-    FDWrapper d_crossProtocolResponsesPipe;
+    pdns::channel::Sender<ConnectionInfo> d_querySender;
+    pdns::channel::Sender<CrossProtocolQuery> d_crossProtocolQuerySender;
   };
 
   std::vector<TCPWorkerThread> d_tcpclientthreads;