From: Remi Gacogne Date: Fri, 2 Dec 2022 14:57:17 +0000 (+0100) Subject: dnsdist: Get rid of TCPCrossProtocolQuerySender X-Git-Tag: dnsdist-1.8.0-rc1~181^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=516a00075fa9da56992fbea0149d274cd4a6813c;p=thirdparty%2Fpdns.git dnsdist: Get rid of TCPCrossProtocolQuerySender We need this construct to deal with cross-protocol queries, like queries received over TCP or DoT, but forwarded over DoH, because the thread dealing with the client and the one dealing with the backend will not be the same in that case, and we do not want to have different threads touching the same TCP connections. So we pass the query and response to the correct thread via pipes. Until now we were allocating an additional object, TCPCrossProtocolQuerySender, to deal with that case, but I noticed that the existing IncomingTCPConnectionState object already does everything we need, except that it needs to know that the response is a cross-protocol one in order to pass it via the pipe instead of treating it in a different way. This can be done by looking if the current thread ID differs from the one that created this object: if it does, we are dealing with a cross-protocol response and should pass it via the pipe, and if it does not we can deal with it directly. This change saves the need to allocate a new object wrapped in a shared pointer for each cross-protocol query, which is quite nice. --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index c9f4120b77..d2c4ba7937 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -491,6 +491,11 @@ void IncomingTCPConnectionState::updateIO(std::shared_ptr state = shared_from_this(); if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) { @@ -566,66 +571,11 @@ struct TCPCrossProtocolResponse struct timeval d_now; }; -class TCPCrossProtocolQuerySender : public TCPQuerySender -{ -public: - TCPCrossProtocolQuerySender(std::shared_ptr& state): d_state(state) - { - } - - bool active() const override - { - return d_state->active(); - } - - const ClientState* getClientState() const override - { - return d_state->getClientState(); - } - - void handleResponse(const struct timeval& now, TCPResponse&& response) override - { - if (d_state->d_threadData.crossProtocolResponsesPipe == -1) { - throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender"); - } - - auto ptr = new TCPCrossProtocolResponse(std::move(response), d_state, now); - static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); - ssize_t sent = write(d_state->d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr)); - if (sent != sizeof(ptr)) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - ++g_stats.tcpCrossProtocolResponsePipeFull; - vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full"); - } - else { - vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); - } - delete ptr; - } - } - - void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override - { - handleResponse(now, std::move(response)); - } - - void notifyIOError(IDState&& query, const struct timeval& now) override - { - TCPResponse response(PacketBuffer(), std::move(query), nullptr); - handleResponse(now, std::move(response)); - } - -private: - std::shared_ptr d_state; -}; - class TCPCrossProtocolQuery : public CrossProtocolQuery { public: - TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr& ds, std::shared_ptr& sender): d_sender(sender) + TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr ds, std::shared_ptr sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) { - query = InternalQuery(std::move(buffer), std::move(ids)); - downstream = ds; proxyProtocolPayloadSize = 0; } @@ -639,9 +589,31 @@ public: } private: - std::shared_ptr d_sender; + std::shared_ptr d_sender; }; +void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response) +{ + if (d_threadData.crossProtocolResponsesPipe == -1) { + throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender"); + } + + std::shared_ptr state = shared_from_this(); + auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now); + static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr)); + if (sent != sizeof(ptr)) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + ++g_stats.tcpCrossProtocolResponsePipeFull; + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full"); + } + else { + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); + } + delete ptr; + } +} + static void handleQuery(std::shared_ptr& state, const struct timeval& now) { if (state->d_querySize < sizeof(dnsheader)) { @@ -784,8 +756,7 @@ static void handleQuery(std::shared_ptr& state, cons proxyProtocolPayload = getProxyProtocolPayload(dq); } - auto incoming = std::make_shared(state); - auto cpq = std::make_unique(std::move(state->d_buffer), std::move(ids), ds, incoming); + auto cpq = std::make_unique(std::move(state->d_buffer), std::move(ids), ds, state); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); ds->passCrossProtocolQuery(std::move(cpq)); diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 6fa76ab6a1..c128a56e81 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -19,7 +19,7 @@ public: class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this { public: - IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique(*threadData.mplexer, d_ci.fd)), d_threadData(threadData) + IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_mainThreadID(std::this_thread::get_id()) { d_origDest.reset(); d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family; @@ -125,6 +125,8 @@ static void handleTimeout(std::shared_ptr& state, bo void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override; void notifyIOError(IDState&& query, const struct timeval& now) override; + void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response); + void terminateClientConnection(); void queueQuery(TCPQuery&& query); @@ -170,6 +172,7 @@ static void handleTimeout(std::shared_ptr& state, bo size_t d_proxyProtocolNeed{0}; size_t d_queriesCount{0}; size_t d_currentQueriesCount{0}; + std::thread::id d_mainThreadID; uint16_t d_querySize{0}; State d_state{State::doingHandshake}; bool d_isXFR{false}; diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index b7b78582d1..6fb1b827a2 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -159,6 +159,11 @@ struct CrossProtocolQuery { } + CrossProtocolQuery(InternalQuery&& query_, std::shared_ptr& downstream_) : + query(std::move(query_)), downstream(downstream_) + { + } + CrossProtocolQuery(CrossProtocolQuery&& rhs) = delete; virtual ~CrossProtocolQuery() {