]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactoring of the outgoing DoH code with pdns::channel
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 17 Mar 2022 09:12:08 +0000 (10:12 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 13 Jun 2023 07:59:32 +0000 (09:59 +0200)
pdns/dnsdistdist/dnsdist-nghttp2.cc

index 5c745eac1c1381c4ceedf8a92bd21876bc9706ac..34e8abe5eb43dc7bfe300320795d8a783b5420e3 100644 (file)
@@ -32,6 +32,7 @@
 #include "dnsdist-downstream-connection.hh"
 
 #include "dolog.hh"
+#include "channel.hh"
 #include "iputils.hh"
 #include "libssl.hh"
 #include "noinitvector.hh"
@@ -368,12 +369,14 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender,
 class DoHClientThreadData
 {
 public:
-  DoHClientThreadData() :
-    mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+  DoHClientThreadData(pdns::channel::Receiver<CrossProtocolQuery>&& receiver) :
+    mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent())),
+    d_receiver(std::move(receiver))
   {
   }
 
   std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+  pdns::channel::Receiver<CrossProtocolQuery> d_receiver;
 };
 
 void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
@@ -856,53 +859,43 @@ DoHConnectionToBackend::DoHConnectionToBackend(const std::shared_ptr<DownstreamS
 static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
 {
   auto threadData = boost::any_cast<DoHClientThreadData*>(param);
-  CrossProtocolQuery* tmp{nullptr};
 
-  ssize_t got = read(pipefd, &tmp, sizeof(tmp));
-  if (got == 0) {
-    throw std::runtime_error("EOF while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
-  }
-  else if (got == -1) {
-    if (errno == EAGAIN || errno == EINTR) {
+  std::unique_ptr<CrossProtocolQuery> cpq{nullptr};
+  try {
+    auto tmp = threadData->d_receiver.receive();
+    if (!tmp) {
       return;
     }
-    throw std::runtime_error("Error while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror());
+    cpq = std::move(*tmp);
   }
-  else if (got != sizeof(tmp)) {
-    throw std::runtime_error("Partial read while reading from the DoH cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode");
+  catch (const std::exception& e) {
+    throw std::runtime_error("Error while reading from the DoH cross-protocol channel:" + std::string(e.what()));
   }
 
-  try {
-    struct timeval now;
-    gettimeofday(&now, nullptr);
+  struct timeval now;
+  gettimeofday(&now, nullptr);
 
-    std::shared_ptr<TCPQuerySender> tqs = tmp->getTCPQuerySender();
-    auto query = std::move(tmp->query);
-    auto downstreamServer = std::move(tmp->downstream);
-    delete tmp;
-    tmp = nullptr;
+  std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
+  auto query = std::move(cpq->query);
+  auto downstreamServer = std::move(cpq->downstream);
+  cpq.reset();
 
-    try {
-      auto downstream = t_downstreamDoHConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::move(query.d_proxyProtocolPayload));
-      downstream->queueQuery(tqs, std::move(query));
-    }
-    catch (...) {
-      tqs->notifyIOError(std::move(query.d_idstate), now);
-    }
+  try {
+    auto downstream = t_downstreamDoHConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::move(query.d_proxyProtocolPayload));
+    downstream->queueQuery(tqs, std::move(query));
   }
   catch (...) {
-    delete tmp;
-    tmp = nullptr;
+    tqs->notifyIOError(std::move(query.d_idstate), now);
   }
 }
 
-static void dohClientThread(int crossProtocolPipeFD)
+static void dohClientThread(pdns::channel::Receiver<CrossProtocolQuery>&& receiver)
 {
   setThreadName("dnsdist/dohClie");
 
   try {
-    DoHClientThreadData data;
-    data.mplexer->addReadFD(crossProtocolPipeFD, handleCrossProtocolQuery, &data);
+    DoHClientThreadData data(std::move(receiver));
+    data.mplexer->addReadFD(data.d_receiver.getDescriptor(), handleCrossProtocolQuery, &data);
 
     struct timeval now;
     gettimeofday(&now, nullptr);
@@ -976,40 +969,26 @@ struct DoHClientCollection::DoHWorkerThread
   {
   }
 
-  DoHWorkerThread(int crossProtocolPipe) :
-    d_crossProtocolQueryPipe(crossProtocolPipe)
+  DoHWorkerThread(pdns::channel::Sender<CrossProtocolQuery>&& sender) :
+    d_sender(std::move(sender))
   {
   }
 
   DoHWorkerThread(DoHWorkerThread&& rhs) :
-    d_crossProtocolQueryPipe(rhs.d_crossProtocolQueryPipe)
+    d_sender(std::move(rhs.d_sender))
   {
-    rhs.d_crossProtocolQueryPipe = -1;
   }
 
   DoHWorkerThread& operator=(DoHWorkerThread&& rhs)
   {
-    if (d_crossProtocolQueryPipe != -1) {
-      close(d_crossProtocolQueryPipe);
-    }
-
-    d_crossProtocolQueryPipe = rhs.d_crossProtocolQueryPipe;
-    rhs.d_crossProtocolQueryPipe = -1;
-
+    d_sender = std::move(rhs.d_sender);
     return *this;
   }
 
   DoHWorkerThread(const DoHWorkerThread& rhs) = delete;
   DoHWorkerThread& operator=(const DoHWorkerThread&) = delete;
 
-  ~DoHWorkerThread()
-  {
-    if (d_crossProtocolQueryPipe != -1) {
-      close(d_crossProtocolQueryPipe);
-    }
-  }
-
-  int d_crossProtocolQueryPipe{-1};
+  pdns::channel::Sender<CrossProtocolQuery> d_sender;
 };
 
 DoHClientCollection::DoHClientCollection(size_t numberOfThreads) :
@@ -1024,13 +1003,8 @@ bool DoHClientCollection::passCrossProtocolQueryToThread(std::unique_ptr<CrossPr
   }
 
   uint64_t pos = d_pos++;
-  auto pipe = d_clientThreads.at(pos % d_numberOfThreads).d_crossProtocolQueryPipe;
-  auto tmp = cpq.release();
-
-  if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) {
-    delete tmp;
+  if (!d_clientThreads.at(pos % d_numberOfThreads).d_sender.send(std::move(cpq))) {
     ++g_stats.outgoingDoHQueryPipeFull;
-    tmp = nullptr;
     return false;
   }
 
@@ -1040,69 +1014,35 @@ bool DoHClientCollection::passCrossProtocolQueryToThread(std::unique_ptr<CrossPr
 void DoHClientCollection::addThread()
 {
 #ifdef HAVE_NGHTTP2
-  auto preparePipe = [](int fds[2], const std::string& type) -> bool {
-    if (pipe(fds) < 0) {
-      errlog("Error creating the DoH thread %s pipe: %s", type, stringerror());
-      return false;
-    }
-
-    if (!setNonBlocking(fds[0])) {
-      int err = errno;
-      close(fds[0]);
-      close(fds[1]);
-      errlog("Error setting the DoH thread %s pipe non-blocking: %s", type, stringerror(err));
-      return false;
-    }
-
-    if (!setNonBlocking(fds[1])) {
-      int err = errno;
-      close(fds[0]);
-      close(fds[1]);
-      errlog("Error setting the DoH thread %s pipe non-blocking: %s", type, stringerror(err));
-      return false;
-    }
-
-    if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) {
-      setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize);
-    }
-
-    return true;
-  };
-
-  int crossProtocolFDs[2] = {-1, -1};
-  if (!preparePipe(crossProtocolFDs, "cross-protocol")) {
-    return;
-  }
-
-  vinfolog("Adding DoH Client thread");
+  try {
+    auto [sender, receiver] = pdns::channel::createObjectQueue<CrossProtocolQuery>(true, g_tcpInternalPipeBufferSize);
 
-  {
+    vinfolog("Adding DoH Client thread");
     std::lock_guard<std::mutex> lock(d_mutex);
 
     if (d_numberOfThreads >= d_clientThreads.size()) {
       vinfolog("Adding a new DoH client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of DoH client threads with setMaxDoHClientThreads() in the configuration.", d_numberOfThreads, d_clientThreads.size());
-      close(crossProtocolFDs[0]);
-      close(crossProtocolFDs[1]);
       return;
     }
 
-    /* from now on this side of the pipe will be managed by that object,
-       no need to worry about it */
-    DoHWorkerThread worker(crossProtocolFDs[1]);
+    DoHWorkerThread worker(std::move(sender));
     try {
-      std::thread t1(dohClientThread, crossProtocolFDs[0]);
+      std::thread t1(dohClientThread, std::move(receiver));
       t1.detach();
     }
     catch (const std::runtime_error& e) {
-      /* the thread creation failed, don't leak */
+      /* the thread creation failed */
       errlog("Error creating a DoH thread: %s", e.what());
-      close(crossProtocolFDs[0]);
       return;
     }
 
     d_clientThreads.at(d_numberOfThreads) = std::move(worker);
     ++d_numberOfThreads;
   }
+  catch (const std::exception& e) {
+    errlog("Error creating the DoH channel: %s", e.what());
+    return;
+  }
 #else /* HAVE_NGHTTP2 */
   throw std::runtime_error("DoHClientCollection::addThread() called but nghttp2 support is not available");
 #endif /* HAVE_NGHTTP2 */