From: Remi Gacogne Date: Fri, 16 Dec 2022 17:31:33 +0000 (+0100) Subject: dnsdist: Implement async processing of queries and responses X-Git-Tag: dnsdist-1.8.0-rc1~86^2~16 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0859e14a8360f0af5b6a4fe1817f48a7d74182e4;p=thirdparty%2Fpdns.git dnsdist: Implement async processing of queries and responses --- diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index 5b42f986f5..cec3a329ab 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -152,6 +152,7 @@ struct InternalQueryState bool dnssecOK{false}; bool useZeroScope{false}; bool forwardedOverUDP{false}; + bool selfGenerated{false}; }; struct IDState diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index 7e5ecb8df6..e0839bcd13 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -22,6 +22,7 @@ #include "config.h" #include "threadname.hh" #include "dnsdist.hh" +#include "dnsdist-async.hh" #include "dnsdist-ecs.hh" #include "dnsdist-lua.hh" #include "dnsdist-lua-ffi.hh" @@ -492,19 +493,24 @@ public: DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - auto lock = g_lua.lock(); try { - auto ret = d_func(dq); - if (ruleresult) { - if (boost::optional rule = std::get<1>(ret)) { - *ruleresult = *rule; - } - else { - // default to empty string - ruleresult->clear(); + DNSAction::Action result; + { + auto lock = g_lua.lock(); + auto ret = d_func(dq); + if (ruleresult) { + if (boost::optional rule = std::get<1>(ret)) { + *ruleresult = *rule; + } + else { + // default to empty string + ruleresult->clear(); + } } + result = static_cast(std::get<0>(ret)); } - return static_cast(std::get<0>(ret)); + dnsdist::handleQueuedAsynchronousEvents(); + return result; } catch (const std::exception &e) { warnlog("LuaAction failed inside Lua, returning ServFail: %s", e.what()); } catch (...) { @@ -529,19 +535,24 @@ public: {} DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override { - auto lock = g_lua.lock(); try { - auto ret = d_func(dr); - if (ruleresult) { - if (boost::optional rule = std::get<1>(ret)) { - *ruleresult = *rule; - } - else { - // default to empty string - ruleresult->clear(); + DNSResponseAction::Action result; + { + auto lock = g_lua.lock(); + auto ret = d_func(dr); + if (ruleresult) { + if (boost::optional rule = std::get<1>(ret)) { + *ruleresult = *rule; + } + else { + // default to empty string + ruleresult->clear(); + } } + result = static_cast(std::get<0>(ret)); } - return static_cast(std::get<0>(ret)); + dnsdist::handleQueuedAsynchronousEvents(); + return result; } catch (const std::exception &e) { warnlog("LuaResponseAction failed inside Lua, returning ServFail: %s", e.what()); } catch (...) { @@ -571,18 +582,23 @@ public: { dnsdist_ffi_dnsquestion_t dqffi(dq); try { - auto lock = g_lua.lock(); - auto ret = d_func(&dqffi); - if (ruleresult) { - if (dqffi.result) { - *ruleresult = *dqffi.result; - } - else { - // default to empty string - ruleresult->clear(); + DNSAction::Action result; + { + auto lock = g_lua.lock(); + auto ret = d_func(&dqffi); + if (ruleresult) { + if (dqffi.result) { + *ruleresult = *dqffi.result; + } + else { + // default to empty string + ruleresult->clear(); + } } + result = static_cast(ret); } - return static_cast(ret); + dnsdist::handleQueuedAsynchronousEvents(); + return result; } catch (const std::exception &e) { warnlog("LuaFFIAction failed inside Lua, returning ServFail: %s", e.what()); } catch (...) { @@ -636,6 +652,7 @@ public: ruleresult->clear(); } } + dnsdist::handleQueuedAsynchronousEvents(); return static_cast(ret); } catch (const std::exception &e) { @@ -681,18 +698,23 @@ public: { dnsdist_ffi_dnsresponse_t drffi(dr); try { - auto lock = g_lua.lock(); - auto ret = d_func(&drffi); - if (ruleresult) { - if (drffi.result) { - *ruleresult = *drffi.result; - } - else { - // default to empty string - ruleresult->clear(); + DNSResponseAction::Action result; + { + auto lock = g_lua.lock(); + auto ret = d_func(&drffi); + if (ruleresult) { + if (drffi.result) { + *ruleresult = *drffi.result; + } + else { + // default to empty string + ruleresult->clear(); + } } + result = static_cast(ret); } - return static_cast(ret); + dnsdist::handleQueuedAsynchronousEvents(); + return result; } catch (const std::exception &e) { warnlog("LuaFFIResponseAction failed inside Lua, returning ServFail: %s", e.what()); } catch (...) { @@ -746,6 +768,7 @@ public: ruleresult->clear(); } } + dnsdist::handleQueuedAsynchronousEvents(); return static_cast(ret); } catch (const std::exception &e) { diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index fbb04c2d82..2677168f44 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -246,8 +246,8 @@ static void handleResponseSent(std::shared_ptr& stat --state->d_currentQueriesCount; - if (currentResponse.d_selfGenerated == false && currentResponse.d_connection && currentResponse.d_connection->getDS()) { - const auto& ds = currentResponse.d_connection->getDS(); + const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds; + if (currentResponse.d_idstate.selfGenerated == false && ds) { const auto& ids = currentResponse.d_idstate; double udiff = ids.queryRealTime.udiff(); vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff); @@ -498,7 +498,7 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe std::shared_ptr state = shared_from_this(); - if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) { + if (!response.isAsync() && response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) { // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway if (!response.d_connection->willBeReusable(true)) { // if it can't be reused even by us, well @@ -527,32 +527,40 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe return; } - try { - auto& ids = response.d_idstate; - unsigned int qnameWireLength; - if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) { - state->terminateClientConnection(); - return; - } + if (!response.isAsync()) { + try { + auto& ids = response.d_idstate; + unsigned int qnameWireLength; + if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) { + state->terminateClientConnection(); + return; + } - if (response.d_connection->getDS()) { - ++response.d_connection->getDS()->responses; - } + if (response.d_connection->getDS()) { + ++response.d_connection->getDS()->responses; + } - DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS()); + DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS()); + dr.d_incomingTCPState = state; - memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); + memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); - if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) { + state->terminateClientConnection(); + return; + } + + if (dr.isAsynchronous()) { + /* we are done for now */ + return; + } + } + catch (const std::exception& e) { + vinfolog("Unexpected exception while handling response from backend: %s", e.what()); state->terminateClientConnection(); return; } } - catch (const std::exception& e) { - vinfolog("Unexpected exception while handling response from backend: %s", e.what()); - state->terminateClientConnection(); - return; - } ++g_stats.responses; ++state->d_ci.cs->responses; @@ -574,7 +582,7 @@ struct TCPCrossProtocolResponse class TCPCrossProtocolQuery : public CrossProtocolQuery { public: - TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr& ds, std::shared_ptr sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) + TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr ds, std::shared_ptr sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) { proxyProtocolPayloadSize = 0; } @@ -588,10 +596,37 @@ public: return d_sender; } + DNSQuestion getDQ() override + { + auto& ids = query.d_idstate; + DNSQuestion dq(ids, query.d_buffer); + dq.d_incomingTCPState = d_sender; + return dq; + } + + DNSResponse getDR() override + { + auto& ids = query.d_idstate; + DNSResponse dr(ids, query.d_buffer, downstream); + dr.d_incomingTCPState = d_sender; + return dr; + } + private: std::shared_ptr d_sender; }; +std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq) +{ + auto state = dq.getIncomingTCPState(); + if (!state) { + throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state"); + } + + dq.ids.origID = dq.getHeader()->id; + return std::make_unique(std::move(dq.getMutableData()), std::move(dq.ids), nullptr, std::move(state)); +} + void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response) { if (d_threadData.crossProtocolResponsesPipe == -1) { @@ -674,7 +709,7 @@ static void handleQuery(std::shared_ptr& state, cons TCPResponse response; dh->rcode = RCode::NotImp; dh->qr = true; - response.d_selfGenerated = true; + response.d_idstate.selfGenerated = true; response.d_buffer = std::move(state->d_buffer); state->d_state = IncomingTCPConnectionState::State::idle; ++state->d_currentQueriesCount; @@ -695,8 +730,9 @@ static void handleQuery(std::shared_ptr& state, cons DNSQuestion dq(ids, state->d_buffer); const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); ids.origFlags = *flags; - + dq.d_incomingTCPState = state; dq.sni = state->d_handler.getServerNameIndication(); + if (state->d_proxyProtocolValues) { /* we need to copy them, because the next queries received on that connection will need to get the _unaltered_ values */ @@ -708,12 +744,17 @@ static void handleQuery(std::shared_ptr& state, cons } std::shared_ptr ds; - auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, ds); + auto result = processQuery(dq, state->d_threadData.holders, ds); if (result == ProcessQueryResult::Drop) { state->terminateClientConnection(); return; } + else if (result == ProcessQueryResult::Asynchronous) { + /* we are done for now */ + ++state->d_currentQueriesCount; + return; + } // the buffer might have been invalidated by now const dnsheader* dh = dq.getHeader(); @@ -722,6 +763,7 @@ static void handleQuery(std::shared_ptr& state, cons memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); response.d_idstate = std::move(ids); response.d_idstate.origID = dh->id; + response.d_idstate.selfGenerated = true; response.d_idstate.cs = state->d_ci.cs; response.d_buffer = std::move(state->d_buffer); @@ -1399,6 +1441,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa } if (cs.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > cs.d_tcpConcurrentConnectionsLimit) { + vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort()); return; } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index a1a6061eae..ccf32fc8aa 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -48,6 +48,7 @@ #endif #include "dnsdist.hh" +#include "dnsdist-async.hh" #include "dnsdist-cache.hh" #include "dnsdist-carbon.hh" #include "dnsdist-console.hh" @@ -513,18 +514,14 @@ static bool applyRulesToResponse(const std::vector& r return true; } -bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& insertedRespRuleActions, DNSResponse& dr, bool muted) +bool processResponseAfterRules(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted) { - if (!applyRulesToResponse(respRuleActions, dr)) { - return false; - } - bool zeroScope = false; if (!fixUpResponse(response, dr.ids.qname, dr.ids.origFlags, dr.ids.ednsAdded, dr.ids.ecsAdded, dr.ids.useZeroScope ? &zeroScope : nullptr)) { return false; } - if (dr.ids.packetCache && !dr.ids.skipCache && response.size() <= s_maxPacketCacheEntrySize) { + if (dr.ids.packetCache && !dr.ids.selfGenerated && !dr.ids.skipCache && response.size() <= s_maxPacketCacheEntrySize) { if (!dr.ids.useZeroScope) { /* if the query was not suitable for zero-scope, for example because it had an existing ECS entry so the hash is @@ -547,7 +544,7 @@ bool processResponse(PacketBuffer& response, const std::vectorinsert(cacheKey, zeroScope ? boost::none : dr.ids.subnet, dr.ids.cacheFlags, dr.ids.dnssecOK, dr.ids.qname, dr.ids.qtype, dr.ids.qclass, response, dr.ids.forwardedOverUDP, dr.getHeader()->rcode, dr.ids.tempFailureTTL); - if (!applyRulesToResponse(insertedRespRuleActions, dr)) { + if (!applyRulesToResponse(cacheInsertedRespRuleActions, dr)) { return false; } } @@ -569,6 +566,19 @@ bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted) +{ + if (!applyRulesToResponse(respRuleActions, dr)) { + return false; + } + + if (dr.isAsynchronous()) { + return true; + } + + return processResponseAfterRules(response, cacheInsertedRespRuleActions, dr, muted); +} + static size_t getInitialUDPPacketBufferSize() { static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "The incoming buffer size should not be larger than s_initialUDPPacketBufferSize"); @@ -593,7 +603,7 @@ static size_t getMaximumIncomingPacketSize(const ClientState& cs) return s_udpIncomingBufferSize + g_proxyProtocolMaximumSize; } -static bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote) +bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote) { #ifndef DISABLE_DELAY_PIPE if (delayMsec && g_delay) { @@ -640,7 +650,7 @@ void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, doLatencyStats(incomingProtocol, udiff); } -static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, const std::shared_ptr& ds, bool selfGenerated) +static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, const std::shared_ptr& ds, bool isAsync, bool selfGenerated) { DNSResponse dr(ids, response, ds); @@ -658,8 +668,14 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); - if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) { - return; + if (!isAsync) { + if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) { + return; + } + + if (dr.isAsynchronous()) { + return; + } } ++g_stats.responses; @@ -757,7 +773,7 @@ void responderThread(std::shared_ptr dss) continue; } - handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false); + handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, false); } } catch (const std::exception& e) { @@ -830,6 +846,10 @@ static void spoofPacketFromString(DNSQuestion& dq, const string& spoofContent) bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop) { + if (dq.isAsynchronous()) { + return false; + } + switch(action) { case DNSAction::Action::Allow: return true; @@ -1215,10 +1235,12 @@ static void queueResponse(const ClientState& cs, const PacketBuffer& response, c #endif /* DISABLE_RECVMMSG */ /* self-generated responses or cache hits */ -static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit) +static bool prepareOutgoingResponse(LocalHolders& holders, const ClientState& cs, DNSQuestion& dq, bool cacheHit) { std::shared_ptr ds{nullptr}; DNSResponse dr(dq.ids, dq.getMutableData(), ds); + dr.d_incomingTCPState = dq.d_incomingTCPState; + dr.ids.selfGenerated = true; if (!applyRulesToResponse(cacheHit ? *holders.cacheHitRespRuleactions : *holders.selfAnsweredRespRuleactions, dr)) { return false; @@ -1230,6 +1252,14 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ ac(&dr, &result); } + if (cacheHit) { + ++g_stats.cacheHits; + } + + if (dr.isAsynchronous()) { + return false; + } + #ifdef HAVE_DNSCRYPT if (!cs.muted) { if (!encryptResponse(dq.getMutableData(), dq.getMaximumSize(), dq.overTCP(), dq.ids.dnsCryptQuery)) { @@ -1238,28 +1268,14 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ } #endif /* HAVE_DNSCRYPT */ - if (cacheHit) { - ++g_stats.cacheHits; - } - return true; } -ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) +ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr& selectedBackend) { const uint16_t queryId = ntohs(dq.getHeader()->id); try { - /* we need an accurate ("real") value for the response and - to store into the IDS, but not for insertion into the - rings for example */ - struct timespec now; - gettime(&now); - - if (!applyRulesToQuery(holders, dq, now)) { - return ProcessQueryResult::Drop; - } - if (dq.getHeader()->qr) { // something turned it into a response fixUpQueryTurnedResponse(dq, dq.ids.origFlags); @@ -1329,7 +1345,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size()); - if (!prepareOutgoingResponse(holders, cs, dq, true)) { + if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, true)) { return ProcessQueryResult::Drop; } @@ -1342,7 +1358,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& else if (dq.ids.protocol == dnsdist::Protocol::DoH && !forwardedOverUDP) { /* do a second-lookup for UDP responses, but we do not want TC=1 answers */ if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKeyUDP, dq.ids.subnet, dq.ids.dnssecOK, true, allowExpired, false, false, false)) { - if (!prepareOutgoingResponse(holders, cs, dq, true)) { + if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, true)) { return ProcessQueryResult::Drop; } @@ -1369,7 +1385,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& fixUpQueryTurnedResponse(dq, dq.ids.origFlags); - if (!prepareOutgoingResponse(holders, cs, dq, false)) { + if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, false)) { return ProcessQueryResult::Drop; } ++g_stats.responses; @@ -1394,7 +1410,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& return ProcessQueryResult::PassToBackend; } catch (const std::exception& e){ - vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.ids.origRemote.toStringWithPort(), queryId, e.what()); + vinfolog("Got an error while parsing a %s query (after applying rules) from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.ids.origRemote.toStringWithPort(), queryId, e.what()); } return ProcessQueryResult::Drop; } @@ -1402,7 +1418,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& class UDPTCPCrossQuerySender : public TCPQuerySender { public: - UDPTCPCrossQuerySender(const ClientState& cs, const std::shared_ptr& ds): d_cs(cs), d_ds(ds) + UDPTCPCrossQuerySender() { } @@ -1415,14 +1431,9 @@ public: return true; } - const ClientState* getClientState() const override - { - return &d_cs; - } - void handleResponse(const struct timeval& now, TCPResponse&& response) override { - if (!d_ds && !response.d_selfGenerated) { + if (!response.d_ds && !response.d_idstate.selfGenerated) { throw std::runtime_error("Passing a cross-protocol answer originated from UDP without a valid downstream"); } @@ -1431,7 +1442,7 @@ public: static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated); + handleResponseForUDPClient(ids, response.d_buffer, *localRespRuleActions, *localCacheInsertedRespRuleActions, response.d_ds, response.isAsync(), response.d_idstate.selfGenerated); } void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override @@ -1443,23 +1454,23 @@ public: { // nothing to do } -private: - const ClientState& d_cs; - const std::shared_ptr d_ds{nullptr}; }; class UDPCrossProtocolQuery : public CrossProtocolQuery { public: - UDPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr& ds) + UDPCrossProtocolQuery(PacketBuffer&& buffer_, InternalQueryState&& ids_, std::shared_ptr ds): CrossProtocolQuery(InternalQuery(std::move(buffer_), std::move(ids_)), ds) { - uint16_t z = 0; - getEDNSUDPPayloadSizeAndZ(reinterpret_cast(buffer.data()), buffer.size(), &ids.udpPayloadSize, &z); - if (ids.udpPayloadSize < 512) { - ids.udpPayloadSize = 512; + auto& ids = query.d_idstate; + const auto& buffer = query.d_buffer; + + if (ids.udpPayloadSize == 0) { + uint16_t z = 0; + getEDNSUDPPayloadSizeAndZ(reinterpret_cast(buffer.data()), buffer.size(), &ids.udpPayloadSize, &z); + if (ids.udpPayloadSize < 512) { + ids.udpPayloadSize = 512; + } } - query = InternalQuery(std::move(buffer), std::move(ids)); - downstream = ds; } ~UDPCrossProtocolQuery() @@ -1468,11 +1479,48 @@ public: std::shared_ptr getTCPQuerySender() override { - auto sender = std::make_shared(*query.d_idstate.cs, downstream); - return sender; + return s_sender; } +private: + static std::shared_ptr s_sender; }; +std::shared_ptr UDPCrossProtocolQuery::s_sender = std::make_shared(); + +std::unique_ptr getUDPCrossProtocolQueryFromDQ(DNSQuestion& dq); +std::unique_ptr getUDPCrossProtocolQueryFromDQ(DNSQuestion& dq) +{ + dq.ids.origID = dq.getHeader()->id; + return std::make_unique(std::move(dq.getMutableData()), std::move(dq.ids), nullptr); +} + +ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr& selectedBackend) +{ + const uint16_t queryId = ntohs(dq.getHeader()->id); + + try { + /* we need an accurate ("real") value for the response and + to store into the IDS, but not for insertion into the + rings for example */ + struct timespec now; + gettime(&now); + + if (!applyRulesToQuery(holders, dq, now)) { + return ProcessQueryResult::Drop; + } + + if (dq.isAsynchronous()) { + return ProcessQueryResult::Asynchronous; + } + + return processQueryAfterRules(dq, holders, selectedBackend); + } + catch (const std::exception& e){ + vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.ids.origRemote.toStringWithPort(), queryId, e.what()); + } + return ProcessQueryResult::Drop; +} + bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest) { bool doh = dq.ids.du != nullptr; @@ -1601,9 +1649,9 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } std::shared_ptr ss{nullptr}; - auto result = processQuery(dq, cs, holders, ss); + auto result = processQuery(dq, holders, ss); - if (result == ProcessQueryResult::Drop) { + if (result == ProcessQueryResult::Drop || result == ProcessQueryResult::Asynchronous) { return; } @@ -1622,7 +1670,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */ sendUDPResponse(cs.udpFD, query, dq.ids.delayMsec, dest, remote); - handleResponseSent(ids, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP); + handleResponseSent(dq.ids.qname, dq.ids.qtype, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP, dnsdist::Protocol::DoUDP); return; } @@ -2586,6 +2634,8 @@ int main(int argc, char** argv) #endif } + dnsdist::g_asyncHolder = std::make_unique(); + auto todo = setupLua(*(g_lua.lock()), false, false, g_cmdLine.config); auto localPools = g_pools.getCopy(); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index a425018e6e..0741ca266e 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -61,6 +61,10 @@ extern bool g_ECSOverride; using QTag = std::unordered_map; +class IncomingTCPConnectionState; + +struct ClientState; + struct DNSQuestion { DNSQuestion(InternalQueryState& ids_, PacketBuffer& data_): @@ -69,6 +73,7 @@ struct DNSQuestion DNSQuestion(const DNSQuestion&) = delete; DNSQuestion& operator=(const DNSQuestion&) = delete; DNSQuestion(DNSQuestion&&) = default; + virtual ~DNSQuestion() = default; std::string getTrailingData() const; bool setTrailingData(const std::string&); @@ -139,6 +144,21 @@ struct DNSQuestion return ids.queryRealTime.d_start; } + bool isAsynchronous() const + { + return asynchronous; + } + + std::shared_ptr getIncomingTCPState() const + { + return d_incomingTCPState; + } + + ClientState* getFrontend() const + { + return ids.cs; + } + protected: PacketBuffer& data; @@ -147,14 +167,18 @@ public: std::unique_ptr ecs{nullptr}; std::string sni; /* Server Name Indication, if any (DoT or DoH) */ mutable std::unique_ptr ednsOptions; /* this needs to be mutable because it is parsed just in time, when DNSQuestion is read-only */ + std::shared_ptr d_incomingTCPState{nullptr}; std::unique_ptr> proxyProtocolValues{nullptr}; uint16_t ecsPrefixLength; uint8_t ednsRCode{0}; bool ecsOverride; bool useECS{true}; bool addXPF{true}; + bool asynchronous{false}; }; +struct DownstreamState; + struct DNSResponse : DNSQuestion { DNSResponse(InternalQueryState& ids_, PacketBuffer& data_, const std::shared_ptr& downstream): @@ -1183,8 +1207,6 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n void resetLuaSideEffect(); // reset to indeterminate state bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote, unsigned int& qnameWireLength); -bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& insertedRespRuleActions, DNSResponse& dr, bool muted); -bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop); bool checkQueryHeaders(const struct dnsheader* dh, ClientState& cs); @@ -1203,11 +1225,16 @@ extern std::set g_capabilitiesToRetain; static const uint16_t s_udpIncomingBufferSize{1500}; // don't accept UDP queries larger than this value static const size_t s_maxPacketCacheEntrySize{4096}; // don't cache responses larger than this value -enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend }; -ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend); +enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend, Asynchronous }; +ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr& selectedBackend); +ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr& selectedBackend); +bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& insertedRespRuleActions, DNSResponse& dr, bool muted); +bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop); +bool processResponseAfterRules(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted); bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); +bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote); void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol); void handleResponseSent(const InternalQueryState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index ef7ac181a6..7310e4050d 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -134,6 +134,7 @@ dnsdist_SOURCES = \ dns.cc dns.hh \ dns_random.hh \ dnscrypt.cc dnscrypt.hh \ + dnsdist-async.cc dnsdist-async.hh \ dnsdist-backend.cc \ dnsdist-cache.cc dnsdist-cache.hh \ dnsdist-carbon.cc dnsdist-carbon.hh \ @@ -147,6 +148,7 @@ dnsdist_SOURCES = \ dnsdist-ecs.cc dnsdist-ecs.hh \ dnsdist-healthchecks.cc dnsdist-healthchecks.hh \ dnsdist-idstate.hh \ + dnsdist-internal-queries.cc dnsdist-internal-queries.hh \ dnsdist-kvs.hh dnsdist-kvs.cc \ dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \ dnsdist-lua-actions.cc \ @@ -242,6 +244,7 @@ testrunner_SOURCES = \ credentials.cc credentials.hh \ dns.cc dns.hh \ dnscrypt.cc dnscrypt.hh \ + dnsdist-async.cc dnsdist-async.hh \ dnsdist-backend.cc \ dnsdist-cache.cc dnsdist-cache.hh \ dnsdist-dnsparser.cc dnsdist-dnsparser.hh \ @@ -304,6 +307,7 @@ testrunner_SOURCES = \ test-dnsdist-connections-cache.cc \ test-dnsdist-dnsparser.cc \ test-dnsdist_cc.cc \ + test-dnsdistasync.cc \ test-dnsdistbackend_cc.cc \ test-dnsdistdynblocks_hh.cc \ test-dnsdistkvs_cc.cc \ diff --git a/pdns/dnsdistdist/dnsdist-async.cc b/pdns/dnsdistdist/dnsdist-async.cc new file mode 100644 index 0000000000..fc9174401c --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-async.cc @@ -0,0 +1,428 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "dnsdist-async.hh" +#include "dnsdist-internal-queries.hh" +#include "dolog.hh" +#include "threadname.hh" + +namespace dnsdist +{ + +AsynchronousHolder::AsynchronousHolder(bool failOpen) +{ + d_data = std::make_shared(); + d_data->d_failOpen = failOpen; + + int fds[2] = {-1, -1}; + if (pipe(fds) < 0) { + throw std::runtime_error("Error creating the AsynchronousHolder pipe: " + stringerror()); + } + + for (size_t idx = 0; idx < (sizeof(fds) / sizeof(*fds)); idx++) { + if (!setNonBlocking(fds[idx])) { + int err = errno; + close(fds[0]); + close(fds[1]); + throw std::runtime_error("Error setting the AsynchronousHolder pipe non-blocking: " + stringerror(err)); + } + } + + d_data->d_notifyPipe = FDWrapper(fds[1]); + d_data->d_watchPipe = FDWrapper(fds[0]); + + std::thread main([data = this->d_data] { mainThread(data); }); + main.detach(); +} + +AsynchronousHolder::~AsynchronousHolder() +{ + try { + stop(); + } + catch (...) { + } +} + +bool AsynchronousHolder::notify() const +{ + const char data = 0; + bool failed = false; + do { + auto written = write(d_data->d_notifyPipe.getHandle(), &data, sizeof(data)); + if (written == 0) { + break; + } + if (written > 0 && static_cast(written) == sizeof(data)) { + return true; + } + if (errno != EINTR) { + failed = true; + } + } while (!failed); + + return false; +} + +bool AsynchronousHolder::wait(const AsynchronousHolder::Data& data, FDMultiplexer& mplexer, std::vector& readyFDs, int atMostMs) +{ + readyFDs.clear(); + mplexer.getAvailableFDs(readyFDs, atMostMs); + if (readyFDs.size() == 0) { + /* timeout */ + return true; + } + + while (true) { + /* we might have been notified several times, let's read + as much as possible before returning */ + char dummy = 0; + auto got = read(data.d_watchPipe.getHandle(), &dummy, sizeof(dummy)); + if (got == 0) { + break; + } + if (got > 0 && static_cast(got) != sizeof(dummy)) { + continue; + } + if (got == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + break; + } + } + + return false; +} + +void AsynchronousHolder::stop() +{ + { + auto content = d_data->d_content.lock(); + d_data->d_done = true; + } + + notify(); +} + +void AsynchronousHolder::mainThread(std::shared_ptr data) +{ + setThreadName("dnsdist/async"); + struct timeval now; + std::list>> expiredEvents; + + auto mplexer = std::unique_ptr(FDMultiplexer::getMultiplexerSilent(1)); + mplexer->addReadFD(data->d_watchPipe.getHandle(), [](int, FDMultiplexer::funcparam_t&) {}); + std::vector readyFDs; + + while (true) { + bool shouldWait = true; + int timeout = -1; + { + auto content = data->d_content.lock(); + if (data->d_done) { + return; + } + + if (!content->empty()) { + gettimeofday(&now, nullptr); + struct timeval next = getNextTTD(*content); + if (next <= now) { + pickupExpired(*content, now, expiredEvents); + shouldWait = false; + } + else { + auto remainingUsec = uSec(next - now); + timeout = std::round(remainingUsec / 1000.0); + if (timeout == 0 && remainingUsec > 0) { + /* if we have less than 1 ms, let's wait at least 1 ms */ + timeout = 1; + } + } + } + } + + if (shouldWait) { + auto timedOut = wait(*data, *mplexer, readyFDs, timeout); + if (timedOut) { + auto content = data->d_content.lock(); + gettimeofday(&now, nullptr); + pickupExpired(*content, now, expiredEvents); + } + } + + while (!expiredEvents.empty()) { + auto [queryID, query] = std::move(expiredEvents.front()); + expiredEvents.pop_front(); + if (!data->d_failOpen) { + vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec); + auto sender = query->getTCPQuerySender(); + if (sender) { + sender->notifyIOError(std::move(query->query.d_idstate), now); + } + } + else { + vinfolog("Asynchronous query %d has expired at %d.%d, resuming", queryID, now.tv_sec, now.tv_usec); + resumeQuery(std::move(query)); + } + } + } +} + +void AsynchronousHolder::push(uint16_t asyncID, uint16_t queryID, const struct timeval& ttd, std::unique_ptr&& query) +{ + bool needNotify = false; + { + auto content = d_data->d_content.lock(); + if (!content->empty()) { + /* the thread is already waiting on a TTD expiry in addition to notifications, + let's not wake it unless our TTD comes before the current one */ + const struct timeval next = getNextTTD(*content); + if (ttd < next) { + needNotify = true; + } + } + else { + /* the thread is currently only waiting for a notify */ + needNotify = true; + } + content->insert({std::move(query), ttd, asyncID, queryID}); + } + + if (needNotify) { + notify(); + } +} + +std::unique_ptr AsynchronousHolder::get(uint16_t asyncID, uint16_t queryID) +{ + /* no need to notify, worst case the thread wakes up for nothing because this was the next TTD */ + auto content = d_data->d_content.lock(); + auto it = content->find(std::tie(queryID, asyncID)); + if (it == content->end()) { + struct timeval now; + gettimeofday(&now, nullptr); + vinfolog("Asynchronous object %d not found at %d.%d", queryID, now.tv_sec, now.tv_usec); + return nullptr; + } + + auto result = std::move(it->d_query); + content->erase(it); + return result; +} + +void AsynchronousHolder::pickupExpired(content_t& content, const struct timeval& now, std::list>>& events) +{ + auto& idx = content.get(); + for (auto it = idx.begin(); it != idx.end() && it->d_ttd < now;) { + events.emplace_back(it->d_queryID, std::move(it->d_query)); + it = idx.erase(it); + } +} + +struct timeval AsynchronousHolder::getNextTTD(const content_t& content) +{ + if (content.empty()) { + throw std::runtime_error("AsynchronousHolder::getNextTTD() called on an empty holder"); + } + + return content.get().begin()->d_ttd; +} + +bool AsynchronousHolder::empty() +{ + return d_data->d_content.read_only_lock()->empty(); +} + +static bool resumeResponse(std::unique_ptr&& response) +{ + try { + auto& ids = response->query.d_idstate; + DNSResponse dr = response->getDR(); + + LocalHolders holders; + auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dr, ids.cs->muted); + if (!result) { + /* easy */ + return true; + } + + auto sender = response->getTCPQuerySender(); + if (sender) { + struct timeval now; + gettimeofday(&now, nullptr); + + TCPResponse resp(std::move(response->query.d_buffer), std::move(response->query.d_idstate), nullptr, response->downstream); + resp.d_async = true; + sender->handleResponse(now, std::move(resp)); + } + } + catch (const std::exception& e) { + vinfolog("Got exception while resuming cross-protocol response: %s", e.what()); + return false; + } + + return true; +} + +static LockGuarded>> s_asynchronousEventsQueue; + +bool queueQueryResumptionEvent(std::unique_ptr&& query) +{ + s_asynchronousEventsQueue.lock()->push_back(std::move(query)); + return true; +} + +void handleQueuedAsynchronousEvents() +{ + while (true) { + std::unique_ptr query; + { + // we do not want to hold the lock while resuming + auto queue = s_asynchronousEventsQueue.lock(); + if (queue->empty()) { + return; + } + + query = std::move(queue->front()); + queue->pop_front(); + } + if (query && !resumeQuery(std::move(query))) { + vinfolog("Unable to resume asynchronous query event"); + } + } +} + +bool resumeQuery(std::unique_ptr&& query) +{ + if (query->d_isResponse) { + return resumeResponse(std::move(query)); + } + + auto& ids = query->query.d_idstate; + DNSQuestion dq = query->getDQ(); + LocalHolders holders; + + auto result = processQueryAfterRules(dq, holders, query->downstream); + if (result == ProcessQueryResult::Drop) { + /* easy */ + return true; + } + else if (result == ProcessQueryResult::PassToBackend) { + if (query->downstream == nullptr) { + return false; + } + +#ifdef HAVE_DNS_OVER_HTTPS + if (dq.ids.du != nullptr) { + dq.ids.du->downstream = query->downstream; + } +#endif + + if (query->downstream->isTCPOnly() || !(dq.getProtocol().isUDP() || dq.getProtocol() == dnsdist::Protocol::DoH)) { + query->downstream->passCrossProtocolQuery(std::move(query)); + return true; + } + + auto queryID = dq.getHeader()->id; + /* at this point 'du', if it is not nullptr, is owned by the DoHCrossProtocolQuery + which will stop existing when we return, so we need to increment the reference count + */ + return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dq, query->query.d_buffer, ids.origDest); + } + else if (result == ProcessQueryResult::SendAnswer) { + auto sender = query->getTCPQuerySender(); + if (!sender) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + + TCPResponse response(std::move(query->query.d_buffer), std::move(query->query.d_idstate), nullptr, query->downstream); + response.d_async = true; + response.d_idstate.selfGenerated = true; + + try { + sender->handleResponse(now, std::move(response)); + return true; + } + catch (const std::exception& e) { + vinfolog("Got exception while resuming cross-protocol self-answered query: %s", e.what()); + return false; + } + } + else if (result == ProcessQueryResult::Asynchronous) { + /* nope */ + errlog("processQueryAfterRules returned 'asynchronous' while trying to resume an already asynchronous query"); + return false; + } + + return false; +} + +bool suspendQuery(DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +{ + if (!g_asyncHolder) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + struct timeval ttd = now; + ttd.tv_sec += timeoutMs / 1000; + ttd.tv_usec += (timeoutMs % 1000) * 1000; + if (ttd.tv_usec >= 1000000) { + ttd.tv_sec++; + ttd.tv_usec -= 1000000; + } + + vinfolog("Suspending asynchronous query %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec); + auto query = getInternalQueryFromDQ(dq, false); + + g_asyncHolder->push(asyncID, queryID, ttd, std::move(query)); + return true; +} + +bool suspendResponse(DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +{ + if (!g_asyncHolder) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + struct timeval ttd = now; + ttd.tv_sec += timeoutMs / 1000; + ttd.tv_usec += (timeoutMs % 1000) * 1000; + if (ttd.tv_usec >= 1000000) { + ttd.tv_sec++; + ttd.tv_usec -= 1000000; + } + + vinfolog("Suspending asynchronous response %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec); + auto query = getInternalQueryFromDQ(dr, true); + query->d_isResponse = true; + query->downstream = dr.d_downstream; + + g_asyncHolder->push(asyncID, queryID, ttd, std::move(query)); + return true; +} + +std::unique_ptr g_asyncHolder; +} diff --git a/pdns/dnsdistdist/dnsdist-async.hh b/pdns/dnsdistdist/dnsdist-async.hh new file mode 100644 index 0000000000..5a8c0908f5 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-async.hh @@ -0,0 +1,98 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include + +#include +#include +#include + +#include "dnsdist-tcp.hh" + +namespace dnsdist +{ +class AsynchronousHolder +{ +public: + AsynchronousHolder(bool failOpen = true); + ~AsynchronousHolder(); + void push(uint16_t asyncID, uint16_t queryID, const struct timeval& ttd, std::unique_ptr&& query); + std::unique_ptr get(uint16_t asyncID, uint16_t queryID); + bool empty(); + void stop(); + +private: + struct TTDTag + { + }; + struct IDTag + { + }; + + struct Entry + { + /* not used by any of the indexes, so mutable */ + mutable std::unique_ptr d_query; + struct timeval d_ttd; + uint16_t d_asyncID; + uint16_t d_queryID; + }; + + typedef multi_index_container< + Entry, + indexed_by< + ordered_unique, + composite_key< + Entry, + member, + member>>, + ordered_non_unique, + member>>> + content_t; + + static void pickupExpired(content_t&, const struct timeval& now, std::list>>& expiredEvents); + static struct timeval getNextTTD(const content_t&); + + struct Data + { + LockGuarded d_content; + FDWrapper d_notifyPipe; + FDWrapper d_watchPipe; + bool d_failOpen{true}; + bool d_done{false}; + }; + std::shared_ptr d_data{nullptr}; + + static void mainThread(std::shared_ptr data); + static bool wait(const Data& data, FDMultiplexer& mplexer, std::vector& readyFDs, int atMostMs); + bool notify() const; +}; + +bool suspendQuery(DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs); +bool suspendResponse(DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs); +bool queueQueryResumptionEvent(std::unique_ptr&& query); +bool resumeQuery(std::unique_ptr&& query); +void handleQueuedAsynchronousEvents(); + +extern std::unique_ptr g_asyncHolder; +} diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index a6328f168c..480bfd1960 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -140,11 +140,6 @@ public: return true; } - const ClientState* getClientState() const override - { - return nullptr; - } - void handleResponse(const struct timeval& now, TCPResponse&& response) override { d_data->d_buffer = std::move(response.d_buffer); diff --git a/pdns/dnsdistdist/dnsdist-internal-queries.cc b/pdns/dnsdistdist/dnsdist-internal-queries.cc new file mode 100644 index 0000000000..49f95e42b4 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-internal-queries.cc @@ -0,0 +1,45 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "dnsdist-internal-queries.hh" +#include "dnsdist-tcp.hh" +#include "doh.hh" + +std::unique_ptr getUDPCrossProtocolQueryFromDQ(DNSQuestion& dq); + +namespace dnsdist +{ +std::unique_ptr getInternalQueryFromDQ(DNSQuestion& dq, bool isResponse) +{ + auto protocol = dq.getProtocol(); + if (protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP) { + return getUDPCrossProtocolQueryFromDQ(dq); + } +#ifdef HAVE_DNS_OVER_HTTPS + else if (protocol == dnsdist::Protocol::DoH) { + return getDoHCrossProtocolQueryFromDQ(dq, isResponse); + } +#endif + else { + return getTCPCrossProtocolQueryFromDQ(dq); + } +} +} diff --git a/pdns/dnsdistdist/dnsdist-internal-queries.hh b/pdns/dnsdistdist/dnsdist-internal-queries.hh new file mode 100644 index 0000000000..46634aa11a --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-internal-queries.hh @@ -0,0 +1,30 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include +#include "dnsdist.hh" + +namespace dnsdist +{ +std::unique_ptr getInternalQueryFromDQ(DNSQuestion& dq, bool isResponse); +} diff --git a/pdns/dnsdistdist/dnsdist-lua-bindings-network.cc b/pdns/dnsdistdist/dnsdist-lua-bindings-network.cc index 3d2c21c4fb..e66a13986a 100644 --- a/pdns/dnsdistdist/dnsdist-lua-bindings-network.cc +++ b/pdns/dnsdistdist/dnsdist-lua-bindings-network.cc @@ -20,6 +20,7 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ #include "dnsdist.hh" +#include "dnsdist-async.hh" #include "dnsdist-lua.hh" #include "dnsdist-lua-network.hh" #include "dolog.hh" @@ -66,8 +67,11 @@ void setupLuaBindingsNetwork(LuaContext& luaCtx, bool client) } return listener->addUnixListeningEndpoint(path, endpointID, [cb](dnsdist::NetworkListener::EndpointID endpoint, std::string&& dgram, const std::string& from) { - auto lock = g_lua.lock(); - cb(endpoint, dgram, from); + { + auto lock = g_lua.lock(); + cb(endpoint, dgram, from); + } + dnsdist::handleQueuedAsynchronousEvents(); }); }); diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h b/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h index 3a509058a2..741bd3aedc 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h +++ b/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h @@ -67,6 +67,7 @@ void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init) __attribute__ ((visibility ("default"))); uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); +uint16_t dnsdist_ffi_dnsquestion_get_id(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); @@ -85,11 +86,13 @@ uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquest bool dnsdist_ffi_dnsquestion_get_do(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); void dnsdist_ffi_dnsquestion_get_sni(const dnsdist_ffi_dnsquestion_t* dq, const char** sni, size_t* sniSize) __attribute__ ((visibility ("default"))); const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* label) __attribute__ ((visibility ("default"))); +size_t dnsdist_ffi_dnsquestion_get_tag_raw(const dnsdist_ffi_dnsquestion_t* dq, const char* label, char* buffer, size_t bufferSize) __attribute__ ((visibility ("default"))); const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq, void* buffer, size_t bufferSize) __attribute__ ((visibility ("default"))); +uint64_t dnsdist_ffi_dnsquestion_get_elapsed_us(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); // returns the length of the resulting 'out' array. 'out' is not set if the length is 0 size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* ref, const dnsdist_ffi_ednsoption_t** out) __attribute__ ((visibility ("default"))); @@ -106,6 +109,7 @@ void dnsdist_ffi_dnsquestion_set_ecs_prefix_length(dnsdist_ffi_dnsquestion_t* dq void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t tempFailureTTL) __attribute__ ((visibility ("default"))); void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value) __attribute__ ((visibility ("default"))); +void dnsdist_ffi_dnsquestion_set_tag_raw(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value, size_t valueSize) __attribute__ ((visibility ("default"))); void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType) __attribute__ ((visibility ("default"))); @@ -148,6 +152,14 @@ void dnsdist_ffi_dnsresponse_limit_ttl(dnsdist_ffi_dnsresponse_t* dr, uint32_t m void dnsdist_ffi_dnsresponse_set_max_returned_ttl(dnsdist_ffi_dnsresponse_t* dr, uint32_t max) __attribute__ ((visibility ("default"))); void dnsdist_ffi_dnsresponse_clear_records_type(dnsdist_ffi_dnsresponse_t* dr, uint16_t qtype) __attribute__ ((visibility ("default"))); +bool dnsdist_ffi_dnsquestion_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) __attribute__ ((visibility ("default"))); +bool dnsdist_ffi_dnsresponse_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) __attribute__ ((visibility ("default"))); + +bool dnsdist_ffi_resume_from_async(uint16_t asyncID, uint16_t queryID, const char* tag, size_t tagSize, const char* tagValue, size_t tagValueSize, bool useCache) __attribute__ ((visibility ("default"))); +bool dnsdist_ffi_drop_from_async(uint16_t asyncID, uint16_t queryID) __attribute__ ((visibility ("default"))); +bool dnsdist_ffi_set_answer_from_async(uint16_t asyncID, uint16_t queryID, const char* raw, size_t rawSize) __attribute__ ((visibility ("default"))); +bool dnsdist_ffi_set_rcode_from_async(uint16_t asyncID, uint16_t queryID, uint8_t rcode, bool clearAnswers) __attribute__ ((visibility ("default"))); + typedef struct dnsdist_ffi_proxy_protocol_value { const char* value; uint16_t size; diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index a0f782710a..f204b0aca0 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -20,6 +20,7 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ +#include "dnsdist-async.hh" #include "dnsdist-dnsparser.hh" #include "dnsdist-lua-ffi.hh" #include "dnsdist-mac-address.hh" @@ -39,6 +40,14 @@ uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq) return dq->dq->ids.qclass; } +uint16_t dnsdist_ffi_dnsquestion_get_id(const dnsdist_ffi_dnsquestion_t* dq) +{ + if (dq == nullptr) { + return 0; + } + return ntohs(dq->dq->getHeader()->id); +} + static void dnsdist_ffi_comboaddress_to_raw(const ComboAddress& ca, const void** addr, size_t* addrSize) { if (ca.isIPv4()) { @@ -66,7 +75,6 @@ size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq, if (dq == nullptr) { return 0; } - auto ret = dnsdist::MacAddressesCache::get(dq->dq->ids.origRemote, reinterpret_cast(buffer), bufferSize); if (ret != 0) { return 0; @@ -75,6 +83,15 @@ size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq, return 6; } +uint64_t dnsdist_ffi_dnsquestion_get_elapsed_us(const dnsdist_ffi_dnsquestion_t* dq) +{ + if (dq == nullptr) { + return 0; + } + + return dq->dq->ids.queryRealTime.udiff(); +} + void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits) { dq->maskedRemote = Netmask(dq->dq->ids.origRemote, bits).getMaskedNetwork(); @@ -223,7 +240,7 @@ const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, { const char * result = nullptr; - if (dq->dq->ids.qTag != nullptr) { + if (dq != nullptr && dq->dq != nullptr && dq->dq->ids.qTag != nullptr) { const auto it = dq->dq->ids.qTag->find(label); if (it != dq->dq->ids.qTag->cend()) { result = it->second.c_str(); @@ -233,6 +250,25 @@ const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, return result; } +size_t dnsdist_ffi_dnsquestion_get_tag_raw(const dnsdist_ffi_dnsquestion_t* dq, const char* label, char* buffer, size_t bufferSize) +{ + if (dq == nullptr || dq->dq == nullptr || dq->dq->ids.qTag == nullptr || label == nullptr || buffer == nullptr || bufferSize == 0) { + return 0; + } + + const auto it = dq->dq->ids.qTag->find(label); + if (it == dq->dq->ids.qTag->cend()) { + return 0; + } + + if (it->second.size() > bufferSize) { + return 0; + } + + memcpy(buffer, it->second.c_str(), it->second.size()); + return it->second.size(); +} + const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpPath) { @@ -380,7 +416,7 @@ size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, c size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_tag_t** out) { - if (dq->dq->ids.qTag == nullptr || dq->dq->ids.qTag->size() == 0) { + if (dq == nullptr || dq->dq == nullptr || dq->dq->ids.qTag == nullptr || dq->dq->ids.qTag->size() == 0) { return 0; } @@ -470,6 +506,11 @@ void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* dq->dq->setTag(label, value); } +void dnsdist_ffi_dnsquestion_set_tag_raw(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value, size_t valueSize) +{ + dq->dq->setTag(label, std::string(value, valueSize)); +} + size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out) { dq->trailingData = dq->dq->getTrailingData(); @@ -649,6 +690,168 @@ void dnsdist_ffi_dnsresponse_clear_records_type(dnsdist_ffi_dnsresponse_t* dr, u } } +bool dnsdist_ffi_dnsquestion_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +{ + try { + dq->dq->asynchronous = true; + dnsdist::suspendQuery(*dq->dq, asyncID, queryID, timeoutMs); + return true; + } + catch (const std::exception& e) { + vinfolog("Error in dnsdist_ffi_dnsquestion_set_async: %s", e.what()); + } + catch (...) { + vinfolog("Exception in dnsdist_ffi_dnsquestion_set_async"); + } + + return false; +} + +bool dnsdist_ffi_dnsresponse_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +{ + try { + dq->dq->asynchronous = true; + auto dr = dynamic_cast(dq->dq); + if (!dr) { + vinfolog("Passed a DNSQuestion instead of a DNSResponse to dnsdist_ffi_dnsresponse_set_async"); + return false; + } + + dnsdist::suspendResponse(*dr, asyncID, queryID, timeoutMs); + return true; + } + catch (const std::exception& e) { + vinfolog("Error in dnsdist_ffi_dnsresponse_set_async: %s", e.what()); + } + catch (...) { + vinfolog("Exception in dnsdist_ffi_dnsresponse_set_async"); + } + return false; +} + +bool dnsdist_ffi_resume_from_async(uint16_t asyncID, uint16_t queryID, const char* tag, size_t tagSize, const char* tagValue, size_t tagValueSize, bool useCache) +{ + if (!dnsdist::g_asyncHolder) { + vinfolog("Unable to resume, no asynchronous holder"); + return false; + } + + auto query = dnsdist::g_asyncHolder->get(asyncID, queryID); + if (!query) { + vinfolog("Unable to resume, no object found for asynchronous ID %d and query ID %d", asyncID, queryID); + return false; + } + + auto& ids = query->query.d_idstate; + if (tag != nullptr && tagSize > 0) { + if (!ids.qTag) { + ids.qTag = std::make_unique(); + } + (*ids.qTag)[std::string(tag, tagSize)] = std::string(tagValue, tagValueSize); + } + + ids.skipCache = !useCache; + + return dnsdist::queueQueryResumptionEvent(std::move(query)); +} + +bool dnsdist_ffi_set_rcode_from_async(uint16_t asyncID, uint16_t queryID, uint8_t rcode, bool clearAnswers) +{ + if (!dnsdist::g_asyncHolder) { + return false; + } + + auto query = dnsdist::g_asyncHolder->get(asyncID, queryID); + if (!query) { + vinfolog("Unable to resume with a custom response code, no object found for asynchronous ID %d and query ID %d", asyncID, queryID); + return false; + } + + const auto qnameLength = query->query.d_idstate.qname.wirelength(); + auto& buffer = query->query.d_buffer; + if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) { + return false; + } + + EDNS0Record edns0; + bool hadEDNS = false; + if (clearAnswers) { + hadEDNS = getEDNS0Record(buffer, edns0); + } + + auto dh = reinterpret_cast(buffer.data()); + dh->rcode = rcode; + dh->ad = false; + dh->aa = false; + dh->ra = dh->rd; + dh->qr = true; + + if (clearAnswers) { + dh->ancount = 0; + dh->nscount = 0; + dh->arcount = 0; + buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)); + if (hadEDNS) { + if (!addEDNS(buffer, query->query.d_idstate.protocol.isUDP() ? 4096 : std::numeric_limits::max(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) { + return false; + } + } + } + + query->query.d_idstate.skipCache = true; + + return dnsdist::queueQueryResumptionEvent(std::move(query)); +} + +bool dnsdist_ffi_drop_from_async(uint16_t asyncID, uint16_t queryID) +{ + if (!dnsdist::g_asyncHolder) { + return false; + } + + auto query = dnsdist::g_asyncHolder->get(asyncID, queryID); + if (!query) { + vinfolog("Unable to drop, no object found for asynchronous ID %d and query ID %d", asyncID, queryID); + return false; + } + + auto sender = query->getTCPQuerySender(); + if (!sender) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + sender->notifyIOError(std::move(query->query.d_idstate), now); + + return true; +} + +bool dnsdist_ffi_set_answer_from_async(uint16_t asyncID, uint16_t queryID, const char* raw, size_t rawSize) +{ + if (rawSize < sizeof(dnsheader)) { + return false; + } + if (!dnsdist::g_asyncHolder) { + return false; + } + + auto query = dnsdist::g_asyncHolder->get(asyncID, queryID); + if (!query) { + vinfolog("Unable to resume with a custom answer, no object found for asynchronous ID %d and query ID %d", asyncID, queryID); + return false; + } + + auto oldId = reinterpret_cast(query->query.d_buffer.data())->id; + query->query.d_buffer.clear(); + query->query.d_buffer.insert(query->query.d_buffer.begin(), raw, raw + rawSize); + reinterpret_cast(query->query.d_buffer.data())->id = oldId; + + query->query.d_idstate.skipCache = true; + + return dnsdist::queueQueryResumptionEvent(std::move(query)); +} + static constexpr char s_lua_ffi_code[] = R"FFICodeContent( local ffi = require("ffi") local C = ffi.C @@ -797,6 +1000,10 @@ size_t dnsdist_ffi_packetcache_get_domain_list_by_addr(const char* poolName, con vinfolog("Error parsing address passed to dnsdist_ffi_packetcache_get_domain_list_by_addr: %s", e.what()); return 0; } + catch (const PDNSException& e) { + vinfolog("Error parsing address passed to dnsdist_ffi_packetcache_get_domain_list_by_addr: %s", e.reason); + return 0; + } const auto localPools = g_pools.getCopy(); auto it = localPools.find(poolName); @@ -1037,6 +1244,10 @@ size_t dnsdist_ffi_ring_get_entries_by_addr(const char* addr, dnsdist_ffi_ring_e vinfolog("Unable to convert address in dnsdist_ffi_ring_get_entries_by_addr: %s", e.what()); return 0; } + catch (const PDNSException& e) { + vinfolog("Unable to convert address in dnsdist_ffi_ring_get_entries_by_addr: %s", e.reason); + return 0; + } auto list = std::make_unique(); diff --git a/pdns/dnsdistdist/dnsdist-lua-network.cc b/pdns/dnsdistdist/dnsdist-lua-network.cc index 819137a400..6892888008 100644 --- a/pdns/dnsdistdist/dnsdist-lua-network.cc +++ b/pdns/dnsdistdist/dnsdist-lua-network.cc @@ -24,11 +24,12 @@ #include "dnsdist-lua-network.hh" #include "dolog.hh" +#include "threadname.hh" namespace dnsdist { NetworkListener::NetworkListener() : - d_mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent())) + d_mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent(10))) { } @@ -131,6 +132,7 @@ void NetworkListener::runOnce(struct timeval& now, uint32_t timeout) void NetworkListener::mainThread() { + setThreadName("dnsdist/lua-net"); struct timeval now; while (true) { diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 56884d0f31..5c745eac1c 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -148,7 +148,7 @@ void DoHConnectionToBackend::handleResponse(PendingRequest&& request) } } - request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this())); + request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this(), d_ds)); } catch (const std::exception& e) { vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what()); diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index 9da42acfb7..a6ab7002f2 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -38,7 +38,7 @@ ConnectionToBackend::~ConnectionToBackend() bool ConnectionToBackend::reconnect() { std::unique_ptr tlsSession{nullptr}; - if (d_handler) { + if (d_handler) { DEBUGLOG("closing socket "<getDescriptor()); if (d_handler->isTLS()) { if (d_handler->hasTLSSessionBeenResumed()) { @@ -73,18 +73,18 @@ bool ConnectionToBackend::reconnect() DEBUGLOG("Opening TCP connection to backend "<getNameWithAddr()); ++d_ds->tcpNewConnections; try { - auto socket = std::make_unique(d_ds->d_config.remote.sin4.sin_family, SOCK_STREAM, 0); - DEBUGLOG("result of socket() is "<getHandle()); + auto socket = Socket(d_ds->d_config.remote.sin4.sin_family, SOCK_STREAM, 0); + DEBUGLOG("result of socket() is "<getHandle()); + setTCPNoDelay(socket.getHandle()); #ifdef SO_BINDTODEVICE if (!d_ds->d_config.sourceItfName.empty()) { - int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->d_config.sourceItfName.c_str(), d_ds->d_config.sourceItfName.length()); + int res = setsockopt(socket.getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->d_config.sourceItfName.c_str(), d_ds->d_config.sourceItfName.length()); if (res != 0) { vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror()); } @@ -92,19 +92,18 @@ bool ConnectionToBackend::reconnect() #endif if (!IsAnyAddress(d_ds->d_config.sourceAddr)) { - SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); + SSetsockopt(socket.getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); #ifdef IP_BIND_ADDRESS_NO_PORT if (d_ds->d_config.ipBindAddrNoPort) { - SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); + SSetsockopt(socket.getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); } #endif - socket->bind(d_ds->d_config.sourceAddr, false); + socket.bind(d_ds->d_config.sourceAddr, false); } - - socket->setNonBlocking(); + socket.setNonBlocking(); gettimeofday(&d_connectionStartTime, nullptr); - auto handler = std::make_unique(d_ds->d_config.d_tlsSubjectName, d_ds->d_config.d_tlsSubjectIsAddr, socket->releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec); + auto handler = std::make_unique(d_ds->d_config.d_tlsSubjectName, d_ds->d_config.d_tlsSubjectIsAddr, socket.releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec); if (!tlsSession && d_ds->d_tlsCtx) { tlsSession = g_sessionCache.getSession(d_ds->getID(), d_connectionStartTime.tv_sec); } @@ -591,15 +590,13 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F auto pendingResponses = std::move(d_pendingResponses); d_pendingResponses.clear(); - auto increaseCounters = [reason](std::shared_ptr& sender) { + auto increaseCounters = [reason](const ClientState* cs) { if (reason == FailureReason::timeout) { - const ClientState* cs = sender->getClientState(); if (cs) { ++cs->tcpDownstreamTimeouts; } } else if (reason == FailureReason::gaveUp) { - const ClientState* cs = sender->getClientState(); if (cs) { ++cs->tcpGaveUp; } @@ -608,25 +605,25 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F try { if (d_state == State::sendingQueryToBackend) { + increaseCounters(d_currentQuery.d_query.d_idstate.cs); auto sender = d_currentQuery.d_sender; if (sender->active()) { - increaseCounters(sender); sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now); } } for (auto& query : pendingQueries) { + increaseCounters(query.d_query.d_idstate.cs); auto sender = query.d_sender; if (sender->active()) { - increaseCounters(sender); sender->notifyIOError(std::move(query.d_query.d_idstate), now); } } for (auto& response : pendingResponses) { + increaseCounters(response.second.d_query.d_idstate.cs); auto sender = response.second.d_sender; if (sender->active()) { - increaseCounters(sender); sender->notifyIOError(std::move(response.second.d_query.d_idstate), now); } } @@ -672,6 +669,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrd_ds; /* we don't move the whole IDS because we will need for the responses to come */ response.d_idstate.qtype = it->second.d_query.d_idstate.qtype; response.d_idstate.qname = it->second.d_query.d_idstate.qname; @@ -728,7 +726,7 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptractive()) { DEBUGLOG("passing response to client connection for "<handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn)); + sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds)); } if (!d_pendingQueries.empty()) { diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index e48d52317d..59c4df410d 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -139,11 +139,6 @@ static void handleTimeout(std::shared_ptr& state, bo return d_ioState != nullptr; } - const ClientState* getClientState() const override - { - return d_ci.cs; - } - std::string toString() const { ostringstream o; diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index 43d615dea2..3d11f1a4f4 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -72,8 +72,9 @@ struct ConnectionInfo int fd{-1}; }; -struct InternalQuery +class InternalQuery { +public: InternalQuery() { } @@ -119,15 +120,26 @@ struct TCPResponse : public TCPQuery memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); } - TCPResponse(PacketBuffer&& buffer, InternalQueryState&& state, std::shared_ptr conn) : - TCPQuery(std::move(buffer), std::move(state)), d_connection(conn) + TCPResponse(PacketBuffer&& buffer, InternalQueryState&& state, std::shared_ptr conn, std::shared_ptr ds) : + TCPQuery(std::move(buffer), std::move(state)), d_connection(conn), d_ds(ds) { - memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); + if (d_buffer.size() >= sizeof(dnsheader)) { + memcpy(&d_cleartextDH, reinterpret_cast(d_buffer.data()), sizeof(d_cleartextDH)); + } + else { + memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); + } + } + + bool isAsync() const + { + return d_async; } std::shared_ptr d_connection{nullptr}; + std::shared_ptr d_ds{nullptr}; dnsheader d_cleartextDH; - bool d_selfGenerated{false}; + bool d_async{false}; }; class TCPQuerySender @@ -138,7 +150,6 @@ public: } virtual bool active() const = 0; - virtual const ClientState* getClientState() const = 0; virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0; virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0; virtual void notifyIOError(InternalQueryState&& query, const struct timeval& now) = 0; @@ -170,11 +181,24 @@ struct CrossProtocolQuery } virtual std::shared_ptr getTCPQuerySender() = 0; + virtual DNSQuestion getDQ() + { + auto& ids = query.d_idstate; + DNSQuestion dq(ids, query.d_buffer); + return dq; + } + + virtual DNSResponse getDR() + { + auto& ids = query.d_idstate; + DNSResponse dr(ids, query.d_buffer, downstream); + return dr; + } InternalQuery query; std::shared_ptr downstream{nullptr}; size_t proxyProtocolPayloadSize{0}; - bool isXFR{false}; + bool d_isResponse{false}; }; class TCPClientCollection @@ -278,3 +302,5 @@ private: }; extern std::unique_ptr g_tcpclientthreads; + +std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq); diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index cc4c549bf1..d505e1a0c1 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -433,7 +433,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo class DoHTCPCrossQuerySender : public TCPQuerySender { public: - DoHTCPCrossQuerySender(const ClientState& cs): d_cs(cs) + DoHTCPCrossQuerySender() { } @@ -442,11 +442,6 @@ public: return true; } - const ClientState* getClientState() const override - { - return &d_cs; - } - void handleResponse(const struct timeval& now, TCPResponse&& response) override { if (!response.d_idstate.du) { @@ -462,32 +457,40 @@ public: du->ids = std::move(response.d_idstate); DNSResponse dr(du->ids, du->response, du->downstream); - static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); - static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); - dr.ids.du = std::move(du); + if (!response.isAsync()) { + static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); + static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { - if (dr.ids.du) { - dr.ids.du->status_code = 503; - sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + dr.ids.du = std::move(du); + + if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { + if (dr.ids.du) { + dr.ids.du->status_code = 503; + sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + } + return; } - return; - } - du = std::move(dr.ids.du); + if (dr.isAsynchronous()) { + return; + } - double udiff = du->ids.queryRealTime.udiff(); - vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); + du = std::move(dr.ids.du); + } + + if (!du->ids.selfGenerated) { + double udiff = du->ids.queryRealTime.udiff(); + vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); - auto backendProtocol = du->downstream->getProtocol(); - if (backendProtocol == dnsdist::Protocol::DoUDP && du->tcp) { - backendProtocol = dnsdist::Protocol::DoTCP; + auto backendProtocol = du->downstream->getProtocol(); + if (backendProtocol == dnsdist::Protocol::DoUDP && du->tcp) { + backendProtocol = dnsdist::Protocol::DoTCP; + } + handleResponseSent(du->ids, udiff, du->ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, backendProtocol); } - handleResponseSent(du->ids, udiff, du->ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, backendProtocol); ++g_stats.responses; if (du->ids.cs) { @@ -517,16 +520,23 @@ public: du->status_code = 502; sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response"); } -protected: - const ClientState& d_cs; }; class DoHCrossProtocolQuery : public CrossProtocolQuery { public: - DoHCrossProtocolQuery(DOHUnitUniquePtr&& du) + DoHCrossProtocolQuery(DOHUnitUniquePtr&& du, bool isResponse) { - query = InternalQuery(std::move(du->query), std::move(du->ids)); + if (isResponse) { + /* happens when a response becomes async */ + query = InternalQuery(std::move(du->response), std::move(du->ids)); + } + else { + /* we need to duplicate the query here because we might need + the existing query later if we get a truncated answer */ + query = InternalQuery(PacketBuffer(du->query), std::move(du->ids)); + } + /* it might have been moved when we moved du->ids */ if (du) { query.d_idstate.du = std::move(du); @@ -551,16 +561,61 @@ public: std::shared_ptr getTCPQuerySender() override { query.d_idstate.du->downstream = downstream; - auto sender = std::make_shared(*query.d_idstate.cs); - return sender; + return s_sender; + } + + DNSQuestion getDQ() override + { + auto& ids = query.d_idstate; + DNSQuestion dq(ids, query.d_buffer); + return dq; } + DNSResponse getDR() override + { + auto& ids = query.d_idstate; + DNSResponse dr(ids, query.d_buffer, downstream); + return dr; + } + DOHUnitUniquePtr&& releaseDU() { return std::move(query.d_idstate.du); } + +private: + static std::shared_ptr s_sender; }; +std::shared_ptr DoHCrossProtocolQuery::s_sender = std::make_shared(); + +std::unique_ptr getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse) +{ + if (!dq.ids.du) { + throw std::runtime_error("Trying to create a DoH cross protocol query without a valid DoH unit"); + } + + auto du = std::move(dq.ids.du); + if (&dq.ids != &du->ids) { + du->ids = std::move(dq.ids); + } + + du->ids.origID = dq.getHeader()->id; + + if (!isResponse) { + if (du->query.data() != dq.getMutableData().data()) { + du->query = std::move(dq.getMutableData()); + } + } + else { + if (du->response.data() != dq.getMutableData().data()) { + du->response = std::move(dq.getMutableData()); + } + } + + return std::make_unique(std::move(du), isResponse); +} + /* We are not in the main DoH thread but in the DoH 'client' thread. */ @@ -650,6 +705,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) queryId = ntohs(dh->id); } + auto downstream = du->downstream; du->ids.qname = DNSName(reinterpret_cast(du->query.data()), du->query.size(), sizeof(dnsheader), false, &du->ids.qtype, &du->ids.qclass); DNSQuestion dq(du->ids, du->query); const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); @@ -657,22 +713,24 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) du->ids.cs = &cs; dq.sni = std::move(du->sni); - auto result = processQuery(dq, cs, holders, du->downstream); + auto result = processQuery(dq, holders, downstream); if (result == ProcessQueryResult::Drop) { du->status_code = 403; handleImmediateResponse(std::move(du), "DoH dropped query"); return; } - - if (result == ProcessQueryResult::SendAnswer) { + else if (result == ProcessQueryResult::Asynchronous) { + return; + } + else if (result == ProcessQueryResult::SendAnswer) { if (du->response.empty()) { du->response = std::move(du->query); } if (du->response.size() >= sizeof(dnsheader) && du->contentType.empty()) { auto dh = reinterpret_cast(du->response.data()); - handleResponseSent(ids.qname, QType(ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH); + handleResponseSent(du->ids.qname, QType(du->ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH); } handleImmediateResponse(std::move(du), "DoH self-answered response"); return; @@ -684,7 +742,6 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) return; } - auto downstream = du->downstream; if (downstream == nullptr) { du->status_code = 502; handleImmediateResponse(std::move(du), "DoH no backend available"); @@ -705,7 +762,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) du->tcp = true; /* this moves du->ids, careful! */ - auto cpq = std::make_unique(std::move(du)); + auto cpq = std::make_unique(std::move(du), false); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); if (downstream->passCrossProtocolQuery(std::move(cpq))) { @@ -1302,7 +1359,7 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) du->truncated = false; du->response.clear(); - auto cpq = std::make_unique(std::move(du)); + auto cpq = std::make_unique(std::move(du), false); if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { continue; @@ -1636,14 +1693,14 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, const dnsheader* dh = reinterpret_cast(du->response.data()); if (!dh->tc) { static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); - static thread_local LocalStateHolder> localcacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); + static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); DNSResponse dr(du->ids, du->response, du->downstream); dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); dr.ids.du = std::move(du); - if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localcacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { if (dr.ids.du) { dr.ids.du->status_code = 503; sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); @@ -1651,6 +1708,10 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, return; } + if (dr.isAsynchronous()) { + return; + } + du = std::move(dr.ids.du); double udiff = du->ids.queryRealTime.udiff(); vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); diff --git a/pdns/dnsdistdist/test-dnsdistasync.cc b/pdns/dnsdistdist/test-dnsdistasync.cc new file mode 100644 index 0000000000..7cf9df5f1e --- /dev/null +++ b/pdns/dnsdistdist/test-dnsdistasync.cc @@ -0,0 +1,165 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#define BOOST_TEST_DYN_LINK +#define BOOST_TEST_NO_MAIN + +#include + +#include "dnsdist-async.hh" + +BOOST_AUTO_TEST_SUITE(test_dnsdistasync) + +class DummyQuerySender : public TCPQuerySender +{ +public: + bool active() const override + { + return true; + } + + void handleResponse(const struct timeval&, TCPResponse&&) override + { + } + + void handleXFRResponse(const struct timeval&, TCPResponse&&) override + { + } + + void notifyIOError(InternalQueryState&&, const struct timeval&) override + { + errorRaised = true; + } + + bool errorRaised{false}; +}; + +struct DummyCrossProtocolQuery : public CrossProtocolQuery +{ + DummyCrossProtocolQuery() : + CrossProtocolQuery() + { + d_sender = std::make_shared(); + } + + std::shared_ptr getTCPQuerySender() override + { + return d_sender; + } + + std::shared_ptr d_sender; +}; + +BOOST_AUTO_TEST_CASE(test_Basic) +{ + auto holder = std::make_unique(); + BOOST_CHECK(holder->empty()); + + { + auto query = holder->get(0, 0); + BOOST_CHECK(query == nullptr); + } + + { + uint16_t asyncID = 1; + uint16_t queryID = 42; + struct timeval ttd; + gettimeofday(&ttd, nullptr); + // timeout in 100 ms + ttd.tv_usec += 100000; + + holder->push(asyncID, queryID, ttd, std::make_unique()); + BOOST_CHECK(!holder->empty()); + + auto query = holder->get(0, 0); + BOOST_CHECK(query == nullptr); + + query = holder->get(asyncID, queryID); + BOOST_CHECK(holder->empty()); + + query = holder->get(asyncID, queryID); + BOOST_CHECK(query == nullptr); + + // sleep for 200 ms, to be sure the main thread has + // been awakened + usleep(200000); + } + + holder->stop(); +} + +BOOST_AUTO_TEST_CASE(test_TimeoutFailClose) +{ + auto holder = std::make_unique(false); + uint16_t asyncID = 1; + uint16_t queryID = 42; + struct timeval ttd; + gettimeofday(&ttd, nullptr); + // timeout in 10 ms + ttd.tv_usec += 10000; + + std::shared_ptr sender{nullptr}; + { + auto query = std::make_unique(); + sender = query->d_sender; + BOOST_REQUIRE(sender != nullptr); + holder->push(asyncID, queryID, ttd, std::move(query)); + BOOST_CHECK(!holder->empty()); + } + + // sleep for 20 ms, to be sure + usleep(20000); + + BOOST_CHECK(holder->empty()); + BOOST_CHECK(sender->errorRaised); + + holder->stop(); +} + +BOOST_AUTO_TEST_CASE(test_AddingExpiredEvent) +{ + auto holder = std::make_unique(false); + uint16_t asyncID = 1; + uint16_t queryID = 42; + struct timeval ttd; + gettimeofday(&ttd, nullptr); + // timeout was 10 ms ago, for some reason (long processing time, CPU starvation...) + ttd.tv_usec -= 10000; + + std::shared_ptr sender{nullptr}; + { + auto query = std::make_unique(); + sender = query->d_sender; + BOOST_REQUIRE(sender != nullptr); + holder->push(asyncID, queryID, ttd, std::move(query)); + BOOST_CHECK(!holder->empty()); + } + + // sleep for 20 ms + usleep(20000); + + BOOST_CHECK(holder->empty()); + BOOST_CHECK(sender->errorRaised); + + holder->stop(); +} + +BOOST_AUTO_TEST_SUITE_END(); diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc index f9abd334f8..41d9992cda 100644 --- a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc @@ -599,11 +599,6 @@ public: return true; } - const ClientState* getClientState() const override - { - return nullptr; - } - void handleResponse(const struct timeval& now, TCPResponse&& response) override { if (d_customHandler) { diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index e427bf7717..7da8b3153b 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -63,12 +63,12 @@ void handleResponseSent(const InternalQueryState& ids, double udiff, const Combo { } -static std::function& selectedBackend)> s_processQuery; +static std::function& selectedBackend)> s_processQuery; -ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) +ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr& selectedBackend) { if (s_processQuery) { - return s_processQuery(dq, cs, holders, selectedBackend); + return s_processQuery(dq, selectedBackend); } return ProcessQueryResult::Drop; @@ -496,7 +496,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, query.size() - 2 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { return ProcessQueryResult::Drop; }; @@ -518,7 +518,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { // Would be nicer to actually turn it into a response return ProcessQueryResult::SendAnswer; }; @@ -550,7 +550,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { // Would be nicer to actually turn it into a response return ProcessQueryResult::SendAnswer; }; @@ -578,7 +578,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, query.size() - 2 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { throw std::runtime_error("Something unexpected happened"); }; @@ -605,7 +605,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) s_steps.push_back({ ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }); s_steps.push_back({ ExpectedStep::ExpectedRequest::closeClient, IOState::Done }); - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { // Would be nicer to actually turn it into a response return ProcessQueryResult::SendAnswer; }; @@ -627,7 +627,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, query.size() - 2 - 2 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { /* should not be reached */ BOOST_CHECK(false); return ProcessQueryResult::SendAnswer; @@ -665,7 +665,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) { ExpectedStep::ExpectedRequest::writeToClient, IOState::NeedWrite, 1 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { return ProcessQueryResult::SendAnswer; }; @@ -701,7 +701,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { return ProcessQueryResult::SendAnswer; }; @@ -759,7 +759,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { return ProcessQueryResult::SendAnswer; }; @@ -789,7 +789,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, s_proxyProtocolMinimumHeaderSize }, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { return ProcessQueryResult::SendAnswer; }; @@ -816,7 +816,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered) { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, proxyPayload.size() - s_proxyProtocolMinimumHeaderSize - 1}, { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { return ProcessQueryResult::SendAnswer; }; @@ -895,7 +895,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* closing a connection to the backend */ { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -935,7 +935,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* closing a connection to the backend */ { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -974,7 +974,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* closing a connection to the backend */ { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1017,7 +1017,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* closing a connection to the backend */ { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1045,7 +1045,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* closing client connection */ { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { return ProcessQueryResult::SendAnswer; }; s_processResponse = [](PacketBuffer& response, DNSResponse& dr, bool muted) -> bool { @@ -1082,7 +1082,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* closing backend connection */ { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1150,7 +1150,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1212,7 +1212,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; @@ -1248,7 +1248,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; @@ -1295,7 +1295,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1352,7 +1352,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1408,7 +1408,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1467,7 +1467,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1519,7 +1519,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1579,7 +1579,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1620,7 +1620,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1682,7 +1682,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* close the connection with the client */ s_steps.push_back({ ExpectedStep::ExpectedRequest::closeClient, IOState::Done }); - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1702,6 +1702,43 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) #endif } + { + /* 2 queries on the same connection, asynchronously handled, check that we only read the first one (no OOOR as maxInFlight is 0) */ + TEST_INIT("=> 2 queries on the same connection, async"); + + size_t count = 2; + + s_readBuffer = query; + + for (size_t idx = 0; idx < count; idx++) { + appendPayloadEditingID(s_readBuffer, query, idx); + appendPayloadEditingID(s_backendReadBuffer, query, idx); + } + + s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 2 }, + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, query.size() - 2 }, + /* close the connection with the client */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done } + }; + + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + selectedBackend = backend; + dq.asynchronous = true; + /* note that we do nothing with the query, we just tell the frontend it was dealt with */ + return ProcessQueryResult::Asynchronous; + }; + s_processResponse = [](PacketBuffer& response, DNSResponse& dr, bool muted) -> bool { + return true; + }; + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + IncomingTCPConnectionState::handleIO(state, now); + BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); + + /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */ + IncomingTCPConnectionState::clearAllDownstreamConnections(); + } } BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) @@ -1871,7 +1908,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -1994,7 +2031,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend,&responses](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend,&responses](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { static size_t count = 0; if (count++ == 3) { /* self answered */ @@ -2183,7 +2220,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -2255,7 +2292,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; counter = 0; - s_processQuery = [backend,&counter](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend,&counter](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { if (counter == 0) { ++counter; selectedBackend = backend; @@ -2338,7 +2375,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; counter = 0; - s_processQuery = [backend,&counter](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend,&counter](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { if (counter == 0) { ++counter; selectedBackend = backend; @@ -2459,7 +2496,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -2611,7 +2648,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -2818,7 +2855,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -2992,7 +3029,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = proxyEnabledBackend; return ProcessQueryResult::PassToBackend; }; @@ -3256,7 +3293,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -3382,7 +3419,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 }, }; - s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = proxyEnabledBackend; return ProcessQueryResult::PassToBackend; }; @@ -3467,7 +3504,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 }, }; - s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = proxyEnabledBackend; return ProcessQueryResult::PassToBackend; }; @@ -3532,7 +3569,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -3723,7 +3760,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend1](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend1](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend1; return ProcessQueryResult::PassToBackend; }; @@ -3808,7 +3845,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) { ExpectedStep::ExpectedRequest::closeClient, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -4040,7 +4077,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR) { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done }, }; - s_processQuery = [backend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { selectedBackend = backend; return ProcessQueryResult::PassToBackend; }; @@ -4063,6 +4100,65 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR) /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */ BOOST_CHECK_EQUAL(IncomingTCPConnectionState::clearAllDownstreamConnections(), 5U); } + + { + /* 2 queries on the same connection, asynchronously handled, check that we only read all of them (OOOR as maxInFlight is 65535) */ + TEST_INIT("=> 2 queries on the same connection, async with OOOR"); + + size_t count = 2; + + s_readBuffer = queries.at(0); + + for (size_t idx = 0; idx < count; idx++) { + appendPayloadEditingID(s_readBuffer, queries.at(idx), idx); + appendPayloadEditingID(s_backendReadBuffer, queries.at(idx), idx); + } + + bool timeout = false; + s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 2 }, + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, queries.at(0).size() - 2 }, + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 2 }, + { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, queries.at(1).size() - 2 }, + { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0, [&timeout](int desc) { + timeout = true; + }}, + /* close the connection with the client */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done } + }; + + s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr& selectedBackend) -> ProcessQueryResult { + selectedBackend = backend; + dq.asynchronous = true; + /* note that we do nothing with the query, we just tell the frontend it was dealt with */ + return ProcessQueryResult::Asynchronous; + }; + s_processResponse = [](PacketBuffer& response, DNSResponse& dr, bool muted) -> bool { + return true; + }; + + auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); + IncomingTCPConnectionState::handleIO(state, now); + while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { + threadData.mplexer->run(&now); + } + + struct timeval later = now; + later.tv_sec += g_tcpRecvTimeout + 1; + auto expiredConns = threadData.mplexer->getTimeouts(later); + BOOST_CHECK_EQUAL(expiredConns.size(), 1U); + for (const auto& cbData : expiredConns) { + if (cbData.second.type() == typeid(std::shared_ptr)) { + auto cbState = boost::any_cast>(cbData.second); + cbState->handleTimeout(cbState, false); + } + } + + BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); + + /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */ + IncomingTCPConnectionState::clearAllDownstreamConnections(); + } } BOOST_AUTO_TEST_SUITE_END(); diff --git a/pdns/doh.hh b/pdns/doh.hh index 62e7f83d29..325776bcdc 100644 --- a/pdns/doh.hh +++ b/pdns/doh.hh @@ -188,6 +188,7 @@ struct DOHUnit void release() { } + size_t proxyProtocolPayloadSize{0}; uint16_t status_code{200}; }; @@ -273,6 +274,11 @@ struct DOHUnit void handleUDPResponseForDoH(std::unique_ptr&&, PacketBuffer&& response, InternalQueryState&& state); +struct CrossProtocolQuery; +struct DNSQuestion; + +std::unique_ptr getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse); + #endif /* HAVE_DNS_OVER_HTTPS */ using DOHUnitUniquePtr = std::unique_ptr; diff --git a/pdns/lock.hh b/pdns/lock.hh index e8bd82988d..b37bf28ae4 100644 --- a/pdns/lock.hh +++ b/pdns/lock.hh @@ -288,7 +288,7 @@ public: return LockGuardedHolder(d_value, d_mutex); } - LockGuardedHolder read_only_lock() const + LockGuardedHolder read_only_lock() { return LockGuardedHolder(d_value, d_mutex); } diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index dacc245387..c4fe42b8aa 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -27,6 +27,8 @@ #include "dnsdist.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-internal-queries.hh" +#include "dnsdist-tcp.hh" #include "dnsdist-xpf.hh" #include "dolog.hh" @@ -37,7 +39,17 @@ #include "ednscookies.hh" #include "ednssubnet.hh" -bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&) +ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr& selectedBackend) +{ + return ProcessQueryResult::Drop; +} + +bool processResponseAfterRules(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted) +{ + return false; +} + +bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote) { return false; } @@ -47,6 +59,18 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint1 return false; } +namespace dnsdist { +std::unique_ptr getInternalQueryFromDQ(DNSQuestion& dq, bool isResponse) +{ + return nullptr; +} +} + +bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&) +{ + return false; +} + BOOST_AUTO_TEST_SUITE(test_dnsdist_cc) static const uint16_t ECSSourcePrefixV4 = 24;