From: Remi Gacogne Date: Fri, 24 Nov 2023 08:27:39 +0000 (+0100) Subject: dnsdist: Delint dnsdist-tcp.cc X-Git-Tag: dnsdist-1.9.0-alpha4~14^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1950a055306e458c41ec1f4a839f1f0044700656;p=thirdparty%2Fpdns.git dnsdist: Delint dnsdist-tcp.cc --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index f5147ce087..f96740d94e 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -66,7 +66,7 @@ size_t g_maxTCPConnectionDuration{0}; #ifdef __linux__ // On Linux this gives us 128k pending queries (default is 8192 queries), // which should be enough to deal with huge spikes -size_t g_tcpInternalPipeBufferSize{1024*1024}; +size_t g_tcpInternalPipeBufferSize{1048576U}; uint64_t g_maxTCPQueuedConnections{10000}; #else size_t g_tcpInternalPipeBufferSize{0}; @@ -85,11 +85,11 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState() dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(d_ci.remote); if (d_ci.cs != nullptr) { - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); auto diff = now - d_connectionStartTime; - d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0); + d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000 + diff.tv_usec / 1000); } // would have been done when the object is destroyed anyway, @@ -114,16 +114,14 @@ size_t IncomingTCPConnectionState::clearAllDownstreamConnections() return t_downstreamTCPConnectionsManager.clear(); } -std::shared_ptr IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr& ds, const std::unique_ptr>& tlvs, const struct timeval& now) +std::shared_ptr IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr& backend, const std::unique_ptr>& tlvs, const struct timeval& now) { - std::shared_ptr downstream{nullptr}; - - downstream = getOwnedDownstreamConnection(ds, tlvs); + auto downstream = getOwnedDownstreamConnection(backend, tlvs); if (!downstream) { /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */ - downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, ds, now, std::string()); - if (ds->d_config.useProxyProtocol) { + downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, backend, now, std::string()); + if (backend->d_config.useProxyProtocol) { registerOwnedDownstreamConnection(downstream); } } @@ -133,7 +131,8 @@ std::shared_ptr IncomingTCPConnectionState::getDownstrea static void tcpClientThread(pdns::channel::Receiver&& queryReceiver, pdns::channel::Receiver&& crossProtocolQueryReceiver, pdns::channel::Receiver&& crossProtocolResponseReceiver, pdns::channel::Sender&& crossProtocolResponseSender, std::vector tcpAcceptStates); -TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector tcpAcceptStates): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) +TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector tcpAcceptStates) : + d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) { for (size_t idx = 0; idx < maxThreads; idx++) { addTCPClientThread(tcpAcceptStates); @@ -159,8 +158,8 @@ void TCPClientCollection::addTCPClientThread(std::vector& tcpAccep TCPWorkerThread worker(std::move(queryChannelSender), std::move(crossProtocolQueryChannelSender)); try { - std::thread t1(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates); - t1.detach(); + std::thread clientThread(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates); + clientThread.detach(); } catch (const std::runtime_error& e) { errlog("Error creating a TCP thread: %s", e.what()); @@ -182,7 +181,7 @@ static IOState sendQueuedResponses(std::shared_ptr& IOState result = IOState::Done; while (state->active() && !state->d_queuedResponses.empty()) { - DEBUGLOG("queue size is "<d_queuedResponses.size()<<", sending the next one"); + DEBUGLOG("queue size is " << state->d_queuedResponses.size() << ", sending the next one"); TCPResponse resp = std::move(state->d_queuedResponses.front()); state->d_queuedResponses.pop_front(); state->d_state = IncomingTCPConnectionState::State::idle; @@ -204,18 +203,19 @@ void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse --d_currentQueriesCount; - const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds; - if (currentResponse.d_idstate.selfGenerated == false && ds) { + const auto& backend = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds; + if (!currentResponse.d_idstate.selfGenerated && backend) { const auto& ids = currentResponse.d_idstate; double udiff = ids.queryRealTime.udiff(); - vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), currentResponse.d_buffer.size(), udiff); + vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", backend->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), currentResponse.d_buffer.size(), udiff); - auto backendProtocol = ds->getProtocol(); + auto backendProtocol = backend->getProtocol(); if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) { backendProtocol = dnsdist::Protocol::DoTCP; } - ::handleResponseSent(ids, udiff, d_ci.remote, ds->d_config.remote, static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true); - } else { + ::handleResponseSent(ids, udiff, d_ci.remote, backend->d_config.remote, static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true); + } + else { const auto& ids = currentResponse.d_idstate; ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false); } @@ -231,11 +231,11 @@ static void prependSizeToTCPQuery(PacketBuffer& buffer, size_t proxyProtocolPayl } uint16_t queryLen = proxyProtocolPayloadSize > 0 ? (buffer.size() - proxyProtocolPayloadSize) : buffer.size(); - const uint8_t sizeBytes[] = { static_cast(queryLen / 256), static_cast(queryLen % 256) }; + const std::array sizeBytes{static_cast(queryLen / 256), static_cast(queryLen % 256)}; /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes that could occur if we had to deal with the size during the processing, especially alignment issues */ - buffer.insert(buffer.begin() + proxyProtocolPayloadSize, sizeBytes, sizeBytes + 2); + buffer.insert(buffer.begin() + static_cast(proxyProtocolPayloadSize), sizeBytes.begin(), sizeBytes.end()); } bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) @@ -247,11 +247,11 @@ bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) // for DoH, this is already handled by the underlying library if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) { - DEBUGLOG("not accepting new queries because we already have "<d_maxInFlightQueriesPerConn); + DEBUGLOG("not accepting new queries because we already have " << d_currentQueriesCount << " out of " << d_ci.cs->d_maxInFlightQueriesPerConn); return false; } - if (g_maxTCPQueriesPerConn && d_queriesCount > g_maxTCPQueriesPerConn) { + if (g_maxTCPQueriesPerConn != 0 && d_queriesCount > g_maxTCPQueriesPerConn) { vinfolog("not accepting new queries from %s because it reached the maximum number of queries per conn (%d / %d)", d_ci.remote.toStringWithPort(), d_queriesCount, g_maxTCPQueriesPerConn); return false; } @@ -272,21 +272,21 @@ void IncomingTCPConnectionState::resetForNewQuery() d_state = State::waitingForQuery; } -std::shared_ptr IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs) +std::shared_ptr IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr& backend, const std::unique_ptr>& tlvs) { - auto it = d_ownedConnectionsToBackend.find(ds); - if (it == d_ownedConnectionsToBackend.end()) { - DEBUGLOG("no owned connection found for "<getName()); + auto connIt = d_ownedConnectionsToBackend.find(backend); + if (connIt == d_ownedConnectionsToBackend.end()) { + DEBUGLOG("no owned connection found for " << backend->getName()); return nullptr; } - for (auto& conn : it->second) { + for (auto& conn : connIt->second) { if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) { - DEBUGLOG("Got one owned connection accepting more for "<getName()); + DEBUGLOG("Got one owned connection accepting more for " << backend->getName()); conn->setReused(); return conn; } - DEBUGLOG("not accepting more for "<getName()); + DEBUGLOG("not accepting more for " << backend->getName()); } return nullptr; @@ -302,30 +302,29 @@ IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPR { d_state = State::sendingResponse; - uint16_t responseSize = static_cast(response.d_buffer.size()); - const uint8_t sizeBytes[] = { static_cast(responseSize / 256), static_cast(responseSize % 256) }; + const auto responseSize = static_cast(response.d_buffer.size()); + const std::array sizeBytes{static_cast(responseSize / 256), static_cast(responseSize % 256)}; /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes that could occur if we had to deal with the size during the processing, especially alignment issues */ - response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2); + response.d_buffer.insert(response.d_buffer.begin(), sizeBytes.begin(), sizeBytes.end()); d_currentPos = 0; d_currentResponse = std::move(response); try { auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size()); if (iostate == IOState::Done) { - DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); + DEBUGLOG("response sent from " << __PRETTY_FUNCTION__); handleResponseSent(d_currentResponse); return iostate; - } else { - d_lastIOBlocked = true; - DEBUGLOG("partial write"); - return iostate; } + d_lastIOBlocked = true; + DEBUGLOG("partial write"); + return iostate; } catch (const std::exception& e) { vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what()); - DEBUGLOG("Closing TCP client connection: "<tcpDiedSendingResponse; terminateClientConnection(); @@ -362,14 +361,13 @@ void IncomingTCPConnectionState::terminateClientConnection() /* we might already be waiting, but we might also not because sometimes we have already been notified via the descriptor, not received Async again, but the async job still exists.. */ auto state = shared_from_this(); - for (const auto fd : afds) { + for (const auto desc : afds) { try { - state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state); + state->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, state); } catch (...) { } } - } } @@ -377,14 +375,13 @@ void IncomingTCPConnectionState::queueResponse(std::shared_ptrd_queuedResponses.emplace_back(std::move(response)); - DEBUGLOG("queueing response, state is "<<(int)state->d_state<<", queue size is now "<d_queuedResponses.size()); + DEBUGLOG("queueing response, state is " << (int)state->d_state << ", queue size is now " << state->d_queuedResponses.size()); // when the response comes from a backend, there is a real possibility that we are currently // idle, and thus not trying to send the response right away would make our ref count go to 0. // Even if we are waiting for a query, we will not wake up before the new query arrives or a // timeout occurs - if (state->d_state == State::idle || - state->d_state == State::waitingForQuery) { + if (state->d_state == State::idle || state->d_state == State::waitingForQuery) { auto iostate = sendQueuedResponses(state, now); if (iostate == IOState::Done && state->active()) { @@ -416,7 +413,7 @@ void IncomingTCPConnectionState::queueResponse(std::shared_ptr>(param); @@ -446,8 +443,8 @@ void IncomingTCPConnectionState::updateIO(std::shared_ptrd_handler.getAsyncFDs(); - for (const auto fd : fds) { - state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state); + for (const auto desc : fds) { + state->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, state); } state->d_ioState->update(IOState::Done, handleIOCallback, state); } @@ -498,28 +495,28 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe if (!response.isAsync()) { try { auto& ids = response.d_idstate; - unsigned int qnameWireLength; - std::shared_ptr ds = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr); - if (!ds || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, ds, qnameWireLength)) { + unsigned int qnameWireLength{0}; + std::shared_ptr backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr); + if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend, qnameWireLength)) { state->terminateClientConnection(); return; } - if (ds) { - ++ds->responses; + if (backend != nullptr) { + ++backend->responses; } - DNSResponse dr(ids, response.d_buffer, ds); - dr.d_incomingTCPState = state; + DNSResponse dnsResponse(ids, response.d_buffer, backend); + dnsResponse.d_incomingTCPState = state; - memcpy(&response.d_cleartextDH, dr.getHeader().get(), sizeof(response.d_cleartextDH)); + memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH)); - if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dnsResponse, false)) { state->terminateClientConnection(); return; } - if (dr.isAsynchronous()) { + if (dnsResponse.isAsynchronous()) { /* we are done for now */ return; } @@ -539,9 +536,15 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe struct TCPCrossProtocolResponse { - TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr& state, const struct timeval& now): d_response(std::move(response)), d_state(state), d_now(now) + TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr& state, const struct timeval& now) : + d_response(std::move(response)), d_state(state), d_now(now) { } + TCPCrossProtocolResponse(const TCPCrossProtocolResponse&) = delete; + TCPCrossProtocolResponse& operator=(const TCPCrossProtocolResponse&) = delete; + TCPCrossProtocolResponse(TCPCrossProtocolResponse&&) = delete; + TCPCrossProtocolResponse& operator=(TCPCrossProtocolResponse&&) = delete; + ~TCPCrossProtocolResponse() = default; TCPResponse d_response; std::shared_ptr d_state; @@ -551,13 +554,15 @@ struct TCPCrossProtocolResponse class TCPCrossProtocolQuery : public CrossProtocolQuery { public: - TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr ds, std::shared_ptr sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) - { - } - - ~TCPCrossProtocolQuery() + TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr backend, std::shared_ptr sender) : + CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), backend), d_sender(std::move(sender)) { } + TCPCrossProtocolQuery(const TCPCrossProtocolQuery&) = delete; + TCPCrossProtocolQuery& operator=(const TCPCrossProtocolQuery&) = delete; + TCPCrossProtocolQuery(TCPCrossProtocolQuery&&) = delete; + TCPCrossProtocolQuery& operator=(TCPCrossProtocolQuery&&) = delete; + ~TCPCrossProtocolQuery() override = default; std::shared_ptr getTCPQuerySender() override { @@ -567,37 +572,37 @@ public: DNSQuestion getDQ() override { auto& ids = query.d_idstate; - DNSQuestion dq(ids, query.d_buffer); - dq.d_incomingTCPState = d_sender; - return dq; + DNSQuestion dnsQuestion(ids, query.d_buffer); + dnsQuestion.d_incomingTCPState = d_sender; + return dnsQuestion; } DNSResponse getDR() override { auto& ids = query.d_idstate; - DNSResponse dr(ids, query.d_buffer, downstream); - dr.d_incomingTCPState = d_sender; - return dr; + DNSResponse dnsResponse(ids, query.d_buffer, downstream); + dnsResponse.d_incomingTCPState = d_sender; + return dnsResponse; } private: std::shared_ptr d_sender; }; -std::unique_ptr IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr& ds) +std::unique_ptr IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr& backend) { - return std::make_unique(std::move(query), std::move(state), ds, shared_from_this()); + return std::make_unique(std::move(query), std::move(state), backend, shared_from_this()); } -std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq) +std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion) { - auto state = dq.getIncomingTCPState(); + auto state = dnsQuestion.getIncomingTCPState(); if (!state) { throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state"); } - dq.ids.origID = dq.getHeader()->id; - return std::make_unique(std::move(dq.getMutableData()), std::move(dq.ids), nullptr, std::move(state)); + dnsQuestion.ids.origID = dnsQuestion.getHeader()->id; + return std::make_unique(std::move(dnsQuestion.getMutableData()), std::move(dnsQuestion.ids), nullptr, std::move(state)); } void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response) @@ -669,14 +674,14 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha { /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ - const dnsheader_aligned dh(query.data()); - if (!checkQueryHeaders(dh.get(), *d_ci.cs)) { + const dnsheader_aligned dnsHeader(query.data()); + if (!checkQueryHeaders(dnsHeader.get(), *d_ci.cs)) { return QueryProcessingResult::InvalidHeaders; } - if (dh->qdcount == 0) { + if (dnsHeader->qdcount == 0) { TCPResponse response; - auto queryID = dh->id; + auto queryID = dnsHeader->id; dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) { header.rcode = RCode::NotImp; header.qr = true; @@ -693,49 +698,52 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha } } - ids.qname = DNSName(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast + ids.qname = DNSName(reinterpret_cast(query.data()), static_cast(query.size()), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); ids.protocol = getProtocol(); if (ids.dnsCryptQuery) { ids.protocol = dnsdist::Protocol::DNSCryptTCP; } - DNSQuestion dq(ids, query); - dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&ids](dnsheader& header) { + DNSQuestion dnsQuestion(ids, query); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) { const uint16_t* flags = getFlagsFromDNSHeader(&header); ids.origFlags = *flags; return true; }); - dq.d_incomingTCPState = state; - dq.sni = d_handler.getServerNameIndication(); + dnsQuestion.d_incomingTCPState = state; + dnsQuestion.sni = d_handler.getServerNameIndication(); if (d_proxyProtocolValues) { /* we need to copy them, because the next queries received on that connection will need to get the _unaltered_ values */ - dq.proxyProtocolValues = make_unique>(*d_proxyProtocolValues); + dnsQuestion.proxyProtocolValues = make_unique>(*d_proxyProtocolValues); } - if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) { - dq.ids.skipCache = true; + if (dnsQuestion.ids.qtype == QType::AXFR || dnsQuestion.ids.qtype == QType::IXFR) { + dnsQuestion.ids.skipCache = true; } if (forwardViaUDPFirst()) { // if there was no EDNS, we add it with a large buffer size // so we can use UDP to talk to the backend. - const dnsheader_aligned dh(query.data()); - if (!dh->arcount) { + const dnsheader_aligned dnsHeader(query.data()); + if (dnsHeader->arcount == 0U) { if (addEDNS(query, 4096, false, 4096, 0)) { - dq.ids.ednsAdded = true; + dnsQuestion.ids.ednsAdded = true; } } } if (streamID) { auto unit = getDOHUnit(*streamID); - dq.ids.du = std::move(unit); + if (unit) { + dnsQuestion.ids.du = std::move(unit); + } } - std::shared_ptr ds; - auto result = processQuery(dq, d_threadData.holders, ds); + std::shared_ptr backend; + auto result = processQuery(dnsQuestion, d_threadData.holders, backend); if (result == ProcessQueryResult::Asynchronous) { /* we are done for now */ @@ -744,7 +752,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha } if (streamID) { - restoreDOHUnit(std::move(dq.ids.du)); + restoreDOHUnit(std::move(dnsQuestion.ids.du)); } if (result == ProcessQueryResult::Drop) { @@ -752,17 +760,17 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha } // the buffer might have been invalidated by now - uint16_t queryID; + uint16_t queryID{0}; { - const auto dh = dq.getHeader(); - queryID = dh->id; + const auto dnsHeader = dnsQuestion.getHeader(); + queryID = dnsHeader->id; } if (result == ProcessQueryResult::SendAnswer) { TCPResponse response; { - const auto dh = dq.getHeader(); - memcpy(&response.d_cleartextDH, dh.get(), sizeof(response.d_cleartextDH)); + const auto dnsHeader = dnsQuestion.getHeader(); + memcpy(&response.d_cleartextDH, dnsHeader.get(), sizeof(response.d_cleartextDH)); } response.d_idstate = std::move(ids); response.d_idstate.origID = queryID; @@ -776,71 +784,76 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha return QueryProcessingResult::SelfAnswered; } - if (result != ProcessQueryResult::PassToBackend || ds == nullptr) { + if (result != ProcessQueryResult::PassToBackend || backend == nullptr) { return QueryProcessingResult::NoBackend; } - dq.ids.origID = queryID; + dnsQuestion.ids.origID = queryID; ++d_currentQueriesCount; std::string proxyProtocolPayload; - if (ds->isDoH()) { - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), ds->getNameWithAddr()); + if (backend->isDoH()) { + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), backend->getNameWithAddr()); /* we need to do this _before_ creating the cross protocol query because after that the buffer will have been moved */ - if (ds->d_config.useProxyProtocol) { - proxyProtocolPayload = getProxyProtocolPayload(dq); + if (backend->d_config.useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion); } - auto cpq = std::make_unique(std::move(query), std::move(ids), ds, state); + auto cpq = std::make_unique(std::move(query), std::move(ids), backend, state); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - ds->passCrossProtocolQuery(std::move(cpq)); + backend->passCrossProtocolQuery(std::move(cpq)); return QueryProcessingResult::Forwarded; } - else if (!ds->isTCPOnly() && forwardViaUDPFirst()) { - auto unit = getDOHUnit(*streamID); - dq.ids.du = std::move(unit); - if (assignOutgoingUDPQueryToBackend(ds, queryID, dq, query)) { + if (!backend->isTCPOnly() && forwardViaUDPFirst()) { + if (streamID) { + auto unit = getDOHUnit(*streamID); + if (unit) { + dnsQuestion.ids.du = std::move(unit); + } + } + if (assignOutgoingUDPQueryToBackend(backend, queryID, dnsQuestion, query)) { return QueryProcessingResult::Forwarded; } - restoreDOHUnit(std::move(dq.ids.du)); + restoreDOHUnit(std::move(dnsQuestion.ids.du)); // fallback to the normal flow } prependSizeToTCPQuery(query, 0); - auto downstreamConnection = getDownstreamConnection(ds, dq.proxyProtocolValues, now); + auto downstreamConnection = getDownstreamConnection(backend, dnsQuestion.proxyProtocolValues, now); - if (ds->d_config.useProxyProtocol) { + if (backend->d_config.useProxyProtocol) { /* if we ever sent a TLV over a connection, we can never go back */ if (!d_proxyProtocolPayloadHasTLV) { - d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty(); + d_proxyProtocolPayloadHasTLV = dnsQuestion.proxyProtocolValues && !dnsQuestion.proxyProtocolValues->empty(); } - proxyProtocolPayload = getProxyProtocolPayload(dq); + proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion); } - if (dq.proxyProtocolValues) { - downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues)); + if (dnsQuestion.proxyProtocolValues) { + downstreamConnection->setProxyProtocolValuesSent(std::move(dnsQuestion.proxyProtocolValues)); } TCPQuery tcpquery(std::move(query), std::move(ids)); tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), ds->getNameWithAddr()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), backend->getNameWithAddr()); std::shared_ptr incoming = state; downstreamConnection->queueQuery(incoming, std::move(tcpquery)); return QueryProcessingResult::Forwarded; } -void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) +void IncomingTCPConnectionState::handleIOCallback(int desc, FDMultiplexer::funcparam_t& param) { auto conn = boost::any_cast>(param); - if (fd != conn->d_handler.getDescriptor()) { - throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor())); + if (desc != conn->d_handler.getDescriptor()) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay): __PRETTY_FUNCTION__ is fine + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(desc) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor())); } conn->handleIO(); @@ -879,7 +892,7 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand ++dnsdist::metrics::g_stats.proxyProtocolInvalid; return ProxyProtocolResult::Error; } - else if (remaining < 0) { + if (remaining < 0) { d_proxyProtocolNeed += -remaining; d_buffer.resize(d_currentPos + d_proxyProtocolNeed); /* we need to keep reading, since we might have buffered data */ @@ -902,8 +915,7 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand else { d_lastIOBlocked = true; } - } - while (active() && !d_lastIOBlocked); + } while (active() && !d_lastIOBlocked); return ProxyProtocolResult::Reading; } @@ -932,12 +944,119 @@ IOState IncomingTCPConnectionState::handleHandshake(const struct timeval& now) return iostate; } +IOState IncomingTCPConnectionState::handleIncomingQueryReceived(const struct timeval& now) +{ + DEBUGLOG("query received"); + d_buffer.resize(d_querySize); + + d_state = State::idle; + auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt); + switch (processingResult) { + case QueryProcessingResult::TooSmall: + /* fall-through */ + case QueryProcessingResult::InvalidHeaders: + /* fall-through */ + case QueryProcessingResult::Dropped: + /* fall-through */ + case QueryProcessingResult::NoBackend: + terminateClientConnection(); + ; + default: + break; + } + + /* the state might have been updated in the meantime, we don't want to override it + in that case */ + if (active() && d_state != State::idle) { + if (d_ioState->isWaitingForRead()) { + return IOState::NeedRead; + } + if (d_ioState->isWaitingForWrite()) { + return IOState::NeedWrite; + } + return IOState::Done; + } + return IOState::Done; +}; + +void IncomingTCPConnectionState::handleExceptionDuringIO(const std::exception& exp) +{ + if (d_state == State::idle || d_state == State::waitingForQuery) { + /* no need to increase any counters in that case, the client is simply done with us */ + } + else if (d_state == State::doingHandshake || d_state != State::readingProxyProtocolHeader || d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery) { + ++d_ci.cs->tcpDiedReadingQuery; + } + else if (d_state == State::sendingResponse) { + /* unlikely to happen here, the exception should be handled in sendResponse() */ + ++d_ci.cs->tcpDiedSendingResponse; + } + + if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) { + DEBUGLOG("Got an exception while handling TCP query: " << exp.what()); + vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), exp.what()); + } + else { + vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), exp.what()); + DEBUGLOG("Closing TCP client connection: " << exp.what()); + } + /* remove this FD from the IO multiplexer */ + terminateClientConnection(); +} + +bool IncomingTCPConnectionState::readIncomingQuery(const timeval& now, IOState& iostate) +{ + if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize)) { + DEBUGLOG("reading query size"); + d_buffer.resize(sizeof(uint16_t)); + iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t)); + if (d_currentPos > 0) { + /* if we got at least one byte, we can't go around sending responses */ + d_state = State::readingQuerySize; + } + + if (iostate == IOState::Done) { + DEBUGLOG("query size received"); + d_state = State::readingQuery; + d_querySizeReadTime = now; + if (d_queriesCount == 0) { + d_firstQuerySizeReadTime = now; + } + d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1); + if (d_querySize < sizeof(dnsheader)) { + /* go away */ + terminateClientConnection(); + return true; + } + + d_buffer.resize(d_querySize); + d_currentPos = 0; + } + else { + d_lastIOBlocked = true; + } + } + + if (!d_lastIOBlocked && d_state == State::readingQuery) { + DEBUGLOG("reading query"); + iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize); + if (iostate == IOState::Done) { + iostate = handleIncomingQueryReceived(now); + } + else { + d_lastIOBlocked = true; + } + } + + return false; +} + void IncomingTCPConnectionState::handleIO() { // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read // even though the underlying socket is not ready, so we need to actually ask for the data first IOState iostate = IOState::Done; - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); do { @@ -947,7 +1066,7 @@ void IncomingTCPConnectionState::handleIO() if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort()); // will be handled by the ioGuard - //handleNewIOState(state, IOState::Done, fd, handleIOCallback); + // handleNewIOState(state, IOState::Done, fd, handleIOCallback); return; } @@ -991,77 +1110,9 @@ void IncomingTCPConnectionState::handleIO() } } - if (!d_lastIOBlocked && (d_state == State::waitingForQuery || - d_state == State::readingQuerySize)) { - DEBUGLOG("reading query size"); - d_buffer.resize(sizeof(uint16_t)); - iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t)); - if (d_currentPos > 0) { - /* if we got at least one byte, we can't go around sending responses */ - d_state = State::readingQuerySize; - } - - if (iostate == IOState::Done) { - DEBUGLOG("query size received"); - d_state = State::readingQuery; - d_querySizeReadTime = now; - if (d_queriesCount == 0) { - d_firstQuerySizeReadTime = now; - } - d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1); - if (d_querySize < sizeof(dnsheader)) { - /* go away */ - terminateClientConnection(); - return; - } - - d_buffer.resize(d_querySize); - d_currentPos = 0; - } - else { - d_lastIOBlocked = true; - } - } - - if (!d_lastIOBlocked && d_state == State::readingQuery) { - DEBUGLOG("reading query"); - iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize); - if (iostate == IOState::Done) { - DEBUGLOG("query received"); - d_buffer.resize(d_querySize); - - d_state = State::idle; - auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt); - switch (processingResult) { - case QueryProcessingResult::TooSmall: - /* fall-through */ - case QueryProcessingResult::InvalidHeaders: - /* fall-through */ - case QueryProcessingResult::Dropped: - /* fall-through */ - case QueryProcessingResult::NoBackend: - terminateClientConnection(); - break; - default: - break; - } - - /* the state might have been updated in the meantime, we don't want to override it - in that case */ - if (active() && d_state != State::idle) { - if (d_ioState->isWaitingForRead()) { - iostate = IOState::NeedRead; - } - else if (d_ioState->isWaitingForWrite()) { - iostate = IOState::NeedWrite; - } - else { - iostate = IOState::Done; - } - } - } - else { - d_lastIOBlocked = true; + if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery)) { + if (readIncomingQuery(now, iostate)) { + return; } } @@ -1069,7 +1120,7 @@ void IncomingTCPConnectionState::handleIO() DEBUGLOG("sending response"); iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size()); if (iostate == IOState::Done) { - DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); + DEBUGLOG("response sent from " << __PRETTY_FUNCTION__); handleResponseSent(d_currentResponse); d_state = State::idle; } @@ -1078,12 +1129,7 @@ void IncomingTCPConnectionState::handleIO() } } - if (active() && - !d_lastIOBlocked && - iostate == IOState::Done && - (d_state == State::idle || - d_state == State::waitingForQuery)) - { + if (active() && !d_lastIOBlocked && iostate == IOState::Done && (d_state == State::idle || d_state == State::waitingForQuery)) { // try sending queued responses DEBUGLOG("send responses, if any"); auto state = shared_from_this(); @@ -1103,47 +1149,16 @@ void IncomingTCPConnectionState::handleIO() } } - if (d_state != State::idle && - d_state != State::doingHandshake && - d_state != State::readingProxyProtocolHeader && - d_state != State::waitingForQuery && - d_state != State::readingQuerySize && - d_state != State::readingQuery && - d_state != State::sendingResponse) { + if (d_state != State::idle && d_state != State::doingHandshake && d_state != State::readingProxyProtocolHeader && d_state != State::waitingForQuery && d_state != State::readingQuerySize && d_state != State::readingQuery && d_state != State::sendingResponse) { vinfolog("Unexpected state %d in handleIOCallback", static_cast(d_state)); } } - catch (const std::exception& e) { + catch (const std::exception& exp) { /* most likely an EOF because the other end closed the connection, but it might also be a real IO error or something else. Let's just drop the connection */ - if (d_state == State::idle || - d_state == State::waitingForQuery) { - /* no need to increase any counters in that case, the client is simply done with us */ - } - else if (d_state == State::doingHandshake || - d_state != State::readingProxyProtocolHeader || - d_state == State::waitingForQuery || - d_state == State::readingQuerySize || - d_state == State::readingQuery) { - ++d_ci.cs->tcpDiedReadingQuery; - } - else if (d_state == State::sendingResponse) { - /* unlikely to happen here, the exception should be handled in sendResponse() */ - ++d_ci.cs->tcpDiedSendingResponse; - } - - if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) { - DEBUGLOG("Got an exception while handling TCP query: "<isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), e.what()); - } - else { - vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what()); - DEBUGLOG("Closing TCP client connection: "<active() && iostate != IOState::Done) { // we need to update the state right away, nobody will do that for us - updateIO(state, iostate, now); + updateIO(state, iostate, now); } } catch (const std::exception& e) { @@ -1214,7 +1228,7 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptrd_ci.remote.toStringWithPort()); DEBUGLOG("client timeout"); - DEBUGLOG("Processed "<d_queriesCount<<" queries, current count is "<d_currentQueriesCount<<", "<d_ownedConnectionsToBackend.size()<<" owned connections, "<d_queuedResponses.size()<<" response queued"); + DEBUGLOG("Processed " << state->d_queriesCount << " queries, current count is " << state->d_currentQueriesCount << ", " << state->d_ownedConnectionsToBackend.size() << " owned connections, " << state->d_queuedResponses.size() << " response queued"); if (write || state->d_currentQueriesCount == 0) { ++state->d_ci.cs->tcpClientTimeouts; @@ -1230,7 +1244,7 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptr(param); + auto* threadData = boost::any_cast(param); std::unique_ptr citmp{nullptr}; try { @@ -1246,7 +1260,7 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param g_tcpclientthreads->decrementQueuedCount(); - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); if (citmp->cs->dohFrontend) { @@ -1263,7 +1277,7 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) { - auto threadData = boost::any_cast(param); + auto* threadData = boost::any_cast(param); std::unique_ptr cpq{nullptr}; try { @@ -1277,7 +1291,7 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par throw std::runtime_error("Error while reading from the TCP cross-protocol channel: " + std::string(e.what())); } - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); std::shared_ptr tqs = cpq->getTCPQuerySender(); @@ -1300,7 +1314,7 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param) { - auto threadData = boost::any_cast(param); + auto* threadData = boost::any_cast(param); std::unique_ptr cpr{nullptr}; try { @@ -1314,7 +1328,7 @@ static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& throw std::runtime_error("Error while reading from the TCP cross-protocol response: " + std::string(e.what())); } - auto response = std::move(*cpr); + auto& response = *cpr; try { if (response.d_response.d_buffer.empty()) { @@ -1334,14 +1348,113 @@ static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& struct TCPAcceptorParam { - ClientState& cs; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + ClientState& clientState; ComboAddress local; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) LocalStateHolder& acl; int socket{-1}; }; static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData); +static void scanForTimeouts(const TCPClientThreadData& data, const timeval& now) +{ + auto expiredReadConns = data.mplexer->getTimeouts(now, false); + for (const auto& cbData : expiredReadConns) { + if (cbData.second.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(state, false); + } + } +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + else if (cbData.second.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (read) from remote H2 client %s", state->d_ci.remote.toStringWithPort()); + std::shared_ptr parentState = state; + state->handleTimeout(parentState, false); + } + } +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ + else if (cbData.second.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(cbData.second); + vinfolog("Timeout (read) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, false); + } + } + + auto expiredWriteConns = data.mplexer->getTimeouts(now, true); + for (const auto& cbData : expiredWriteConns) { + if (cbData.second.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(state, true); + } + } +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + else if (cbData.second.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (write) from remote H2 client %s", state->d_ci.remote.toStringWithPort()); + std::shared_ptr parentState = state; + state->handleTimeout(parentState, true); + } + } +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ + else if (cbData.second.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(cbData.second); + vinfolog("Timeout (write) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, true); + } + } +} + +static void dumpTCPStates(const TCPClientThreadData& data) +{ + /* just to keep things clean in the output, debug only */ + static std::mutex s_lock; + std::lock_guard lck(s_lock); + if (g_tcpStatesDumpRequested > 0) { + /* no race here, we took the lock so it can only be increased in the meantime */ + --g_tcpStatesDumpRequested; + infolog("Dumping the TCP states, as requested:"); + data.mplexer->runForAllWatchedFDs([](bool isRead, int desc, const FDMultiplexer::funcparam_t& param, struct timeval ttd) { + timeval lnow{}; + gettimeofday(&lnow, nullptr); + if (ttd.tv_sec > 0) { + infolog("- Descriptor %d is in %s state, TTD in %d", desc, (isRead ? "read" : "write"), (ttd.tv_sec - lnow.tv_sec)); + } + else { + infolog("- Descriptor %d is in %s state, no TTD set", desc, (isRead ? "read" : "write")); + } + + if (param.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(param); + infolog(" - %s", state->toString()); + } +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + else if (param.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(param); + infolog(" - %s", state->toString()); + } +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ + else if (param.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(param); + infolog(" - %s", conn->toString()); + } + else if (param.type() == typeid(TCPClientThreadData*)) { + infolog(" - Worker thread pipe"); + } + }); + infolog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount()); + } +} + +// NOLINTNEXTLINE(performance-unnecessary-value-param): you are wrong, clang-tidy, go home static void tcpClientThread(pdns::channel::Receiver&& queryReceiver, pdns::channel::Receiver&& crossProtocolQueryReceiver, pdns::channel::Receiver&& crossProtocolResponseReceiver, pdns::channel::Sender&& crossProtocolResponseSender, std::vector tcpAcceptStates) { /* we get launched with a pipe on which we receive file descriptors from clients that we own @@ -1373,17 +1486,16 @@ static void tcpClientThread(pdns::channel::Receiver&& queryRecei } auto acceptCallback = [&data](int socket, FDMultiplexer::funcparam_t& funcparam) { - auto acceptorParam = boost::any_cast(funcparam); + const auto* acceptorParam = boost::any_cast(funcparam); acceptNewConnection(*acceptorParam, &data); }; - for (size_t idx = 0; idx < acceptParams.size(); idx++) { - const auto& param = acceptParams.at(idx); + for (const auto& param : acceptParams) { setNonBlocking(param.socket); data.mplexer->addReadFD(param.socket, acceptCallback, ¶m); } - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); time_t lastTimeoutScan = now.tv_sec; @@ -1395,97 +1507,10 @@ static void tcpClientThread(pdns::channel::Receiver&& queryRecei if (now.tv_sec > lastTimeoutScan) { lastTimeoutScan = now.tv_sec; - auto expiredReadConns = data.mplexer->getTimeouts(now, false); - for (const auto& cbData : expiredReadConns) { - if (cbData.second.type() == typeid(std::shared_ptr)) { - auto state = boost::any_cast>(cbData.second); - if (cbData.first == state->d_handler.getDescriptor()) { - vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); - state->handleTimeout(state, false); - } - } -#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) - else if (cbData.second.type() == typeid(std::shared_ptr)) { - auto state = boost::any_cast>(cbData.second); - if (cbData.first == state->d_handler.getDescriptor()) { - vinfolog("Timeout (read) from remote H2 client %s", state->d_ci.remote.toStringWithPort()); - std::shared_ptr parentState = state; - state->handleTimeout(parentState, false); - } - } -#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ - else if (cbData.second.type() == typeid(std::shared_ptr)) { - auto conn = boost::any_cast>(cbData.second); - vinfolog("Timeout (read) from remote backend %s", conn->getBackendName()); - conn->handleTimeout(now, false); - } - } - - auto expiredWriteConns = data.mplexer->getTimeouts(now, true); - for (const auto& cbData : expiredWriteConns) { - if (cbData.second.type() == typeid(std::shared_ptr)) { - auto state = boost::any_cast>(cbData.second); - if (cbData.first == state->d_handler.getDescriptor()) { - vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); - state->handleTimeout(state, true); - } - } -#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) - else if (cbData.second.type() == typeid(std::shared_ptr)) { - auto state = boost::any_cast>(cbData.second); - if (cbData.first == state->d_handler.getDescriptor()) { - vinfolog("Timeout (write) from remote H2 client %s", state->d_ci.remote.toStringWithPort()); - std::shared_ptr parentState = state; - state->handleTimeout(parentState, true); - } - } -#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ - else if (cbData.second.type() == typeid(std::shared_ptr)) { - auto conn = boost::any_cast>(cbData.second); - vinfolog("Timeout (write) from remote backend %s", conn->getBackendName()); - conn->handleTimeout(now, true); - } - } + scanForTimeouts(data, now); if (g_tcpStatesDumpRequested > 0) { - /* just to keep things clean in the output, debug only */ - static std::mutex s_lock; - std::lock_guard lck(s_lock); - if (g_tcpStatesDumpRequested > 0) { - /* no race here, we took the lock so it can only be increased in the meantime */ - --g_tcpStatesDumpRequested; - infolog("Dumping the TCP states, as requested:"); - data.mplexer->runForAllWatchedFDs([](bool isRead, int fd, const FDMultiplexer::funcparam_t& param, struct timeval ttd) - { - struct timeval lnow; - gettimeofday(&lnow, nullptr); - if (ttd.tv_sec > 0) { - infolog("- Descriptor %d is in %s state, TTD in %d", fd, (isRead ? "read" : "write"), (ttd.tv_sec-lnow.tv_sec)); - } - else { - infolog("- Descriptor %d is in %s state, no TTD set", fd, (isRead ? "read" : "write")); - } - - if (param.type() == typeid(std::shared_ptr)) { - auto state = boost::any_cast>(param); - infolog(" - %s", state->toString()); - } -#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) - else if (param.type() == typeid(std::shared_ptr)) { - auto state = boost::any_cast>(param); - infolog(" - %s", state->toString()); - } -#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ - else if (param.type() == typeid(std::shared_ptr)) { - auto conn = boost::any_cast>(param); - infolog(" - %s", conn->toString()); - } - else if (param.type() == typeid(TCPClientThreadData*)) { - infolog(" - Worker thread pipe"); - } - }); - infolog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount()); - } + dumpTCPStates(data); } } } @@ -1501,9 +1526,9 @@ static void tcpClientThread(pdns::channel::Receiver&& queryRecei static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData) { - auto& cs = param.cs; + auto& clientState = param.clientState; auto& acl = param.acl; - const bool checkACL = !cs.dohFrontend || (!cs.dohFrontend->d_trustForwardedForHeader && cs.dohFrontend->d_earlyACLDrop); + const bool checkACL = clientState.dohFrontend == nullptr || (!clientState.dohFrontend->d_trustForwardedForHeader && clientState.dohFrontend->d_earlyACLDrop); const int socket = param.socket; bool tcpClientCountIncremented = false; ComboAddress remote; @@ -1512,16 +1537,18 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa tcpClientCountIncremented = false; try { socklen_t remlen = remote.getSocklen(); - ConnectionInfo ci(&cs); + ConnectionInfo connInfo(&clientState); #ifdef HAVE_ACCEPT4 - ci.fd = accept4(socket, reinterpret_cast(&remote), &remlen, SOCK_NONBLOCK); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + connInfo.fd = accept4(socket, reinterpret_cast(&remote), &remlen, SOCK_NONBLOCK); #else - ci.fd = accept(socket, reinterpret_cast(&remote), &remlen); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + connInfo.fd = accept(socket, reinterpret_cast(&remote), &remlen); #endif // will be decremented when the ConnectionInfo object is destroyed, no matter the reason - auto concurrentConnections = ++cs.tcpCurrentConnections; + auto concurrentConnections = ++clientState.tcpCurrentConnections; - if (ci.fd < 0) { + if (connInfo.fd < 0) { throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str()); } @@ -1531,22 +1558,22 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa return; } - if (cs.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > cs.d_tcpConcurrentConnectionsLimit) { + if (clientState.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > clientState.d_tcpConcurrentConnectionsLimit) { vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort()); return; } - if (concurrentConnections > cs.tcpMaxConcurrentConnections.load()) { - cs.tcpMaxConcurrentConnections.store(concurrentConnections); + if (concurrentConnections > clientState.tcpMaxConcurrentConnections.load()) { + clientState.tcpMaxConcurrentConnections.store(concurrentConnections); } #ifndef HAVE_ACCEPT4 - if (!setNonBlocking(ci.fd)) { + if (!setNonBlocking(connInfo.fd)) { return; } #endif - setTCPNoDelay(ci.fd); // disable NAGLE + setTCPNoDelay(connInfo.fd); // disable NAGLE if (g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) { vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort()); @@ -1561,27 +1588,27 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa vinfolog("Got TCP connection from %s", remote.toStringWithPort()); - ci.remote = remote; + connInfo.remote = remote; if (threadData == nullptr) { - if (!g_tcpclientthreads->passConnectionToThread(std::make_unique(std::move(ci)))) { + if (!g_tcpclientthreads->passConnectionToThread(std::make_unique(std::move(connInfo)))) { if (tcpClientCountIncremented) { dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote); } } } else { - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); - if (ci.cs->dohFrontend) { + if (connInfo.cs->dohFrontend) { #if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) - auto state = std::make_shared(std::move(ci), *threadData, now); + auto state = std::make_shared(std::move(connInfo), *threadData, now); state->handleIO(); #endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ } else { - auto state = std::make_shared(std::move(ci), *threadData, now); + auto state = std::make_shared(std::move(connInfo), *threadData, now); state->handleIO(); } } @@ -1592,14 +1619,15 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote); } } - catch (...){} + catch (...) { + } } /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and they will hand off to worker threads & spawn more of them if required */ #ifndef USE_SINGLE_ACCEPTOR_THREAD -void tcpAcceptorThread(std::vector states) +void tcpAcceptorThread(const std::vector& states) { setThreadName("dnsdist/tcpAcce"); @@ -1607,7 +1635,7 @@ void tcpAcceptorThread(std::vector states) std::vector params; params.reserve(states.size()); - for (auto& state : states) { + for (const auto& state : states) { params.emplace_back(TCPAcceptorParam{*state, state->local, acl, state->tcpFD}); for (const auto& [addr, socket] : state->d_additionalAddresses) { params.emplace_back(TCPAcceptorParam{*state, addr, acl, socket}); @@ -1621,19 +1649,18 @@ void tcpAcceptorThread(std::vector states) } else { auto acceptCallback = [](int socket, FDMultiplexer::funcparam_t& funcparam) { - auto acceptorParam = boost::any_cast(funcparam); + const auto* acceptorParam = boost::any_cast(funcparam); acceptNewConnection(*acceptorParam, nullptr); }; auto mplexer = std::unique_ptr(FDMultiplexer::getMultiplexerSilent(params.size())); - for (size_t idx = 0; idx < params.size(); idx++) { - const auto& param = params.at(idx); + for (const auto& param : params) { mplexer->addReadFD(param.socket, acceptCallback, ¶m); } - struct timeval tv; + timeval now{}; while (true) { - mplexer->run(&tv, -1); + mplexer->run(&now, -1); } } } diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 3122776560..676f80e1ba 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -1112,7 +1112,7 @@ struct LocalHolders LocalStateHolder pools; }; -void tcpAcceptorThread(std::vector states); +void tcpAcceptorThread(const std::vector& states); void setLuaNoSideEffect(); // if nothing has been declared, set that there are no side effects void setLuaSideEffect(); // set to report a side effect, cancelling all _no_ side effect calls diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 2c081fd27e..c6410df0c9 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -116,14 +116,14 @@ public: return false; } - std::shared_ptr getOwnedDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs); - std::shared_ptr getDownstreamConnection(std::shared_ptr& ds, const std::unique_ptr>& tlvs, const struct timeval& now); + std::shared_ptr getOwnedDownstreamConnection(const std::shared_ptr& backend, const std::unique_ptr>& tlvs); + std::shared_ptr getDownstreamConnection(std::shared_ptr& backend, const std::unique_ptr>& tlvs, const struct timeval& now); void registerOwnedDownstreamConnection(std::shared_ptr& conn); static size_t clearAllDownstreamConnections(); - static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); - static void handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param); + static void handleIOCallback(int desc, FDMultiplexer::funcparam_t& param); + static void handleAsyncReady(int desc, FDMultiplexer::funcparam_t& param); static void updateIO(std::shared_ptr& state, IOState newState, const struct timeval& now); static void queueResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response, bool fromBackend); @@ -172,7 +172,7 @@ public: throw std::runtime_error("Restoring a DOHUnit state to a generic TCP/DoT connection is not supported"); } - std::unique_ptr getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr& ds); + std::unique_ptr getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr& backend); std::string toString() const { @@ -182,6 +182,9 @@ public: } dnsdist::Protocol getProtocol() const; + IOState handleIncomingQueryReceived(const struct timeval& now); + void handleExceptionDuringIO(const std::exception& exp); + bool readIncomingQuery(const timeval& now, IOState& iostate); enum class State : uint8_t { starting, doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ }; diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index 199cbcab35..f3d827ebf3 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -218,7 +218,7 @@ struct CrossProtocolQuery class TCPClientCollection { public: - TCPClientCollection(size_t maxThreads, std::vector tcpStates); + TCPClientCollection(size_t maxThreads, std::vector tcpAcceptStates); bool passConnectionToThread(std::unique_ptr&& conn) { @@ -307,4 +307,4 @@ private: extern std::unique_ptr g_tcpclientthreads; -std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq); +std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion);