From: Remi Gacogne Date: Thu, 17 Mar 2022 15:26:08 +0000 (+0100) Subject: dnsdist: Refactoring of the TCP/TLS workers using channels X-Git-Tag: rec-5.0.0-alpha1~161^2~22 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=dd5381f58f90bbbcf650825d65b79cee6243486c;p=thirdparty%2Fpdns.git dnsdist: Refactoring of the TCP/TLS workers using channels --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index a5af69e2ce..17309a4f66 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -118,7 +118,7 @@ std::shared_ptr IncomingTCPConnectionState::getDownstrea return downstream; } -static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector tcpAcceptStates); +static void tcpClientThread(pdns::channel::Receiver&& queryReceiver, pdns::channel::Receiver&& crossProtocolQueryReceiver, pdns::channel::Receiver&& crossProtocolResponseReceiver, pdns::channel::Sender&& crossProtocolResponseSender, std::vector tcpAcceptStates); TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector tcpAcceptStates): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) { @@ -129,83 +129,37 @@ TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector& tcpAcceptStates) { - auto preparePipe = [](int fds[2], const std::string& type) -> bool { - if (pipe(fds) < 0) { - errlog("Error creating the TCP thread %s pipe: %s", type, stringerror()); - return false; - } - - if (!setNonBlocking(fds[0])) { - int err = errno; - close(fds[0]); - close(fds[1]); - errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); - return false; - } - - if (!setNonBlocking(fds[1])) { - int err = errno; - close(fds[0]); - close(fds[1]); - errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); - return false; - } - - if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) { - setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize); - } - - return true; - }; + try { + auto [queryChannelSender, queryChannelReceiver] = pdns::channel::createObjectQueue(true, g_tcpInternalPipeBufferSize); - int pipefds[2] = { -1, -1}; - if (!preparePipe(pipefds, "communication")) { - return; - } + auto [crossProtocolQueryChannelSender, crossProtocolQueryChannelReceiver] = pdns::channel::createObjectQueue(true, g_tcpInternalPipeBufferSize); - int crossProtocolQueriesFDs[2] = { -1, -1}; - if (!preparePipe(crossProtocolQueriesFDs, "cross-protocol queries")) { - return; - } + auto [crossProtocolResponseChannelSender, crossProtocolResponseChannelReceiver] = pdns::channel::createObjectQueue(true, g_tcpInternalPipeBufferSize); - int crossProtocolResponsesFDs[2] = { -1, -1}; - if (!preparePipe(crossProtocolResponsesFDs, "cross-protocol responses")) { - return; - } + vinfolog("Adding TCP Client thread"); - vinfolog("Adding TCP Client thread"); - - { if (d_numthreads >= d_tcpclientthreads.size()) { vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size()); - close(crossProtocolQueriesFDs[0]); - close(crossProtocolQueriesFDs[1]); - close(crossProtocolResponsesFDs[0]); - close(crossProtocolResponsesFDs[1]); - close(pipefds[0]); - close(pipefds[1]); return; } - /* from now on this side of the pipe will be managed by that object, - no need to worry about it */ - TCPWorkerThread worker(pipefds[1], crossProtocolQueriesFDs[1], crossProtocolResponsesFDs[1]); + TCPWorkerThread worker(std::move(queryChannelSender), std::move(crossProtocolQueryChannelSender)); + try { - std::thread t1(tcpClientThread, pipefds[0], crossProtocolQueriesFDs[0], crossProtocolResponsesFDs[0], crossProtocolResponsesFDs[1], tcpAcceptStates); + std::thread t1(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates); t1.detach(); } catch (const std::runtime_error& e) { - /* the thread creation failed, don't leak */ errlog("Error creating a TCP thread: %s", e.what()); - close(pipefds[0]); - close(crossProtocolQueriesFDs[0]); - close(crossProtocolResponsesFDs[0]); return; } d_tcpclientthreads.at(d_numthreads) = std::move(worker); ++d_numthreads; } + catch (const std::exception& e) { + errlog("Error creating TCP worker: %", e.what()); + } } std::unique_ptr g_tcpclientthreads; @@ -620,23 +574,16 @@ std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response) { - if (d_threadData.crossProtocolResponsesPipe == -1) { - throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender"); - } - std::shared_ptr state = shared_from_this(); - auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now); - static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); - ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr)); - if (sent != sizeof(ptr)) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { + try { + auto ptr = std::make_unique(std::move(response), state, now); + if (!state->d_threadData.crossProtocolResponseSender.send(std::move(ptr))) { ++g_stats.tcpCrossProtocolResponsePipeFull; vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full"); } - else { - vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); - } - delete ptr; + } + catch (const std::exception& e) { + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); } } @@ -1151,111 +1098,82 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param { auto threadData = boost::any_cast(param); - ConnectionInfo* citmp{nullptr}; - - ssize_t got = read(pipefd, &citmp, sizeof(citmp)); - if (got == 0) { - throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); - } - else if (got == -1) { - if (errno == EAGAIN || errno == EINTR) { + std::unique_ptr citmp{nullptr}; + try { + auto tmp = threadData->queryReceiver.receive(); + if (!tmp) { return; } - throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + citmp = std::move(*tmp); } - else if (got != sizeof(citmp)) { - throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + catch (const std::exception& e) { + throw std::runtime_error("Error while reading from the TCP query channel: " + std::string(e.what())); } - try { - g_tcpclientthreads->decrementQueuedCount(); + g_tcpclientthreads->decrementQueuedCount(); - struct timeval now; - gettimeofday(&now, nullptr); - auto state = std::make_shared(std::move(*citmp), *threadData, now); - delete citmp; - citmp = nullptr; - - IncomingTCPConnectionState::handleIO(state, now); - } - catch (...) { - delete citmp; - citmp = nullptr; - throw; - } + struct timeval now; + gettimeofday(&now, nullptr); + auto state = std::make_shared(std::move(*citmp), *threadData, now); + IncomingTCPConnectionState::handleIO(state, now); } static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) { auto threadData = boost::any_cast(param); - CrossProtocolQuery* tmp{nullptr}; - ssize_t got = read(pipefd, &tmp, sizeof(tmp)); - if (got == 0) { - throw std::runtime_error("EOF while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); - } - else if (got == -1) { - if (errno == EAGAIN || errno == EINTR) { + std::unique_ptr cpq{nullptr}; + try { + auto tmp = threadData->crossProtocolQueryReceiver.receive(); + if (!tmp) { return; } - throw std::runtime_error("Error while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + cpq = std::move(*tmp); } - else if (got != sizeof(tmp)) { - throw std::runtime_error("Partial read while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + catch (const std::exception& e) { + throw std::runtime_error("Error while reading from the TCP cross-protocol channel: " + std::string(e.what())); } - try { - struct timeval now; - gettimeofday(&now, nullptr); + struct timeval now; + gettimeofday(&now, nullptr); - std::shared_ptr tqs = tmp->getTCPQuerySender(); - auto query = std::move(tmp->query); - auto downstreamServer = std::move(tmp->downstream); - auto proxyProtocolPayloadSize = tmp->proxyProtocolPayloadSize; - delete tmp; - tmp = nullptr; + std::shared_ptr tqs = cpq->getTCPQuerySender(); + auto query = std::move(cpq->query); + auto downstreamServer = std::move(cpq->downstream); + auto proxyProtocolPayloadSize = cpq->proxyProtocolPayloadSize; - try { - auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string()); + try { + auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string()); - prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); - query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize; + prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); + query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize; - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr()); - downstream->queueQuery(tqs, std::move(query)); - } - catch (...) { - tqs->notifyIOError(std::move(query.d_idstate), now); - } + downstream->queueQuery(tqs, std::move(query)); } catch (...) { - delete tmp; - tmp = nullptr; + tqs->notifyIOError(std::move(query.d_idstate), now); } } static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param) { - TCPCrossProtocolResponse* tmp{nullptr}; + auto threadData = boost::any_cast(param); - ssize_t got = read(pipefd, &tmp, sizeof(tmp)); - if (got == 0) { - throw std::runtime_error("EOF while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); - } - else if (got == -1) { - if (errno == EAGAIN || errno == EINTR) { + std::unique_ptr cpr{nullptr}; + try { + auto tmp = threadData->crossProtocolResponseReceiver.receive(); + if (!tmp) { return; } - throw std::runtime_error("Error while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + cpr = std::move(*tmp); } - else if (got != sizeof(tmp)) { - throw std::runtime_error("Partial read while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + catch (const std::exception& e) { + throw std::runtime_error("Error while reading from the TCP cross-protocol response: " + std::string(e.what())); } - auto response = std::move(*tmp); - delete tmp; - tmp = nullptr; + auto response = std::move(*cpr); try { if (response.d_response.d_buffer.empty()) { @@ -1283,7 +1201,7 @@ struct TCPAcceptorParam static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData); -static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector tcpAcceptStates) +static void tcpClientThread(pdns::channel::Receiver&& queryReceiver, pdns::channel::Receiver&& crossProtocolQueryReceiver, pdns::channel::Receiver&& crossProtocolResponseReceiver, pdns::channel::Sender&& crossProtocolResponseSender, std::vector tcpAcceptStates) { /* we get launched with a pipe on which we receive file descriptors from clients that we own from that point on */ @@ -1292,11 +1210,14 @@ static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int cros try { TCPClientThreadData data; - /* this is the writing end! */ - data.crossProtocolResponsesPipe = crossProtocolResponsesWritePipeFD; - data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data); - data.mplexer->addReadFD(crossProtocolQueriesPipeFD, handleCrossProtocolQuery, &data); - data.mplexer->addReadFD(crossProtocolResponsesListenPipeFD, handleCrossProtocolResponse, &data); + data.crossProtocolResponseSender = std::move(crossProtocolResponseSender); + data.queryReceiver = std::move(queryReceiver); + data.crossProtocolQueryReceiver = std::move(crossProtocolQueryReceiver); + data.crossProtocolResponseReceiver = std::move(crossProtocolResponseReceiver); + + data.mplexer->addReadFD(data.queryReceiver.getDescriptor(), handleIncomingTCPQuery, &data); + data.mplexer->addReadFD(data.crossProtocolQueryReceiver.getDescriptor(), handleCrossProtocolQuery, &data); + data.mplexer->addReadFD(data.crossProtocolResponseReceiver.getDescriptor(), handleCrossProtocolResponse, &data); /* only used in single acceptor mode for now */ auto acl = g_ACL.getLocal(); diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 59c4df410d..04a5ff4da6 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -3,6 +3,8 @@ #include "dolog.hh" #include "dnsdist-tcp.hh" +class TCPCrossProtocolResponse; + class TCPClientThreadData { public: @@ -15,7 +17,10 @@ public: LocalStateHolder> localRespRuleActions; LocalStateHolder> localCacheInsertedRespRuleActions; std::unique_ptr mplexer{nullptr}; - int crossProtocolResponsesPipe{-1}; + pdns::channel::Receiver queryReceiver; + pdns::channel::Receiver crossProtocolQueryReceiver; + pdns::channel::Receiver crossProtocolResponseReceiver; + pdns::channel::Sender crossProtocolResponseSender; }; class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index 3d11f1a4f4..04fb377374 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -22,6 +22,7 @@ #pragma once #include +#include "channel.hh" #include "iputils.hh" #include "dnsdist.hh" @@ -213,20 +214,16 @@ public: } uint64_t pos = d_pos++; - auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_newConnectionPipe.getHandle(); - auto tmp = conn.release(); - /* we need to increment this counter _before_ writing to the pipe, otherwise there is a very real possiblity that the other end decrement the counter before we can increment it, leading to an underflow */ ++d_queued; - if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) { + if (!d_tcpclientthreads.at(pos % d_numthreads).d_querySender.send(std::move(conn))) { --d_queued; ++g_stats.tcpQueryPipeFull; - delete tmp; - tmp = nullptr; return false; } + return true; } @@ -237,13 +234,8 @@ public: } uint64_t pos = d_pos++; - auto pipe = d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQueriesPipe.getHandle(); - auto tmp = cpq.release(); - - if (write(pipe, &tmp, sizeof(tmp)) != sizeof(tmp)) { + if (!d_tcpclientthreads.at(pos % d_numthreads).d_crossProtocolQuerySender.send(std::move(cpq))) { ++g_stats.tcpCrossProtocolQueryPipeFull; - delete tmp; - tmp = nullptr; return false; } @@ -279,8 +271,8 @@ private: { } - TCPWorkerThread(int newConnPipe, int crossProtocolQueriesPipe, int crossProtocolResponsesPipe) : - d_newConnectionPipe(newConnPipe), d_crossProtocolQueriesPipe(crossProtocolQueriesPipe), d_crossProtocolResponsesPipe(crossProtocolResponsesPipe) + TCPWorkerThread(pdns::channel::Sender&& querySender, pdns::channel::Sender&& crossProtocolQuerySender) : + d_querySender(std::move(querySender)), d_crossProtocolQuerySender(std::move(crossProtocolQuerySender)) { } @@ -289,9 +281,8 @@ private: TCPWorkerThread(const TCPWorkerThread& rhs) = delete; TCPWorkerThread& operator=(const TCPWorkerThread&) = delete; - FDWrapper d_newConnectionPipe; - FDWrapper d_crossProtocolQueriesPipe; - FDWrapper d_crossProtocolResponsesPipe; + pdns::channel::Sender d_querySender; + pdns::channel::Sender d_crossProtocolQuerySender; }; std::vector d_tcpclientthreads;