From: Remi Gacogne Date: Mon, 31 Jul 2023 15:07:05 +0000 (+0200) Subject: dnsdist: Implement incoming DoH support via nghttp2 X-Git-Tag: rec-5.0.0-alpha1~19^2~34 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=7e8a05fa482317dc40e17be9e9dafa29992363fd;p=thirdparty%2Fpdns.git dnsdist: Implement incoming DoH support via nghttp2 --- diff --git a/pdns/dnsdist-doh-common.hh b/pdns/dnsdist-doh-common.hh index 44ad826a88..41166de9f3 100644 --- a/pdns/dnsdist-doh-common.hh +++ b/pdns/dnsdist-doh-common.hh @@ -77,10 +77,6 @@ struct DOHFrontend DOHFrontend() { } - DOHFrontend(std::shared_ptr tlsCtx) : - d_tlsContext(std::move(tlsCtx)) - { - } virtual ~DOHFrontend() { @@ -126,6 +122,7 @@ struct DOHFrontend #endif bool d_sendCacheControlHeaders{true}; bool d_trustForwardedForHeader{false}; + bool d_earlyACLDrop{true}; /* whether we require tue query path to exactly match one of configured ones, or accept everything below these paths. */ bool d_exactPathMatching{true}; diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index a29c17c3ad..f71a9bbf4d 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -284,7 +284,7 @@ public: struct timeval now; gettimeofday(&now, nullptr); - sender->notifyIOError(std::move(object->query.d_idstate), now); + sender->notifyIOError(now, TCPResponse(std::move(object->query))); return true; } diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index c829c2e1b5..fedd9c5d86 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -2337,14 +2337,34 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) setLuaSideEffect(); auto frontend = std::make_shared(); + if (getOptionalValue(vars, "library", frontend->d_library) == 0) { +#ifdef HAVE_NGHTTP2 + frontend->d_library = "nghttp2"; +#else /* HAVE_NGHTTP2 */ + frontend->d_library = "h2o"; +#endif /* HAVE_NGHTTP2 */ + } + if (frontend->d_library == "h2o") { #ifdef HAVE_LIBH2OEVLOOP - frontend = std::make_shared(); - frontend->d_library = "h2o"; + frontend = std::make_shared(); + frontend->d_library = "h2o"; #else /* HAVE_LIBH2OEVLOOP */ - errlog("DOH bind %s is configured to use libh2o but the library is not available", addr); - return; + errlog("DOH bind %s is configured to use libh2o but the library is not available", addr); + return; #endif /* HAVE_LIBH2OEVLOOP */ + } + else if (frontend->d_library == "nghttp2") { +#ifndef HAVE_NGHTTP2 + errlog("DOH bind %s is configured to use nghttp2 but the library is not available", addr); + return; +#endif /* HAVE_NGHTTP2 */ + } + else { + errlog("DOH bind %s is configured to use an unknown library ('%s')", addr, frontend->d_library); + return; + } + bool useTLS = true; if (certFiles && !certFiles->empty()) { if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) { return; @@ -2355,6 +2375,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) else { frontend->d_tlsContext.d_addr = ComboAddress(addr, 80); infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort()); + useTLS = false; } if (urls) { @@ -2385,6 +2406,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections); getOptionalValue(vars, "idleTimeout", frontend->d_idleTimeout); getOptionalValue(vars, "serverTokens", frontend->d_serverTokens); + getOptionalValue(vars, "provider", frontend->d_tlsContext.d_provider); + boost::algorithm::to_lower(frontend->d_tlsContext.d_provider); LuaAssociativeTable customResponseHeaders; if (getOptionalValue(vars, "customResponseHeaders", customResponseHeaders) > 0) { @@ -2397,6 +2420,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) getOptionalValue(vars, "sendCacheControlHeaders", frontend->d_sendCacheControlHeaders); getOptionalValue(vars, "keepIncomingHeaders", frontend->d_keepIncomingHeaders); getOptionalValue(vars, "trustForwardedForHeader", frontend->d_trustForwardedForHeader); + getOptionalValue(vars, "earlyACLDrop", frontend->d_earlyACLDrop); getOptionalValue(vars, "internalPipeBufferSize", frontend->d_internalPipeBufferSize); getOptionalValue(vars, "exactPathMatching", frontend->d_exactPathMatching); @@ -2432,6 +2456,21 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) checkAllParametersConsumed("addDOHLocal", vars); } + + if (useTLS && frontend->d_library == "nghttp2") { + if (!frontend->d_tlsContext.d_provider.empty()) { + vinfolog("Loading TLS provider '%s'", frontend->d_tlsContext.d_provider); + } + else { +#ifdef HAVE_LIBSSL + const std::string provider("openssl"); +#else + const std::string provider("gnutls"); +#endif + vinfolog("Loading default TLS provider '%s'", provider); + } + } + g_dohlocals.push_back(frontend); auto cs = std::make_unique(frontend->d_tlsContext.d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus); cs->dohFrontend = frontend; @@ -2648,10 +2687,11 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } else { #ifdef HAVE_LIBSSL - vinfolog("Loading default TLS provider 'openssl'"); + const std::string provider("openssl"); #else - vinfolog("Loading default TLS provider 'gnutls'"); + const std::string provider("gnutls"); #endif + vinfolog("Loading default TLS provider '%s'", provider); } // only works pre-startup, so no sync necessary auto cs = std::make_unique(frontend->d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus); diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 14af2564e3..751ba98672 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -27,6 +27,7 @@ #include "dnsdist.hh" #include "dnsdist-concurrent-connections.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-nghttp2-in.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rings.hh" #include "dnsdist-tcp.hh" @@ -96,6 +97,17 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState() d_handler.close(); } +dnsdist::Protocol IncomingTCPConnectionState::getProtocol() const +{ + if (d_ci.cs->dohFrontend) { + return dnsdist::Protocol::DoH; + } + if (d_handler.isTLS()) { + return dnsdist::Protocol::DoT; + } + return dnsdist::Protocol::DoTCP; +} + size_t IncomingTCPConnectionState::clearAllDownstreamConnections() { return t_downstreamTCPConnectionsManager.clear(); @@ -173,7 +185,7 @@ static IOState sendQueuedResponses(std::shared_ptr& TCPResponse resp = std::move(state->d_queuedResponses.front()); state->d_queuedResponses.pop_front(); state->d_state = IncomingTCPConnectionState::State::idle; - result = state->sendResponse(state, now, std::move(resp)); + result = state->sendResponse(now, std::move(resp)); if (result != IOState::Done) { return result; } @@ -183,28 +195,28 @@ static IOState sendQueuedResponses(std::shared_ptr& return IOState::Done; } -static void handleResponseSent(std::shared_ptr& state, TCPResponse& currentResponse) +void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse) { if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) { return; } - --state->d_currentQueriesCount; + --d_currentQueriesCount; const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds; if (currentResponse.d_idstate.selfGenerated == false && ds) { const auto& ids = currentResponse.d_idstate; double udiff = ids.queryRealTime.udiff(); - vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff); + vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), currentResponse.d_buffer.size(), udiff); auto backendProtocol = ds->getProtocol(); - if (backendProtocol == dnsdist::Protocol::DoUDP) { + if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) { backendProtocol = dnsdist::Protocol::DoTCP; } - ::handleResponseSent(ids, udiff, state->d_ci.remote, ds->d_config.remote, static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true); + ::handleResponseSent(ids, udiff, d_ci.remote, ds->d_config.remote, static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true); } else { const auto& ids = currentResponse.d_idstate; - ::handleResponseSent(ids, 0., state->d_ci.remote, ComboAddress(), static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false); + ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false); } currentResponse.d_buffer.clear(); @@ -232,7 +244,8 @@ bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) return false; } - if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) { + // for DoH, this is already handled by the underlying library + if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) { DEBUGLOG("not accepting new queries because we already have "<d_maxInFlightQueriesPerConn); return false; } @@ -284,9 +297,9 @@ void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_p } /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */ -IOState IncomingTCPConnectionState::sendResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response) +IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response) { - state->d_state = IncomingTCPConnectionState::State::sendingResponse; + d_state = IncomingTCPConnectionState::State::sendingResponse; uint16_t responseSize = static_cast(response.d_buffer.size()); const uint8_t sizeBytes[] = { static_cast(responseSize / 256), static_cast(responseSize % 256) }; @@ -294,27 +307,27 @@ IOState IncomingTCPConnectionState::sendResponse(std::shared_ptrd_currentPos = 0; - state->d_currentResponse = std::move(response); + d_currentPos = 0; + d_currentResponse = std::move(response); try { - auto iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size()); + auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size()); if (iostate == IOState::Done) { DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); - handleResponseSent(state, state->d_currentResponse); + handleResponseSent(d_currentResponse); return iostate; } else { - state->d_lastIOBlocked = true; + d_lastIOBlocked = true; DEBUGLOG("partial write"); return iostate; } } catch (const std::exception& e) { - vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); + vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what()); DEBUGLOG("Closing TCP client connection: "<d_ci.cs->tcpDiedSendingResponse; + ++d_ci.cs->tcpDiedSendingResponse; - state->terminateClientConnection(); + terminateClientConnection(); return IOState::Done; } @@ -408,9 +421,7 @@ void IncomingTCPConnectionState::handleAsyncReady(int fd, FDMultiplexer::funcpar if (state->active()) { /* and now we restart our own I/O state machine */ - struct timeval now; - gettimeofday(&now, nullptr); - handleIO(state, now); + state->handleIO(); } else { /* we were only waiting for the engine to come back, @@ -476,16 +487,17 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe try { auto& ids = response.d_idstate; unsigned int qnameWireLength; - if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) { + std::shared_ptr ds = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr); + if (!ds || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, ds, qnameWireLength)) { state->terminateClientConnection(); return; } - if (response.d_connection->getDS()) { - ++response.d_connection->getDS()->responses; + if (ds) { + ++ds->responses; } - DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS()); + DNSResponse dr(ids, response.d_buffer, ds); dr.d_incomingTCPState = state; memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); @@ -529,7 +541,6 @@ class TCPCrossProtocolQuery : public CrossProtocolQuery public: TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr ds, std::shared_ptr sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) { - proxyProtocolPayloadSize = 0; } ~TCPCrossProtocolQuery() @@ -561,6 +572,11 @@ private: std::shared_ptr d_sender; }; +std::unique_ptr IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr& ds) +{ + return std::make_unique(std::move(query), std::move(state), ds, shared_from_this()); +} + std::unique_ptr getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq) { auto state = dq.getIncomingTCPState(); @@ -587,60 +603,63 @@ void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeva } } -static void handleQuery(std::shared_ptr& state, const struct timeval& now) +IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::handleQuery(PacketBuffer&& queryIn, const struct timeval& now, std::optional streamID) { - if (state->d_querySize < sizeof(dnsheader)) { + auto query = std::move(queryIn); + if (query.size() < sizeof(dnsheader)) { ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++state->d_ci.cs->nonCompliantQueries; - state->terminateClientConnection(); - return; + ++d_ci.cs->nonCompliantQueries; + return QueryProcessingResult::TooSmall; } - ++state->d_queriesCount; - ++state->d_ci.cs->queries; + ++d_queriesCount; + ++d_ci.cs->queries; ++dnsdist::metrics::g_stats.queries; - if (state->d_handler.isTLS()) { - auto tlsVersion = state->d_handler.getTLSVersion(); + if (d_handler.isTLS()) { + auto tlsVersion = d_handler.getTLSVersion(); switch (tlsVersion) { case LibsslTLSVersion::TLS10: - ++state->d_ci.cs->tls10queries; + ++d_ci.cs->tls10queries; break; case LibsslTLSVersion::TLS11: - ++state->d_ci.cs->tls11queries; + ++d_ci.cs->tls11queries; break; case LibsslTLSVersion::TLS12: - ++state->d_ci.cs->tls12queries; + ++d_ci.cs->tls12queries; break; case LibsslTLSVersion::TLS13: - ++state->d_ci.cs->tls13queries; + ++d_ci.cs->tls13queries; break; default: - ++state->d_ci.cs->tlsUnknownqueries; + ++d_ci.cs->tlsUnknownqueries; } } + auto state = shared_from_this(); InternalQueryState ids; - ids.origDest = state->d_proxiedDestination; - ids.origRemote = state->d_proxiedRemote; - ids.cs = state->d_ci.cs; + ids.origDest = d_proxiedDestination; + ids.origRemote = d_proxiedRemote; + ids.cs = d_ci.cs; ids.queryRealTime.start(); + if (streamID) { + ids.d_streamID = *streamID; + } - auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true); + auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true); if (dnsCryptResponse) { TCPResponse response; - state->d_state = IncomingTCPConnectionState::State::idle; - ++state->d_currentQueriesCount; - state->queueResponse(state, now, std::move(response)); - return; + d_state = IncomingTCPConnectionState::State::idle; + ++d_currentQueriesCount; + queueResponse(state, now, std::move(response)); + return QueryProcessingResult::SelfAnswered; } { /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ - auto* dh = reinterpret_cast(state->d_buffer.data()); - if (!checkQueryHeaders(dh, *state->d_ci.cs)) { - state->terminateClientConnection(); - return; + auto* dh = reinterpret_cast(query.data()); + if (!checkQueryHeaders(dh, *d_ci.cs)) { + return QueryProcessingResult::InvalidHeaders; } if (dh->qdcount == 0) { @@ -648,81 +667,105 @@ static void handleQuery(std::shared_ptr& state, cons dh->rcode = RCode::NotImp; dh->qr = true; response.d_idstate.selfGenerated = true; - response.d_buffer = std::move(state->d_buffer); - state->d_state = IncomingTCPConnectionState::State::idle; - ++state->d_currentQueriesCount; - state->queueResponse(state, now, std::move(response)); - return; + response.d_buffer = std::move(query); + d_state = IncomingTCPConnectionState::State::idle; + ++d_currentQueriesCount; + queueResponse(state, now, std::move(response)); + return QueryProcessingResult::Empty; } } - ids.qname = DNSName(reinterpret_cast(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); - ids.protocol = dnsdist::Protocol::DoTCP; + ids.qname = DNSName(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + ids.protocol = getProtocol(); if (ids.dnsCryptQuery) { ids.protocol = dnsdist::Protocol::DNSCryptTCP; } - else if (state->d_handler.isTLS()) { - ids.protocol = dnsdist::Protocol::DoT; - } - DNSQuestion dq(ids, state->d_buffer); + DNSQuestion dq(ids, query); const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); ids.origFlags = *flags; dq.d_incomingTCPState = state; - dq.sni = state->d_handler.getServerNameIndication(); + dq.sni = d_handler.getServerNameIndication(); - if (state->d_proxyProtocolValues) { + if (d_proxyProtocolValues) { /* we need to copy them, because the next queries received on that connection will need to get the _unaltered_ values */ - dq.proxyProtocolValues = make_unique>(*state->d_proxyProtocolValues); + dq.proxyProtocolValues = make_unique>(*d_proxyProtocolValues); } if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) { dq.ids.skipCache = true; } - std::shared_ptr ds; - auto result = processQuery(dq, state->d_threadData.holders, ds); + if (forwardViaUDPFirst()) { + // if there was no EDNS, we add it with a large buffer size + // so we can use UDP to talk to the backend. + auto dh = const_cast(reinterpret_cast(query.data())); + if (!dh->arcount) { + if (addEDNS(query, 4096, false, 4096, 0)) { + dq.ids.ednsAdded = true; + } + } + } - if (result == ProcessQueryResult::Drop) { - state->terminateClientConnection(); - return; + if (streamID) { + auto unit = getDOHUnit(*streamID); + dq.ids.du = std::move(unit); } - else if (result == ProcessQueryResult::Asynchronous) { + + std::shared_ptr ds; + auto result = processQuery(dq, d_threadData.holders, ds); + + if (result == ProcessQueryResult::Asynchronous) { /* we are done for now */ - ++state->d_currentQueriesCount; - return; + ++d_currentQueriesCount; + return QueryProcessingResult::Asynchronous; + } + + if (streamID) { + restoreDOHUnit(std::move(dq.ids.du)); + } + + if (result == ProcessQueryResult::Drop) { + return QueryProcessingResult::Dropped; } // the buffer might have been invalidated by now - const dnsheader* dh = dq.getHeader(); + uint16_t queryID; + { + const dnsheader* dh = dq.getHeader(); + queryID = dh->id; + } + if (result == ProcessQueryResult::SendAnswer) { TCPResponse response; - memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); + { + const dnsheader* dh = dq.getHeader(); + memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); + } response.d_idstate = std::move(ids); - response.d_idstate.origID = dh->id; + response.d_idstate.origID = queryID; response.d_idstate.selfGenerated = true; - response.d_idstate.cs = state->d_ci.cs; - response.d_buffer = std::move(state->d_buffer); + response.d_idstate.cs = d_ci.cs; + response.d_buffer = std::move(query); - state->d_state = IncomingTCPConnectionState::State::idle; - ++state->d_currentQueriesCount; - state->queueResponse(state, now, std::move(response)); - return; + d_state = IncomingTCPConnectionState::State::idle; + ++d_currentQueriesCount; + queueResponse(state, now, std::move(response)); + return QueryProcessingResult::SelfAnswered; } if (result != ProcessQueryResult::PassToBackend || ds == nullptr) { - state->terminateClientConnection(); - return; + return QueryProcessingResult::NoBackend; } - dq.ids.origID = dh->id; + dq.ids.origID = queryID; - ++state->d_currentQueriesCount; + ++d_currentQueriesCount; std::string proxyProtocolPayload; if (ds->isDoH()) { - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getNameWithAddr()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), ds->getNameWithAddr()); /* we need to do this _before_ creating the cross protocol query because after that the buffer will have been moved */ @@ -730,21 +773,30 @@ static void handleQuery(std::shared_ptr& state, cons proxyProtocolPayload = getProxyProtocolPayload(dq); } - auto cpq = std::make_unique(std::move(state->d_buffer), std::move(ids), ds, state); + auto cpq = std::make_unique(std::move(query), std::move(ids), ds, state); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); ds->passCrossProtocolQuery(std::move(cpq)); - return; + return QueryProcessingResult::Forwarded; + } + else if (!ds->isTCPOnly() && forwardViaUDPFirst()) { + auto unit = getDOHUnit(*streamID); + dq.ids.du = std::move(unit); + if (assignOutgoingUDPQueryToBackend(ds, queryID, dq, query)) { + return QueryProcessingResult::Forwarded; + } + restoreDOHUnit(std::move(dq.ids.du)); + // fallback to the normal flow } - prependSizeToTCPQuery(state->d_buffer, 0); + prependSizeToTCPQuery(query, 0); - auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); + auto downstreamConnection = getDownstreamConnection(ds, dq.proxyProtocolValues, now); if (ds->d_config.useProxyProtocol) { /* if we ever sent a TLV over a connection, we can never go back */ - if (!state->d_proxyProtocolPayloadHasTLV) { - state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty(); + if (!d_proxyProtocolPayloadHasTLV) { + d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty(); } proxyProtocolPayload = getProxyProtocolPayload(dq); @@ -754,12 +806,13 @@ static void handleQuery(std::shared_ptr& state, cons downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues)); } - TCPQuery query(std::move(state->d_buffer), std::move(ids)); - query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); + TCPQuery tcpquery(std::move(query), std::move(ids)); + tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getNameWithAddr()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), ds->getNameWithAddr()); std::shared_ptr incoming = state; - downstreamConnection->queueQuery(incoming, std::move(query)); + downstreamConnection->queueQuery(incoming, std::move(tcpquery)); + return QueryProcessingResult::Forwarded; } void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) @@ -769,159 +822,194 @@ void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcpar throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor())); } - struct timeval now; - gettimeofday(&now, nullptr); - handleIO(conn, now); + conn->handleIO(); } -void IncomingTCPConnectionState::handleIO(std::shared_ptr& state, const struct timeval& now) +void IncomingTCPConnectionState::handleHandshakeDone(const struct timeval& now) +{ + if (d_handler.isTLS()) { + if (!d_handler.hasTLSSessionBeenResumed()) { + ++d_ci.cs->tlsNewSessions; + } + else { + ++d_ci.cs->tlsResumptions; + } + if (d_handler.getResumedFromInactiveTicketKey()) { + ++d_ci.cs->tlsInactiveTicketKey; + } + if (d_handler.getUnknownTicketKey()) { + ++d_ci.cs->tlsUnknownTicketKey; + } + } + + d_handshakeDoneTime = now; +} + +IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::handleProxyProtocolPayload() +{ + do { + DEBUGLOG("reading proxy protocol header"); + auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed); + if (iostate == IOState::Done) { + d_buffer.resize(d_currentPos); + ssize_t remaining = isProxyHeaderComplete(d_buffer); + if (remaining == 0) { + vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", d_ci.remote.toStringWithPort()); + ++dnsdist::metrics::g_stats.proxyProtocolInvalid; + return ProxyProtocolResult::Error; + } + else if (remaining < 0) { + d_proxyProtocolNeed += -remaining; + d_buffer.resize(d_currentPos + d_proxyProtocolNeed); + /* we need to keep reading, since we might have buffered data */ + } + else { + /* proxy header received */ + std::vector proxyProtocolValues; + if (!handleProxyProtocol(d_ci.remote, true, *d_threadData.holders.acl, d_buffer, d_proxiedRemote, d_proxiedDestination, proxyProtocolValues)) { + vinfolog("Error handling the Proxy Protocol received from TCP client %s", d_ci.remote.toStringWithPort()); + return ProxyProtocolResult::Error; + } + + if (!proxyProtocolValues.empty()) { + d_proxyProtocolValues = make_unique>(std::move(proxyProtocolValues)); + } + + return ProxyProtocolResult::Done; + } + } + else { + d_lastIOBlocked = true; + } + } + while (active() && !d_lastIOBlocked); + + return ProxyProtocolResult::Reading; +} + +void IncomingTCPConnectionState::handleIO() { // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read // even though the underlying socket is not ready, so we need to actually ask for the data first IOState iostate = IOState::Done; + struct timeval now; + gettimeofday(&now, nullptr); + do { iostate = IOState::Done; - IOStateGuard ioGuard(state->d_ioState); + IOStateGuard ioGuard(d_ioState); - if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { - vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); + if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort()); // will be handled by the ioGuard //handleNewIOState(state, IOState::Done, fd, handleIOCallback); return; } - state->d_lastIOBlocked = false; + d_lastIOBlocked = false; try { - if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) { + if (d_state == IncomingTCPConnectionState::State::doingHandshake) { DEBUGLOG("doing handshake"); - iostate = state->d_handler.tryHandshake(); + iostate = d_handler.tryHandshake(); if (iostate == IOState::Done) { DEBUGLOG("handshake done"); - if (state->d_handler.isTLS()) { - if (!state->d_handler.hasTLSSessionBeenResumed()) { - ++state->d_ci.cs->tlsNewSessions; - } - else { - ++state->d_ci.cs->tlsResumptions; - } - if (state->d_handler.getResumedFromInactiveTicketKey()) { - ++state->d_ci.cs->tlsInactiveTicketKey; - } - if (state->d_handler.getUnknownTicketKey()) { - ++state->d_ci.cs->tlsUnknownTicketKey; - } - } + handleHandshakeDone(now); - state->d_handshakeDoneTime = now; - if (expectProxyProtocolFrom(state->d_ci.remote)) { - state->d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader; - state->d_buffer.resize(s_proxyProtocolMinimumHeaderSize); - state->d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; + if (expectProxyProtocolFrom(d_ci.remote)) { + d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; } else { - state->d_state = IncomingTCPConnectionState::State::readingQuerySize; + d_state = IncomingTCPConnectionState::State::readingQuerySize; } } else { - state->d_lastIOBlocked = true; + d_lastIOBlocked = true; } } - if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) { - do { - DEBUGLOG("reading proxy protocol header"); - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_proxyProtocolNeed); - if (iostate == IOState::Done) { - state->d_buffer.resize(state->d_currentPos); - ssize_t remaining = isProxyHeaderComplete(state->d_buffer); - if (remaining == 0) { - vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", state->d_ci.remote.toStringWithPort()); - ++dnsdist::metrics::g_stats.proxyProtocolInvalid; - break; - } - else if (remaining < 0) { - state->d_proxyProtocolNeed += -remaining; - state->d_buffer.resize(state->d_currentPos + state->d_proxyProtocolNeed); - /* we need to keep reading, since we might have buffered data */ - iostate = IOState::NeedRead; - } - else { - /* proxy header received */ - std::vector proxyProtocolValues; - if (!handleProxyProtocol(state->d_ci.remote, true, *state->d_threadData.holders.acl, state->d_buffer, state->d_proxiedRemote, state->d_proxiedDestination, proxyProtocolValues)) { - vinfolog("Error handling the Proxy Protocol received from TCP client %s", state->d_ci.remote.toStringWithPort()); - break; - } - - if (!proxyProtocolValues.empty()) { - state->d_proxyProtocolValues = make_unique>(std::move(proxyProtocolValues)); - } - - state->d_state = IncomingTCPConnectionState::State::readingQuerySize; - state->d_buffer.resize(sizeof(uint16_t)); - state->d_currentPos = 0; - state->d_proxyProtocolNeed = 0; - break; - } - } - else { - state->d_lastIOBlocked = true; - } + if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) { + auto status = handleProxyProtocolPayload(); + if (status == ProxyProtocolResult::Done) { + d_state = IncomingTCPConnectionState::State::readingQuerySize; + d_buffer.resize(sizeof(uint16_t)); + d_currentPos = 0; + d_proxyProtocolNeed = 0; + } + else if (status == ProxyProtocolResult::Error) { + iostate = IOState::Done; + } + else { + iostate = IOState::NeedRead; } - while (state->active() && !state->d_lastIOBlocked); } - if (!state->d_lastIOBlocked && (state->d_state == IncomingTCPConnectionState::State::waitingForQuery || - state->d_state == IncomingTCPConnectionState::State::readingQuerySize)) { + if (!d_lastIOBlocked && (d_state == IncomingTCPConnectionState::State::waitingForQuery || + d_state == IncomingTCPConnectionState::State::readingQuerySize)) { DEBUGLOG("reading query size"); - state->d_buffer.resize(sizeof(uint16_t)); - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t)); - if (state->d_currentPos > 0) { + d_buffer.resize(sizeof(uint16_t)); + iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t)); + if (d_currentPos > 0) { /* if we got at least one byte, we can't go around sending responses */ - state->d_state = IncomingTCPConnectionState::State::readingQuerySize; + d_state = IncomingTCPConnectionState::State::readingQuerySize; } if (iostate == IOState::Done) { DEBUGLOG("query size received"); - state->d_state = IncomingTCPConnectionState::State::readingQuery; - state->d_querySizeReadTime = now; - if (state->d_queriesCount == 0) { - state->d_firstQuerySizeReadTime = now; + d_state = IncomingTCPConnectionState::State::readingQuery; + d_querySizeReadTime = now; + if (d_queriesCount == 0) { + d_firstQuerySizeReadTime = now; } - state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1); - if (state->d_querySize < sizeof(dnsheader)) { + d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1); + if (d_querySize < sizeof(dnsheader)) { /* go away */ - state->terminateClientConnection(); + terminateClientConnection(); return; } /* allocate a bit more memory to be able to spoof the content, get an answer from the cache or to add ECS without allocating a new buffer */ - state->d_buffer.resize(std::max(state->d_querySize + static_cast(512), s_maxPacketCacheEntrySize)); - state->d_currentPos = 0; + d_buffer.resize(std::max(d_querySize + static_cast(512), s_maxPacketCacheEntrySize)); + d_currentPos = 0; } else { - state->d_lastIOBlocked = true; + d_lastIOBlocked = true; } } - if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingQuery) { + if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingQuery) { DEBUGLOG("reading query"); - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); + iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize); if (iostate == IOState::Done) { DEBUGLOG("query received"); - state->d_buffer.resize(state->d_querySize); + d_buffer.resize(d_querySize); + + d_state = IncomingTCPConnectionState::State::idle; + auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt); + switch (processingResult) { + case QueryProcessingResult::TooSmall: + /* fall-through */ + case QueryProcessingResult::InvalidHeaders: + /* fall-through */ + case QueryProcessingResult::Dropped: + /* fall-through */ + case QueryProcessingResult::NoBackend: + terminateClientConnection(); + break; + default: + break; + } - state->d_state = IncomingTCPConnectionState::State::idle; - handleQuery(state, now); /* the state might have been updated in the meantime, we don't want to override it in that case */ - if (state->active() && state->d_state != IncomingTCPConnectionState::State::idle) { - if (state->d_ioState->isWaitingForRead()) { + if (active() && d_state != IncomingTCPConnectionState::State::idle) { + if (d_ioState->isWaitingForRead()) { iostate = IOState::NeedRead; } - else if (state->d_ioState->isWaitingForWrite()) { + else if (d_ioState->isWaitingForWrite()) { iostate = IOState::NeedWrite; } else { @@ -930,55 +1018,56 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptrd_lastIOBlocked = true; + d_lastIOBlocked = true; } } - if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::sendingResponse) { DEBUGLOG("sending response"); - iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size()); + iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size()); if (iostate == IOState::Done) { DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); - handleResponseSent(state, state->d_currentResponse); - state->d_state = IncomingTCPConnectionState::State::idle; + handleResponseSent(d_currentResponse); + d_state = IncomingTCPConnectionState::State::idle; } else { - state->d_lastIOBlocked = true; + d_lastIOBlocked = true; } } - if (state->active() && - !state->d_lastIOBlocked && + if (active() && + !d_lastIOBlocked && iostate == IOState::Done && - (state->d_state == IncomingTCPConnectionState::State::idle || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery)) + (d_state == IncomingTCPConnectionState::State::idle || + d_state == IncomingTCPConnectionState::State::waitingForQuery)) { // try sending queued responses DEBUGLOG("send responses, if any"); + auto state = shared_from_this(); iostate = sendQueuedResponses(state, now); - if (!state->d_lastIOBlocked && state->active() && iostate == IOState::Done) { + if (!d_lastIOBlocked && active() && iostate == IOState::Done) { // if the query has been passed to a backend, or dropped, and the responses have been sent, // we can start reading again - if (state->canAcceptNewQueries(now)) { - state->resetForNewQuery(); + if (canAcceptNewQueries(now)) { + resetForNewQuery(); iostate = IOState::NeedRead; } else { - state->d_state = IncomingTCPConnectionState::State::idle; + d_state = IncomingTCPConnectionState::State::idle; iostate = IOState::Done; } } } - if (state->d_state != IncomingTCPConnectionState::State::idle && - state->d_state != IncomingTCPConnectionState::State::doingHandshake && - state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader && - state->d_state != IncomingTCPConnectionState::State::waitingForQuery && - state->d_state != IncomingTCPConnectionState::State::readingQuerySize && - state->d_state != IncomingTCPConnectionState::State::readingQuery && - state->d_state != IncomingTCPConnectionState::State::sendingResponse) { - vinfolog("Unexpected state %d in handleIOCallback", static_cast(state->d_state)); + if (d_state != IncomingTCPConnectionState::State::idle && + d_state != IncomingTCPConnectionState::State::doingHandshake && + d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader && + d_state != IncomingTCPConnectionState::State::waitingForQuery && + d_state != IncomingTCPConnectionState::State::readingQuerySize && + d_state != IncomingTCPConnectionState::State::readingQuery && + d_state != IncomingTCPConnectionState::State::sendingResponse) { + vinfolog("Unexpected state %d in handleIOCallback", static_cast(d_state)); } } catch (const std::exception& e) { @@ -986,55 +1075,56 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptrd_state == IncomingTCPConnectionState::State::idle || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery) { + if (d_state == IncomingTCPConnectionState::State::idle || + d_state == IncomingTCPConnectionState::State::waitingForQuery) { /* no need to increase any counters in that case, the client is simply done with us */ } - else if (state->d_state == IncomingTCPConnectionState::State::doingHandshake || - state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery || - state->d_state == IncomingTCPConnectionState::State::readingQuerySize || - state->d_state == IncomingTCPConnectionState::State::readingQuery) { - ++state->d_ci.cs->tcpDiedReadingQuery; + else if (d_state == IncomingTCPConnectionState::State::doingHandshake || + d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader || + d_state == IncomingTCPConnectionState::State::waitingForQuery || + d_state == IncomingTCPConnectionState::State::readingQuerySize || + d_state == IncomingTCPConnectionState::State::readingQuery) { + ++d_ci.cs->tcpDiedReadingQuery; } - else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + else if (d_state == IncomingTCPConnectionState::State::sendingResponse) { /* unlikely to happen here, the exception should be handled in sendResponse() */ - ++state->d_ci.cs->tcpDiedSendingResponse; + ++d_ci.cs->tcpDiedSendingResponse; } - if (state->d_ioState->isWaitingForWrite() || state->d_queriesCount == 0) { + if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) { DEBUGLOG("Got an exception while handling TCP query: "<d_ioState->isWaitingForRead() ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); + vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), e.what()); } else { - vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); + vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what()); DEBUGLOG("Closing TCP client connection: "<terminateClientConnection(); + terminateClientConnection(); } - if (!state->active()) { + if (!active()) { DEBUGLOG("state is no longer active"); return; } + auto state = shared_from_this(); if (iostate == IOState::Done) { - state->d_ioState->update(iostate, handleIOCallback, state); + d_ioState->update(iostate, handleIOCallback, state); } else { updateIO(state, iostate, now); } ioGuard.release(); } - while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !state->d_lastIOBlocked); + while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !d_lastIOBlocked); } -void IncomingTCPConnectionState::notifyIOError(InternalQueryState&& query, const struct timeval& now) +void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPResponse&& response) { if (std::this_thread::get_id() != d_creatorThreadID) { /* empty buffer will signal an IO error */ - TCPResponse response(PacketBuffer(), std::move(query), nullptr, nullptr); + response.d_buffer.clear(); handleCrossProtocolResponse(now, std::move(response)); return; } @@ -1115,8 +1205,17 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param struct timeval now; gettimeofday(&now, nullptr); - auto state = std::make_shared(std::move(*citmp), *threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + + if (citmp->cs->dohFrontend) { +#ifdef HAVE_NGHTTP2 + auto state = std::make_shared(std::move(*citmp), *threadData, now); + state->handleIO(); +#endif /* HAVE_NGHTTP2 */ + } + else { + auto state = std::make_shared(std::move(*citmp), *threadData, now); + state->handleIO(); + } } static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) @@ -1141,20 +1240,18 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par std::shared_ptr tqs = cpq->getTCPQuerySender(); auto query = std::move(cpq->query); auto downstreamServer = std::move(cpq->downstream); - auto proxyProtocolPayloadSize = cpq->proxyProtocolPayloadSize; try { auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string()); - prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); - query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize; + prependSizeToTCPQuery(query.d_buffer, query.d_idstate.d_proxyProtocolPayloadSize); vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr()); downstream->queueQuery(tqs, std::move(query)); } catch (...) { - tqs->notifyIOError(std::move(query.d_idstate), now); + tqs->notifyIOError(now, std::move(query)); } } @@ -1178,7 +1275,7 @@ static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& try { if (response.d_response.d_buffer.empty()) { - response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now); + response.d_state->notifyIOError(response.d_now, std::move(response.d_response)); } else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) { response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response)); @@ -1337,7 +1434,8 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa { auto& cs = param.cs; auto& acl = param.acl; - int socket = param.socket; + const bool checkACL = !cs.dohFrontend || (!cs.dohFrontend->d_trustForwardedForHeader && cs.dohFrontend->d_earlyACLDrop); + const int socket = param.socket; bool tcpClientCountIncremented = false; ComboAddress remote; remote.sin4.sin_family = param.local.sin4.sin_family; @@ -1358,7 +1456,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str()); } - if (!acl->match(remote)) { + if (checkACL && !acl->match(remote)) { ++dnsdist::metrics::g_stats.aclDrops; vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort()); return; @@ -1395,6 +1493,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa vinfolog("Got TCP connection from %s", remote.toStringWithPort()); ci.remote = remote; + if (threadData == nullptr) { if (!g_tcpclientthreads->passConnectionToThread(std::make_unique(std::move(ci)))) { if (tcpClientCountIncremented) { @@ -1405,8 +1504,17 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa else { struct timeval now; gettimeofday(&now, nullptr); - auto state = std::make_shared(std::move(ci), *threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + + if (ci.cs->dohFrontend) { +#ifdef HAVE_NGHTTP2 + auto state = std::make_shared(std::move(ci), *threadData, now); + state->handleIO(); +#endif /* HAVE_NGHTTP2 */ + } + else { + auto state = std::make_shared(std::move(ci), *threadData, now); + state->handleIO(); + } } } catch (const std::exception& e) { diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index fdf2797104..19e477ce79 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1469,7 +1469,7 @@ public: return handleResponse(now, std::move(response)); } - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override + void notifyIOError(const struct timeval&, TCPResponse&&) override { // nothing to do } @@ -2573,18 +2573,24 @@ int main(int argc, char** argv) cout<<"gnutls"; #ifdef HAVE_LIBSSL cout<<" "; -#endif /* HAVE_LIBSSL */ +#endif #endif /* HAVE_GNUTLS */ #ifdef HAVE_LIBSSL cout<<"openssl"; -#endif /* HAVE_LIBSSL */ +#endif cout<<") "; #endif /* HAVE_DNS_OVER_TLS */ #ifdef HAVE_DNS_OVER_HTTPS cout<<"dns-over-https("; #ifdef HAVE_LIBH2OEVLOOP cout<<"h2o"; +#ifdef HAVE_NGHTTP2 + cout<<" "; +#endif #endif /* HAVE_LIBH2OEVLOOP */ +#ifdef HAVE_NGHTTP2 + cout<<"nghttp2"; +#endif cout<<") "; #endif /* HAVE_DNS_OVER_HTTPS */ #ifdef HAVE_DNSCRYPT @@ -2608,9 +2614,6 @@ int main(int argc, char** argv) #ifdef HAVE_LMDB cout<<"lmdb "; #endif -#ifdef HAVE_NGHTTP2 - cout<<"outgoing-dns-over-https(nghttp2) "; -#endif #ifndef DISABLE_PROTOBUF cout<<"protobuf "; #endif @@ -2914,8 +2917,8 @@ int main(int argc, char** argv) std::vector tcpStates; std::vector udpStates; - for(auto& cs : g_frontends) { - if (cs->dohFrontend != nullptr) { + for (auto& cs : g_frontends) { + if (cs->dohFrontend != nullptr && cs->dohFrontend->d_library == "h2o") { #ifdef HAVE_DNS_OVER_HTTPS #ifdef HAVE_LIBH2OEVLOOP std::thread t1(dohThread, cs.get()); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 99d7cdbe64..e4f30eaa83 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -80,6 +80,10 @@ if HAVE_LIBSSL AM_CPPFLAGS += $(LIBSSL_CFLAGS) endif +if HAVE_GNUTLS +AM_CPPFLAGS += $(GNUTLS_CFLAGS) +endif + if HAVE_LIBH2OEVLOOP AM_CPPFLAGS += $(LIBH2OEVLOOP_CFLAGS) endif @@ -178,6 +182,7 @@ dnsdist_SOURCES = \ dnsdist-lua.cc dnsdist-lua.hh \ dnsdist-mac-address.cc dnsdist-mac-address.hh \ dnsdist-metrics.cc dnsdist-metrics.hh \ + dnsdist-nghttp2-in.cc dnsdist-nghttp2-in.hh \ dnsdist-nghttp2.cc dnsdist-nghttp2.hh \ dnsdist-prometheus.hh \ dnsdist-protobuf.cc dnsdist-protobuf.hh \ @@ -274,6 +279,7 @@ testrunner_SOURCES = \ dnsdist-lua-vars.cc \ dnsdist-mac-address.cc dnsdist-mac-address.hh \ dnsdist-metrics.cc dnsdist-metrics.hh \ + dnsdist-nghttp2-in.cc dnsdist-nghttp2-in.hh \ dnsdist-nghttp2.cc dnsdist-nghttp2.hh \ dnsdist-protocols.cc dnsdist-protocols.hh \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ @@ -411,6 +417,10 @@ endif if HAVE_DNS_OVER_HTTPS +if HAVE_GNUTLS +dnsdist_LDADD += -lgnutls +endif + if HAVE_LIBH2OEVLOOP dnsdist_LDADD += $(LIBH2OEVLOOP_LIBS) endif diff --git a/pdns/dnsdistdist/configure.ac b/pdns/dnsdistdist/configure.ac index 5805bbdfe8..af9f8c422a 100644 --- a/pdns/dnsdistdist/configure.ac +++ b/pdns/dnsdistdist/configure.ac @@ -71,6 +71,7 @@ AM_CONDITIONAL([HAVE_GNUTLS], [false]) AM_CONDITIONAL([HAVE_LIBH2OEVLOOP], [false]) AM_CONDITIONAL([HAVE_LIBSSL], [false]) AM_CONDITIONAL([HAVE_LMDB], [false]) +AM_CONDITIONAL([HAVE_NGHTTP2], [false]) PDNS_CHECK_LIBCRYPTO @@ -81,30 +82,28 @@ DNSDIST_ENABLE_DNS_OVER_HTTPS AS_IF([test "x$enable_dns_over_tls" != "xno" -o "x$enable_dns_over_https" != "xno"], [ PDNS_WITH_LIBSSL + PDNS_WITH_GNUTLS ]) AS_IF([test "x$enable_dns_over_tls" != "xno"], [ - PDNS_WITH_GNUTLS - AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [ AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available]) ]) ]) AS_IF([test "x$enable_dns_over_https" != "xno"], [ + PDNS_WITH_NGHTTP2 PDNS_WITH_LIBH2OEVLOOP - AS_IF([test "x$HAVE_LIBH2OEVLOOP" != "x1"], [ - AC_MSG_ERROR([DNS over HTTPS support requested but libh2o-evloop was not found]) + AS_IF([test "x$HAVE_LIBH2OEVLOOP" != "x1" -a "x$HAVE_NGHTTP2" != "x1" ], [ + AC_MSG_ERROR([DNS over HTTPS support requested but neither libh2o-evloop nor nghttp2 was not found]) ]) - AS_IF([test "x$HAVE_LIBSSL" != "x1"], [ - AC_MSG_ERROR([DNS over HTTPS support requested but OpenSSL was not found]) + AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [ + AC_MSG_ERROR([DNS over HTTPS support requested but neither GnuTLS nor OpenSSL are available]) ]) ]) -PDNS_WITH_NGHTTP2 - DNSDIST_WITH_CDB PDNS_CHECK_LMDB PDNS_ENABLE_IPCIPHER diff --git a/pdns/dnsdistdist/dnsdist-async.cc b/pdns/dnsdistdist/dnsdist-async.cc index f54b1c0b14..9cb96d83a2 100644 --- a/pdns/dnsdistdist/dnsdist-async.cc +++ b/pdns/dnsdistdist/dnsdist-async.cc @@ -137,7 +137,8 @@ void AsynchronousHolder::mainThread(std::shared_ptr data) vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec); auto sender = query->getTCPQuerySender(); if (sender) { - sender->notifyIOError(std::move(query->query.d_idstate), now); + TCPResponse tresponse(std::move(query->query)); + sender->notifyIOError(now, std::move(tresponse)); } } else { diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index 66f421479e..67b65796d4 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -168,7 +168,7 @@ public: throw std::runtime_error("Unexpected XFR reponse to a health check query"); } - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override + void notifyIOError(const struct timeval& now, TCPResponse&&) override { ++d_data->d_ds->d_healthCheckMetrics.d_networkErrors; d_data->d_ds->submitHealthCheckResult(d_data->d_initial, false); diff --git a/pdns/dnsdistdist/dnsdist-internal-queries.cc b/pdns/dnsdistdist/dnsdist-internal-queries.cc index 49f95e42b4..ea4c541397 100644 --- a/pdns/dnsdistdist/dnsdist-internal-queries.cc +++ b/pdns/dnsdistdist/dnsdist-internal-queries.cc @@ -20,6 +20,7 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ #include "dnsdist-internal-queries.hh" +#include "dnsdist-nghttp2-in.hh" #include "dnsdist-tcp.hh" #include "doh.hh" @@ -35,7 +36,12 @@ std::unique_ptr getInternalQueryFromDQ(DNSQuestion& dq, bool } #ifdef HAVE_DNS_OVER_HTTPS else if (protocol == dnsdist::Protocol::DoH) { - return getDoHCrossProtocolQueryFromDQ(dq, isResponse); +#ifdef HAVE_LIBH2OEVLOOP + if (dq.ids.cs->dohFrontend->d_library == "h2o") { + return getDoHCrossProtocolQueryFromDQ(dq, isResponse); + } +#endif /* HAVE_LIBH2OEVLOOP */ + return getTCPCrossProtocolQueryFromDQ(dq); } #endif else { diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index 20c866931a..48ce507da8 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -929,7 +929,8 @@ bool dnsdist_ffi_drop_from_async(uint16_t asyncID, uint16_t queryID) struct timeval now; gettimeofday(&now, nullptr); - sender->notifyIOError(std::move(query->query.d_idstate), now); + TCPResponse tresponse(std::move(query->query)); + sender->notifyIOError(now, std::move(tresponse)); return true; } diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc new file mode 100644 index 0000000000..aefa50d777 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc @@ -0,0 +1,1214 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "base64.hh" +#include "dnsdist-nghttp2-in.hh" +#include "dnsdist-proxy-protocol.hh" +#include "dnsparser.hh" + +#ifdef HAVE_NGHTTP2 + +#if 0 +class IncomingDoHCrossProtocolContext : public CrossProtocolContext +{ +public: + IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr connection, IncomingHTTP2Connection::StreamID streamID): CrossProtocolContext(std::move(query.d_buffer)), d_connection(connection), d_query(std::move(query)) + { + } + + std::optional getHTTPPath() const override + { + return d_query.d_path; + } + + std::optional getHTTPScheme() const override + { + return d_query.d_scheme; + } + + std::optional getHTTPHost() const override + { + return d_query.d_host; + } + + std::optional getHTTPQueryString() const override + { + return d_query.d_queryString; + } + + std::optional getHTTPHeaders() const override + { + if (!d_query.d_headers) { + return std::nullopt; + } + return *d_query.d_headers; + } + + void handleResponse(PacketBuffer&& response, InternalQueryState&& state) override + { + auto conn = d_connection.lock(); + if (!conn) { + /* the connection has been closed in the meantime */ + return; + } + } + + void handleTimeout() override + { + auto conn = d_connection.lock(); + if (!conn) { + /* the connection has been closed in the meantime */ + return; + } + } + + ~IncomingDoHCrossProtocolContext() override + { + } + +private: + std::weak_ptr d_connection; + IncomingHTTP2Connection::PendingQuery d_query; + IncomingHTTP2Connection::StreamID d_streamID{-1}; +}; +#endif + +class IncomingDoHCrossProtocolContext : public DOHUnitInterface +{ +public: + IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr connection, IncomingHTTP2Connection::StreamID streamID) : + d_connection(connection), d_query(std::move(query)), d_streamID(streamID) + { + } + + std::string getHTTPPath() const override + { + return d_query.d_path; + } + + const std::string& getHTTPScheme() const override + { + return d_query.d_scheme; + } + + const std::string& getHTTPHost() const override + { + return d_query.d_host; + } + + std::string getHTTPQueryString() const override + { + return d_query.d_queryString; + } + + const HeadersMap& getHTTPHeaders() const override + { + if (!d_query.d_headers) { + static const HeadersMap empty{}; + return empty; + } + return *d_query.d_headers; + } + + void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "") override + { + d_query.d_statusCode = statusCode; + d_query.d_response = std::move(body); + d_query.d_contentTypeOut = contentType; + } + + void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr& ds) override + { + std::unique_ptr unit(this); + auto conn = d_connection.lock(); + if (!conn) { + /* the connection has been closed in the meantime */ + return; + } + + state.du = std::move(unit); + TCPResponse resp(std::move(response), std::move(state), nullptr, nullptr); + resp.d_ds = ds; + struct timeval now; + gettimeofday(&now, nullptr); + conn->handleResponse(now, std::move(resp)); + } + + void handleTimeout() override + { + std::unique_ptr unit(this); + auto conn = d_connection.lock(); + if (!conn) { + /* the connection has been closed in the meantime */ + return; + } + struct timeval now; + gettimeofday(&now, nullptr); + TCPResponse resp; + resp.d_idstate.d_streamID = d_streamID; + conn->notifyIOError(now, std::move(resp)); + } + + ~IncomingDoHCrossProtocolContext() override + { + } + + std::weak_ptr d_connection; + IncomingHTTP2Connection::PendingQuery d_query; + IncomingHTTP2Connection::StreamID d_streamID{-1}; +}; + +void IncomingHTTP2Connection::handleResponse(const struct timeval& now, TCPResponse&& response) +{ + if (std::this_thread::get_id() != d_creatorThreadID) { + handleCrossProtocolResponse(now, std::move(response)); + return; + } + + auto& state = response.d_idstate; + if (state.forwardedOverUDP) { + dnsheader* responseDH = reinterpret_cast(response.d_buffer.data()); + + if (responseDH->tc && state.d_packet && state.d_packet->size() > state.d_proxyProtocolPayloadSize && state.d_packet->size() - state.d_proxyProtocolPayloadSize > sizeof(dnsheader)) { + auto& query = *state.d_packet; + dnsheader* queryDH = reinterpret_cast(query.data() + state.d_proxyProtocolPayloadSize); + /* restoring the original ID */ + queryDH->id = state.origID; + + state.forwardedOverUDP = false; + auto cpq = getCrossProtocolQuery(std::move(query), std::move(state), response.d_ds); + cpq->query.d_proxyProtocolPayloadAdded = state.d_proxyProtocolPayloadSize > 0; + if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { + return; + } + else { + vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP"); + notifyIOError(now, std::move(response)); + return; + } + } + } + + IncomingTCPConnectionState::handleResponse(now, std::move(response)); +} + +std::unique_ptr IncomingHTTP2Connection::getDOHUnit(uint32_t streamID) +{ + auto query = std::move(d_currentStreams.at(streamID)); + return std::make_unique(std::move(query), std::dynamic_pointer_cast(shared_from_this()), streamID); +} + +void IncomingHTTP2Connection::restoreDOHUnit(std::unique_ptr&& unit) +{ + auto context = std::unique_ptr(dynamic_cast(unit.release())); + d_currentStreams.at(context->d_streamID) = std::move(context->d_query); +} + +void IncomingHTTP2Connection::restoreContext(uint32_t streamID, IncomingHTTP2Connection::PendingQuery&& context) +{ + d_currentStreams.at(streamID) = std::move(context); +} + +IncomingHTTP2Connection::IncomingHTTP2Connection(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now) : + IncomingTCPConnectionState(std::move(ci), threadData, now) +{ + nghttp2_session_callbacks* cbs = nullptr; + if (nghttp2_session_callbacks_new(&cbs) != 0) { + throw std::runtime_error("Unable to create a callback object for a new incoming HTTP/2 session"); + } + std::unique_ptr callbacks(cbs, nghttp2_session_callbacks_del); + cbs = nullptr; + + nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback); + nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback); + nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback); + nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback); + nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback); + nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback); + nghttp2_session_callbacks_set_error_callback2(callbacks.get(), on_error_callback); + + nghttp2_session* sess = nullptr; + if (nghttp2_session_server_new(&sess, callbacks.get(), this) != 0) { + throw std::runtime_error("Coult not allocate a new incoming HTTP/2 session"); + } + + d_session = std::unique_ptr(sess, nghttp2_session_del); + sess = nullptr; +} + +bool IncomingHTTP2Connection::checkALPN() +{ + constexpr std::array h2{'h', '2'}; + auto protocols = d_handler.getNextProtocol(); + if (protocols.size() == h2.size() && memcmp(protocols.data(), h2.data(), h2.size()) == 0) { + return true; + } + vinfolog("DoH connection from %s expected ALPN value 'h2', got '%s'", d_ci.remote.toStringWithPort(), std::string(protocols.begin(), protocols.end())); + return false; +} + +void IncomingHTTP2Connection::handleConnectionReady() +{ + constexpr std::array iv{{{NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100U}}}; + auto ret = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv.data(), iv.size()); + if (ret != 0) { + throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret))); + } + ret = nghttp2_session_send(d_session.get()); + if (ret != 0) { + throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret))); + } +} + +void IncomingHTTP2Connection::handleIO() +{ + IOState iostate = IOState::Done; + struct timeval now; + gettimeofday(&now, nullptr); + + try { + if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating DoH connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort()); + stopIO(); + d_connectionDied = true; + return; + } + + if (d_state == State::doingHandshake) { + iostate = d_handler.tryHandshake(); + if (iostate == IOState::Done) { + handleHandshakeDone(now); + if (d_handler.isTLS()) { + if (!checkALPN()) { + d_connectionDied = true; + stopIO(); + return; + } + } + + if (expectProxyProtocolFrom(d_ci.remote)) { + d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; + } + else { + d_state = State::waitingForQuery; + handleConnectionReady(); + } + } + } + + if (d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) { + auto status = handleProxyProtocolPayload(); + if (status == ProxyProtocolResult::Done) { + d_currentPos = 0; + d_proxyProtocolNeed = 0; + d_buffer.clear(); + d_state = State::waitingForQuery; + handleConnectionReady(); + } + else if (status == ProxyProtocolResult::Error) { + d_connectionDied = true; + stopIO(); + return; + } + } + + if (d_state == State::waitingForQuery) { + readHTTPData(); + } + + if (!d_connectionDied) { + auto shared = std::dynamic_pointer_cast(shared_from_this()); + if (nghttp2_session_want_read(d_session.get())) { + d_ioState->add(IOState::NeedRead, &handleReadableIOCallback, shared, boost::none); + } + if (nghttp2_session_want_write(d_session.get())) { + d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, shared, boost::none); + } + } + } + catch (const std::exception& e) { + vinfolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what()); + d_connectionDied = true; + stopIO(); + } +} + +ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) +{ + IncomingHTTP2Connection* conn = reinterpret_cast(user_data); + bool bufferWasEmpty = conn->d_out.empty(); + conn->d_out.insert(conn->d_out.end(), data, data + length); + + if (bufferWasEmpty) { + try { + auto state = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size()); + if (state == IOState::Done) { + conn->d_out.clear(); + conn->d_outPos = 0; + if (!conn->isIdle()) { + conn->updateIO(IOState::NeedRead, handleReadableIOCallback); + } + else { + conn->watchForRemoteHostClosingConnection(); + } + } + else { + conn->updateIO(state, handleWritableIOCallback); + } + } + catch (const std::exception& e) { + vinfolog("Exception while trying to write (send) to incoming HTTP connection: %s", e.what()); + conn->handleIOError(); + } + } + + return length; +} + +static const std::unordered_map s_constants{ + {"200-value", "200"}, + {"method-name", ":method"}, + {"method-value", "POST"}, + {"scheme-name", ":scheme"}, + {"scheme-value", "https"}, + {"authority-name", ":authority"}, + {"x-forwarded-for-name", "x-forwarded-for"}, + {"path-name", ":path"}, + {"content-length-name", "content-length"}, + {"status-name", ":status"}, + {"location-name", "location"}, + {"accept-name", "accept"}, + {"accept-value", "application/dns-message"}, + {"cache-control-name", "cache-control"}, + {"content-type-name", "content-type"}, + {"content-type-value", "application/dns-message"}, + {"user-agent-name", "user-agent"}, + {"user-agent-value", "nghttp2-" NGHTTP2_VERSION "/dnsdist"}, + {"x-forwarded-port-name", "x-forwarded-port"}, + {"x-forwarded-proto-name", "x-forwarded-proto"}, + {"x-forwarded-proto-value-dns-over-udp", "dns-over-udp"}, + {"x-forwarded-proto-value-dns-over-tcp", "dns-over-tcp"}, + {"x-forwarded-proto-value-dns-over-tls", "dns-over-tls"}, + {"x-forwarded-proto-value-dns-over-http", "dns-over-http"}, + {"x-forwarded-proto-value-dns-over-https", "dns-over-https"}, +}; + +static const std::string s_authorityHeaderName(":authority"); +static const std::string s_pathHeaderName(":path"); +static const std::string s_methodHeaderName(":method"); +static const std::string s_schemeHeaderName(":scheme"); +static const std::string s_xForwardedForHeaderName("x-forwarded-for"); + +void NGHTTP2Headers::addStaticHeader(std::vector& headers, const std::string& nameKey, const std::string& valueKey) +{ + const auto& name = s_constants.at(nameKey); + const auto& value = s_constants.at(valueKey); + + headers.push_back({const_cast(reinterpret_cast(name.c_str())), const_cast(reinterpret_cast(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE}); +} + +void NGHTTP2Headers::addCustomDynamicHeader(std::vector& headers, const std::string& name, const std::string_view& value) +{ + headers.push_back({const_cast(reinterpret_cast(name.data())), const_cast(reinterpret_cast(value.data())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE}); +} + +void NGHTTP2Headers::addDynamicHeader(std::vector& headers, const std::string& nameKey, const std::string_view& value) +{ + const auto& name = s_constants.at(nameKey); + NGHTTP2Headers::addCustomDynamicHeader(headers, name, value); +} + +IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResponse&& response) +{ + assert(response.d_idstate.d_streamID != -1); + auto& context = d_currentStreams.at(response.d_idstate.d_streamID); + + uint32_t statusCode = 200U; + std::string contentType; + bool sendContentType = true; + auto& responseBuffer = context.d_buffer; + if (context.d_statusCode != 0) { + responseBuffer = std::move(context.d_response); + statusCode = context.d_statusCode; + contentType = std::move(context.d_contentTypeOut); + } + else { + responseBuffer = std::move(response.d_buffer); + } + + sendResponse(response.d_idstate.d_streamID, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType); + handleResponseSent(response); + + return IOState::Done; +} + +void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPResponse&& response) +{ + if (std::this_thread::get_id() != d_creatorThreadID) { + /* empty buffer will signal an IO error */ + response.d_buffer.clear(); + handleCrossProtocolResponse(now, std::move(response)); + return; + } + + assert(response.d_idstate.d_streamID != -1); + d_currentStreams.at(response.d_idstate.d_streamID).d_buffer = std::move(response.d_buffer); + sendResponse(response.d_idstate.d_streamID, 502, d_ci.cs->dohFrontend->d_customResponseHeaders); +} + +bool IncomingHTTP2Connection::sendResponse(IncomingHTTP2Connection::StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType, bool addContentType) +{ + /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set. + */ + nghttp2_data_provider data_provider; + + data_provider.source.ptr = this; + data_provider.read_callback = [](nghttp2_session*, IncomingHTTP2Connection::StreamID stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* cb_data) -> ssize_t { + auto connection = reinterpret_cast(cb_data); + auto& obj = connection->d_currentStreams.at(stream_id); + size_t toCopy = 0; + if (obj.d_queryPos < obj.d_buffer.size()) { + size_t remaining = obj.d_buffer.size() - obj.d_queryPos; + toCopy = length > remaining ? remaining : length; + memcpy(buf, &obj.d_buffer.at(obj.d_queryPos), toCopy); + obj.d_queryPos += toCopy; + } + + if (obj.d_queryPos >= obj.d_buffer.size()) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return toCopy; + }; + + const auto& df = d_ci.cs->dohFrontend; + auto& responseBody = d_currentStreams.at(streamID).d_buffer; + + std::vector headers; + std::string responseCodeStr; + std::string cacheControlValue; + std::string location; + /* remember that dynamic header values should be kept alive + until we have called nghttp2_submit_response(), at least */ + + if (responseCode == 200) { + NGHTTP2Headers::addStaticHeader(headers, "status-name", "200-value"); + ++df->d_validresponses; + ++df->d_http2Stats.d_nb200Responses; + + if (addContentType) { + if (contentType.empty()) { + NGHTTP2Headers::addStaticHeader(headers, "content-type-name", "content-type-value"); + } + else { + NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", contentType); + } + } + + if (df->d_sendCacheControlHeaders && responseBody.size() > sizeof(dnsheader)) { + uint32_t minTTL = getDNSPacketMinTTL(reinterpret_cast(responseBody.data()), responseBody.size()); + if (minTTL != std::numeric_limits::max()) { + cacheControlValue = "max-age=" + std::to_string(minTTL); + NGHTTP2Headers::addDynamicHeader(headers, "cache-control-name", cacheControlValue); + } + } + } + else { + responseCodeStr = std::to_string(responseCode); + NGHTTP2Headers::addDynamicHeader(headers, "status-name", responseCodeStr); + + if (responseCode >= 300 && responseCode < 400) { + location = std::string(reinterpret_cast(responseBody.data()), responseBody.size()); + NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", "text/html; charset=utf-8"); + NGHTTP2Headers::addDynamicHeader(headers, "location-name", location); + static const std::string s_redirectStart{"Moved

The document has moved here"}; + responseBody.reserve(s_redirectStart.size() + responseBody.size() + s_redirectEnd.size()); + responseBody.insert(responseBody.begin(), s_redirectStart.begin(), s_redirectStart.end()); + responseBody.insert(responseBody.end(), s_redirectEnd.begin(), s_redirectEnd.end()); + ++df->d_redirectresponses; + } + else { + ++df->d_errorresponses; + switch (responseCode) { + case 400: + ++df->d_http2Stats.d_nb400Responses; + break; + case 403: + ++df->d_http2Stats.d_nb403Responses; + break; + case 500: + ++df->d_http2Stats.d_nb500Responses; + break; + case 502: + ++df->d_http2Stats.d_nb502Responses; + break; + default: + ++df->d_http2Stats.d_nbOtherResponses; + break; + } + + if (!responseBody.empty()) { + NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", "text/plain; charset=utf-8"); + } + else { + static const std::string invalid{"invalid DNS query"}; + static const std::string notAllowed{"dns query not allowed"}; + static const std::string noDownstream{"no downstream server available"}; + static const std::string internalServerError{"Internal Server Error"}; + + switch (responseCode) { + case 400: + responseBody.insert(responseBody.begin(), invalid.begin(), invalid.end()); + break; + case 403: + responseBody.insert(responseBody.begin(), notAllowed.begin(), notAllowed.end()); + break; + case 502: + responseBody.insert(responseBody.begin(), noDownstream.begin(), noDownstream.end()); + break; + case 500: + /* fall-through */ + default: + responseBody.insert(responseBody.begin(), internalServerError.begin(), internalServerError.end()); + break; + } + } + } + } + + const std::string contentLength = std::to_string(responseBody.size()); + NGHTTP2Headers::addDynamicHeader(headers, "content-length-name", contentLength); + + for (const auto& [key, value] : customResponseHeaders) { + NGHTTP2Headers::addCustomDynamicHeader(headers, key, value); + } + + auto ret = nghttp2_submit_response(d_session.get(), streamID, headers.data(), headers.size(), &data_provider); + if (ret != 0) { + d_currentStreams.erase(streamID); + vinfolog("Error submitting HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret)); + return false; + } + + ret = nghttp2_session_send(d_session.get()); + if (ret != 0) { + d_currentStreams.erase(streamID); + vinfolog("Error flushing HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret)); + return false; + } + + return true; +} + +static void processForwardedForHeader(const std::unique_ptr& headers, ComboAddress& remote) +{ + if (!headers) { + return; + } + + auto it = headers->find(s_xForwardedForHeaderName); + if (it == headers->end()) { + return; + } + + std::string_view value = it->second; + try { + auto pos = value.rfind(','); + if (pos != std::string_view::npos) { + ++pos; + for (; pos < value.size() && value[pos] == ' '; ++pos) { + } + + if (pos < value.size()) { + value = value.substr(pos); + } + } + auto newRemote = ComboAddress(std::string(value)); + remote = newRemote; + } + catch (const std::exception& e) { + vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.what()); + } + catch (const PDNSException& e) { + vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.reason); + } +} + +static std::optional getPayloadFromPath(const std::string_view& path) +{ + std::optional result{std::nullopt}; + + if (path.size() <= 5) { + return result; + } + + auto pos = path.find("?dns="); + if (pos == string::npos) { + pos = path.find("&dns="); + } + + if (pos == string::npos) { + return result; + } + + // need to base64url decode this + string sdns(path.substr(pos + 5)); + boost::replace_all(sdns, "-", "+"); + boost::replace_all(sdns, "_", "/"); + + // re-add padding that may have been missing + switch (sdns.size() % 4) { + case 2: + sdns.append(2, '='); + break; + case 3: + sdns.append(1, '='); + break; + } + + PacketBuffer decoded; + /* rough estimate so we hopefully don't need a new allocation later */ + /* We reserve at few additional bytes to be able to add EDNS later */ + const size_t estimate = ((sdns.size() * 3) / 4); + decoded.reserve(estimate); + if (B64Decode(sdns, decoded) < 0) { + return result; + } + + result = std::move(decoded); + return result; +} + +void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::PendingQuery&& query, IncomingHTTP2Connection::StreamID streamID) +{ + const auto handleImmediateResponse = [this, &query, streamID](uint16_t code, const std::string& reason, PacketBuffer&& response = PacketBuffer()) { + if (response.empty()) { + query.d_buffer.clear(); + query.d_buffer.insert(query.d_buffer.begin(), reason.begin(), reason.end()); + } + else { + query.d_buffer = std::move(response); + } + vinfolog("Sending an immediate %d response to incoming DoH query: %s", code, reason); + sendResponse(streamID, code, d_ci.cs->dohFrontend->d_customResponseHeaders); + }; + + ++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries; + + if (d_ci.cs->dohFrontend->d_trustForwardedForHeader) { + processForwardedForHeader(query.d_headers, d_proxiedRemote); + + /* second ACL lookup based on the updated address */ + auto& holders = d_threadData.holders; + if (!holders.acl->match(d_proxiedRemote)) { + ++dnsdist::metrics::g_stats.aclDrops; + vinfolog("Query from %s (%s) (DoH) dropped because of ACL", d_ci.remote.toStringWithPort(), d_proxiedRemote.toStringWithPort()); + handleImmediateResponse(403, "DoH query not allowed because of ACL"); + return; + } + + if (!d_ci.cs->dohFrontend->d_keepIncomingHeaders) { + query.d_headers.reset(); + } + } + + if (d_ci.cs->dohFrontend->d_exactPathMatching) { + if (d_ci.cs->dohFrontend->d_urls.count(query.d_path) == 0) { + handleImmediateResponse(404, "there is no endpoint configured for this path"); + return; + } + } + else { + bool found = false; + for (const auto& path : d_ci.cs->dohFrontend->d_urls) { + if (boost::starts_with(query.d_path, path)) { + found = true; + break; + } + } + if (!found) { + handleImmediateResponse(404, "there is no endpoint configured for this path"); + return; + } + } + + /* the responses map can be updated at runtime, so we need to take a copy of + the shared pointer, increasing the reference counter */ + auto responsesMap = d_ci.cs->dohFrontend->d_responsesMap; + if (responsesMap) { + for (const auto& entry : *responsesMap) { + if (entry->matches(query.d_path)) { + const auto& customHeaders = entry->getHeaders(); + query.d_buffer = entry->getContent(); + if (entry->getStatusCode() >= 400 && query.d_buffer.size() >= 1) { + // legacy trailing 0 from the h2o era + query.d_buffer.pop_back(); + } + + sendResponse(streamID, entry->getStatusCode(), customHeaders ? *customHeaders : d_ci.cs->dohFrontend->d_customResponseHeaders, std::string(), false); + return; + } + } + } + + if (query.d_buffer.empty() && query.d_method == PendingQuery::Method::Get && !query.d_queryString.empty()) { + auto payload = getPayloadFromPath(query.d_queryString); + if (payload) { + query.d_buffer = std::move(*payload); + } + else { + ++d_ci.cs->dohFrontend->d_badrequests; + handleImmediateResponse(400, "DoH unable to decode BASE64-URL"); + return; + } + } + + if (query.d_method == PendingQuery::Method::Get) { + ++d_ci.cs->dohFrontend->d_getqueries; + } + else if (query.d_method == PendingQuery::Method::Post) { + ++d_ci.cs->dohFrontend->d_postqueries; + } + + try { + struct timeval now; + gettimeofday(&now, nullptr); + auto processingResult = handleQuery(std::move(query.d_buffer), now, streamID); + + switch (processingResult) { + case QueryProcessingResult::TooSmall: + handleImmediateResponse(400, "DoH non-compliant query"); + break; + case QueryProcessingResult::InvalidHeaders: + handleImmediateResponse(400, "DoH invalid headers"); + break; + case QueryProcessingResult::Empty: + handleImmediateResponse(200, "DoH empty query", std::move(query.d_buffer)); + break; + case QueryProcessingResult::Dropped: + handleImmediateResponse(403, "DoH dropped query"); + break; + case QueryProcessingResult::NoBackend: + handleImmediateResponse(502, "DoH no backend available"); + return; + case QueryProcessingResult::Forwarded: + case QueryProcessingResult::Asynchronous: + case QueryProcessingResult::SelfAnswered: + break; + } + } + catch (const std::exception& e) { + vinfolog("Exception while processing DoH query: %s", e.what()); + handleImmediateResponse(400, "DoH non-compliant query"); + return; + } +} + +int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) +{ + IncomingHTTP2Connection* conn = reinterpret_cast(user_data); +#if 0 + switch (frame->hd.type) { + case NGHTTP2_HEADERS: + cerr<<"got headers"<headers.cat == NGHTTP2_HCAT_RESPONSE) { + cerr<<"All headers received"<headers.cat == NGHTTP2_HCAT_REQUEST) { + cerr<<"All headers received - query"<settings.niv<settings.niv; idx++) { + cerr<<"- "<settings.iv[idx].settings_id<<" "<settings.iv[idx].value<hd.type == NGHTTP2_GOAWAY) { + conn->stopIO(); + if (conn->isIdle()) { + if (nghttp2_session_want_write(conn->d_session.get())) { + conn->d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, conn, boost::none); + } + } + } + + /* is this the last frame for this stream? */ + else if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { + auto streamID = frame->hd.stream_id; + auto stream = conn->d_currentStreams.find(streamID); + if (stream != conn->d_currentStreams.end()) { + conn->handleIncomingQuery(std::move(stream->second), streamID); + + if (conn->isIdle()) { + conn->watchForRemoteHostClosingConnection(); + } + } + else { + vinfolog("Stream %d NOT FOUND", streamID); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } + + return 0; +} + +int IncomingHTTP2Connection::on_stream_close_callback(nghttp2_session* session, IncomingHTTP2Connection::StreamID stream_id, uint32_t error_code, void* user_data) +{ + IncomingHTTP2Connection* conn = reinterpret_cast(user_data); + + if (error_code == 0) { + return 0; + } + + auto stream = conn->d_currentStreams.find(stream_id); + if (stream == conn->d_currentStreams.end()) { + /* we don't care, then */ + return 0; + } + + struct timeval now; + gettimeofday(&now, nullptr); + auto request = std::move(stream->second); + conn->d_currentStreams.erase(stream->first); + + if (conn->isIdle()) { + conn->watchForRemoteHostClosingConnection(); + } + + return 0; +} + +int IncomingHTTP2Connection::on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) +{ + if (frame->hd.type != NGHTTP2_HEADERS || frame->headers.cat != NGHTTP2_HCAT_REQUEST) { + return 0; + } + + IncomingHTTP2Connection* conn = reinterpret_cast(user_data); + auto insertPair = conn->d_currentStreams.insert({frame->hd.stream_id, PendingQuery()}); + if (!insertPair.second) { + /* there is a stream ID collision, something is very wrong! */ + vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort()); + conn->d_connectionDied = true; + nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR); + auto ret = nghttp2_session_send(conn->d_session.get()); + if (ret != 0) { + vinfolog("Error flushing HTTP response for stream %d from %s: %s", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret)); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + + return 0; + } + + return 0; +} + +static std::string::size_type getLengthOfPathWithoutParameters(const std::string_view& path) +{ + auto pos = path.find("?"); + if (pos == string::npos) { + return path.size(); + } + + return pos; +} + +int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t nameLen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data) +{ + IncomingHTTP2Connection* conn = reinterpret_cast(user_data); + + if (frame->hd.type == NGHTTP2_HEADERS && frame->headers.cat == NGHTTP2_HCAT_REQUEST) { + if (nghttp2_check_header_name(name, nameLen) == 0) { + vinfolog("Invalid header name"); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + +#if HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113 + if (nghttp2_check_header_value_rfc9113(value, valuelen) == 0) { + vinfolog("Invalid header value"); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } +#endif /* HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113 */ + + auto headerMatches = [name, nameLen](const std::string& expected) -> bool { + return nameLen == expected.size() && memcmp(name, expected.data(), expected.size()) == 0; + }; + + auto stream = conn->d_currentStreams.find(frame->hd.stream_id); + if (stream == conn->d_currentStreams.end()) { + vinfolog("Unable to match the stream ID %d to a known one!", frame->hd.stream_id); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + auto& query = stream->second; + auto valueView = std::string_view(reinterpret_cast(value), valuelen); + if (headerMatches(s_pathHeaderName)) { +#if HAVE_NGHTTP2_CHECK_PATH + if (nghttp2_check_path(value, valuelen) == 0) { + vinfolog("Invalid path value"); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } +#endif /* HAVE_NGHTTP2_CHECK_PATH */ + + auto pathLen = getLengthOfPathWithoutParameters(valueView); + query.d_path = valueView.substr(0, pathLen); + if (pathLen < valueView.size()) { + query.d_queryString = valueView.substr(pathLen); + } + } + else if (headerMatches(s_authorityHeaderName)) { + query.d_host = valueView; + } + else if (headerMatches(s_schemeHeaderName)) { + query.d_scheme = valueView; + } + else if (headerMatches(s_methodHeaderName)) { +#if HAVE_NGHTTP2_CHECK_METHOD + if (nghttp2_check_method(value, valuelen) == 0) { + vinfolog("Invalid method value"); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } +#endif /* HAVE_NGHTTP2_CHECK_METHOD */ + + if (valueView == "GET") { + query.d_method = PendingQuery::Method::Get; + } + else if (valueView == "POST") { + query.d_method = PendingQuery::Method::Post; + } + else { + vinfolog("Unsupported method value"); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + } + + if (conn->d_ci.cs->dohFrontend->d_keepIncomingHeaders || (conn->d_ci.cs->dohFrontend->d_trustForwardedForHeader && headerMatches(s_xForwardedForHeaderName))) { + if (!query.d_headers) { + query.d_headers = std::make_unique(); + } + query.d_headers->insert({std::string(reinterpret_cast(name), nameLen), std::string(valueView)}); + } + } + return 0; +} + +int IncomingHTTP2Connection::on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, IncomingHTTP2Connection::StreamID stream_id, const uint8_t* data, size_t len, void* user_data) +{ + IncomingHTTP2Connection* conn = reinterpret_cast(user_data); + auto stream = conn->d_currentStreams.find(stream_id); + if (stream == conn->d_currentStreams.end()) { + vinfolog("Unable to match the stream ID %d to a known one!", stream_id); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + if (len > std::numeric_limits::max() || (std::numeric_limits::max() - stream->second.d_buffer.size()) < len) { + vinfolog("Data frame of size %d is too large for a DNS query (we already have %d)", len, stream->second.d_buffer.size()); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + + stream->second.d_buffer.insert(stream->second.d_buffer.end(), data, data + len); + + return 0; +} + +int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data) +{ + IncomingHTTP2Connection* conn = reinterpret_cast(user_data); + + vinfolog("Error in HTTP/2 connection from %d: %s", conn->d_ci.remote.toStringWithPort(), std::string(msg, len)); + conn->d_connectionDied = true; + nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR); + auto ret = nghttp2_session_send(conn->d_session.get()); + if (ret != 0) { + vinfolog("Error flushing HTTP response on connection from %s: %s", conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret)); + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + + return 0; +} + +void IncomingHTTP2Connection::readHTTPData() +{ + IOStateGuard ioGuard(d_ioState); + do { + size_t got = 0; + d_in.resize(d_in.size() + 512); + try { + IOState newState = d_handler.tryRead(d_in, got, d_in.size(), true); + d_in.resize(got); + + if (got > 0) { + /* we got something */ + auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size()); + /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB, + all data should be consumed before returning */ + if (readlen < 0 || static_cast(readlen) < d_in.size()) { + throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen))); + } + + nghttp2_session_send(d_session.get()); + } + + if (newState == IOState::Done) { + if (isIdle()) { + watchForRemoteHostClosingConnection(); + ioGuard.release(); + break; + } + } + else { + if (newState == IOState::NeedWrite) { + updateIO(IOState::NeedWrite, handleReadableIOCallback); + } + ioGuard.release(); + break; + } + } + catch (const std::exception& e) { + vinfolog("Exception while trying to read from HTTP backend connection: %s", e.what()); + handleIOError(); + break; + } + } while (getConcurrentStreamsCount() > 0); +} + +void IncomingHTTP2Connection::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto conn = boost::any_cast>(param); + conn->handleIO(); +} + +void IncomingHTTP2Connection::handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto conn = boost::any_cast>(param); + IOStateGuard ioGuard(conn->d_ioState); + + try { + IOState newState = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size()); + if (newState == IOState::NeedRead) { + conn->updateIO(IOState::NeedRead, handleWritableIOCallback); + } + else if (newState == IOState::Done) { + conn->d_out.clear(); + conn->d_outPos = 0; + if (!conn->isIdle()) { + conn->updateIO(IOState::NeedRead, handleReadableIOCallback); + } + else { + conn->watchForRemoteHostClosingConnection(); + } + } + ioGuard.release(); + } + catch (const std::exception& e) { + vinfolog("Exception while trying to write (ready) to HTTP backend connection: %s", e.what()); + conn->handleIOError(); + } +} + +bool IncomingHTTP2Connection::isIdle() const +{ + return getConcurrentStreamsCount() == 0; +} + +void IncomingHTTP2Connection::stopIO() +{ + d_ioState->reset(); +} + +uint32_t IncomingHTTP2Connection::getConcurrentStreamsCount() const +{ + return d_currentStreams.size(); +} + +boost::optional IncomingHTTP2Connection::getIdleClientReadTTD(struct timeval now) const +{ + auto idleTimeout = d_ci.cs->dohFrontend->d_idleTimeout; + if (g_maxTCPConnectionDuration == 0 && idleTimeout == 0) { + return boost::none; + } + + if (g_maxTCPConnectionDuration > 0) { + auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec; + if (elapsed < 0 || (static_cast(elapsed) >= g_maxTCPConnectionDuration)) { + return now; + } + auto remaining = g_maxTCPConnectionDuration - elapsed; + if (idleTimeout == 0 || remaining <= static_cast(idleTimeout)) { + now.tv_sec += remaining; + return now; + } + } + + now.tv_sec += idleTimeout; + return now; +} + +void IncomingHTTP2Connection::updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback) +{ + boost::optional ttd{boost::none}; + + auto shared = std::dynamic_pointer_cast(shared_from_this()); + if (shared) { + struct timeval now; + gettimeofday(&now, nullptr); + + if (newState == IOState::NeedRead) { + if (isIdle()) { + ttd = getIdleClientReadTTD(now); + } + else { + ttd = getClientReadTTD(now); + } + d_ioState->update(newState, callback, shared, ttd); + } + else if (newState == IOState::NeedWrite) { + ttd = getClientWriteTTD(now); + d_ioState->update(newState, callback, shared, ttd); + } + } +} + +void IncomingHTTP2Connection::watchForRemoteHostClosingConnection() +{ + updateIO(IOState::NeedRead, handleReadableIOCallback); +} + +void IncomingHTTP2Connection::handleIOError() +{ + d_connectionDied = true; + nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR); + d_currentStreams.clear(); + stopIO(); +} +#endif /* HAVE_NGHTTP2 */ diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.hh b/pdns/dnsdistdist/dnsdist-nghttp2-in.hh new file mode 100644 index 0000000000..3ee1c96d10 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.hh @@ -0,0 +1,114 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include "config.h" +#ifdef HAVE_NGHTTP2 +#include + +#include "dnsdist-tcp-upstream.hh" + +class IncomingHTTP2Connection : public IncomingTCPConnectionState +{ +public: + using StreamID = int32_t; + + class PendingQuery + { + public: + enum class Method : uint8_t + { + Unknown, + Get, + Post + }; + + PacketBuffer d_buffer; + PacketBuffer d_response; + std::string d_path; + std::string d_scheme; + std::string d_host; + std::string d_queryString; + std::string d_sni; + std::string d_contentTypeOut; + std::unique_ptr d_headers; + size_t d_queryPos{0}; + uint32_t d_statusCode{0}; + Method d_method{Method::Unknown}; + }; + + IncomingHTTP2Connection(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now); + ~IncomingHTTP2Connection() = default; + void handleIO() override; + void handleResponse(const struct timeval& now, TCPResponse&& response) override; + void notifyIOError(const struct timeval& now, TCPResponse&& response) override; + void restoreContext(uint32_t streamID, PendingQuery&& context); + +private: + static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data); + static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data); + static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, StreamID stream_id, const uint8_t* data, size_t len, void* user_data); + static int on_stream_close_callback(nghttp2_session* session, StreamID stream_id, uint32_t error_code, void* user_data); + static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data); + static int on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data); + static int on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data); + static void handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param); + static void handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param); + + IOState sendResponse(const struct timeval& now, TCPResponse&& response) override; + bool forwardViaUDPFirst() const override + { + return true; + } + void restoreDOHUnit(std::unique_ptr&&) override; + std::unique_ptr getDOHUnit(uint32_t streamID) override; + + void stopIO(); + bool isIdle() const; + uint32_t getConcurrentStreamsCount() const; + void updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback); + void watchForRemoteHostClosingConnection(); + void handleIOError(); + bool sendResponse(StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true); + void handleIncomingQuery(PendingQuery&& query, StreamID streamID); + bool checkALPN(); + void readHTTPData(); + void handleConnectionReady(); + boost::optional getIdleClientReadTTD(struct timeval now) const; + + std::unique_ptr d_session{nullptr, nghttp2_session_del}; + std::unordered_map d_currentStreams; + PacketBuffer d_out; + PacketBuffer d_in; + size_t d_outPos{0}; + bool d_connectionDied{false}; +}; + +class NGHTTP2Headers +{ +public: + static void addStaticHeader(std::vector& headers, const std::string& nameKey, const std::string& valueKey); + static void addDynamicHeader(std::vector& headers, const std::string& nameKey, const std::string_view& value); + static void addCustomDynamicHeader(std::vector& headers, const std::string& name, const std::string_view& value); +}; + +#endif /* HAVE_NGHTTP2 */ diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 39e60009e0..692b732757 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -27,6 +27,7 @@ #endif /* HAVE_NGHTTP2 */ #include "dnsdist-nghttp2.hh" +#include "dnsdist-nghttp2-in.hh" #include "dnsdist-tcp.hh" #include "dnsdist-tcp-downstream.hh" #include "dnsdist-downstream-connection.hh" @@ -153,7 +154,11 @@ void DoHConnectionToBackend::handleResponse(PendingRequest&& request) } } - request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this(), d_ds)); + TCPResponse response(std::move(request.d_query)); + response.d_buffer = std::move(request.d_buffer); + response.d_connection = shared_from_this(); + response.d_ds = d_ds; + request.d_sender->handleResponse(now, std::move(response)); } catch (const std::exception& e) { vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what()); @@ -167,7 +172,8 @@ void DoHConnectionToBackend::handleResponseError(PendingRequest&& request, const d_ds->reportTimeoutOrError(); } - request.d_sender->notifyIOError(std::move(request.d_query.d_idstate), now); + TCPResponse response(PacketBuffer(), std::move(request.d_query.d_idstate), nullptr, nullptr); + request.d_sender->notifyIOError(now, std::move(response)); } catch (const std::exception& e) { vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what()); @@ -230,45 +236,6 @@ bool DoHConnectionToBackend::isIdle() const return getConcurrentStreamsCount() == 0; } -const std::unordered_map DoHConnectionToBackend::s_constants = { - {"method-name", ":method"}, - {"method-value", "POST"}, - {"scheme-name", ":scheme"}, - {"scheme-value", "https"}, - {"accept-name", "accept"}, - {"accept-value", "application/dns-message"}, - {"content-type-name", "content-type"}, - {"content-type-value", "application/dns-message"}, - {"user-agent-name", "user-agent"}, - {"user-agent-value", "nghttp2-" NGHTTP2_VERSION "/dnsdist"}, - {"authority-name", ":authority"}, - {"path-name", ":path"}, - {"content-length-name", "content-length"}, - {"x-forwarded-for-name", "x-forwarded-for"}, - {"x-forwarded-port-name", "x-forwarded-port"}, - {"x-forwarded-proto-name", "x-forwarded-proto"}, - {"x-forwarded-proto-value-dns-over-udp", "dns-over-udp"}, - {"x-forwarded-proto-value-dns-over-tcp", "dns-over-tcp"}, - {"x-forwarded-proto-value-dns-over-tls", "dns-over-tls"}, - {"x-forwarded-proto-value-dns-over-http", "dns-over-http"}, - {"x-forwarded-proto-value-dns-over-https", "dns-over-https"}, -}; - -void DoHConnectionToBackend::addStaticHeader(std::vector& headers, const std::string& nameKey, const std::string& valueKey) -{ - const auto& name = s_constants.at(nameKey); - const auto& value = s_constants.at(valueKey); - - headers.push_back({const_cast(reinterpret_cast(name.c_str())), const_cast(reinterpret_cast(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE}); -} - -void DoHConnectionToBackend::addDynamicHeader(std::vector& headers, const std::string& nameKey, const std::string& value) -{ - const auto& name = s_constants.at(nameKey); - - headers.push_back({const_cast(reinterpret_cast(name.c_str())), const_cast(reinterpret_cast(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE}); -} - void DoHConnectionToBackend::queueQuery(std::shared_ptr& sender, TCPQuery&& query) { auto payloadSize = std::to_string(query.d_buffer.size()); @@ -284,37 +251,37 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr& sender, headers.reserve(8 + (addXForwarded ? 3 : 0)); /* Pseudo-headers need to come first (rfc7540 8.1.2.1) */ - addStaticHeader(headers, "method-name", "method-value"); - addStaticHeader(headers, "scheme-name", "scheme-value"); - addDynamicHeader(headers, "authority-name", d_ds->d_config.d_tlsSubjectName); - addDynamicHeader(headers, "path-name", d_ds->d_config.d_dohPath); - addStaticHeader(headers, "accept-name", "accept-value"); - addStaticHeader(headers, "content-type-name", "content-type-value"); - addStaticHeader(headers, "user-agent-name", "user-agent-value"); - addDynamicHeader(headers, "content-length-name", payloadSize); + NGHTTP2Headers::addStaticHeader(headers, "method-name", "method-value"); + NGHTTP2Headers::addStaticHeader(headers, "scheme-name", "scheme-value"); + NGHTTP2Headers::addDynamicHeader(headers, "authority-name", d_ds->d_config.d_tlsSubjectName); + NGHTTP2Headers::addDynamicHeader(headers, "path-name", d_ds->d_config.d_dohPath); + NGHTTP2Headers::addStaticHeader(headers, "accept-name", "accept-value"); + NGHTTP2Headers::addStaticHeader(headers, "content-type-name", "content-type-value"); + NGHTTP2Headers::addStaticHeader(headers, "user-agent-name", "user-agent-value"); + NGHTTP2Headers::addDynamicHeader(headers, "content-length-name", payloadSize); /* no need to add these headers for health-check queries */ if (addXForwarded && query.d_idstate.origRemote.getPort() != 0) { remote = query.d_idstate.origRemote.toString(); remotePort = std::to_string(query.d_idstate.origRemote.getPort()); - addDynamicHeader(headers, "x-forwarded-for-name", remote); - addDynamicHeader(headers, "x-forwarded-port-name", remotePort); + NGHTTP2Headers::addDynamicHeader(headers, "x-forwarded-for-name", remote); + NGHTTP2Headers::addDynamicHeader(headers, "x-forwarded-port-name", remotePort); if (query.d_idstate.cs != nullptr) { if (query.d_idstate.cs->isUDP()) { - addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-udp"); + NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-udp"); } else if (query.d_idstate.cs->isDoH()) { if (query.d_idstate.cs->hasTLS()) { - addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-https"); + NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-https"); } else { - addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-http"); + NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-http"); } } else if (query.d_idstate.cs->hasTLS()) { - addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tls"); + NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tls"); } else { - addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tcp"); + NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tcp"); } } } @@ -920,7 +887,8 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par downstream->queueQuery(tqs, std::move(query)); } catch (...) { - tqs->notifyIOError(std::move(query.d_idstate), now); + TCPResponse response(std::move(query)); + tqs->notifyIOError(now, std::move(response)); } } diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index 6c6fcf2229..43de71fc58 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -173,7 +173,7 @@ static uint32_t getSerialFromRawSOAContent(const std::vector& raw) static bool getSerialFromIXFRQuery(TCPQuery& query) { try { - size_t proxyPayloadSize = query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0; + size_t proxyPayloadSize = query.d_proxyProtocolPayloadAdded ? query.d_idstate.d_proxyProtocolPayloadSize : 0; if (query.d_buffer.size() <= (proxyPayloadSize + sizeof(uint16_t))) { return false; } @@ -232,24 +232,24 @@ static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState quer if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) { query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end()); query.d_proxyProtocolPayloadAdded = true; - query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size(); + query.d_idstate.d_proxyProtocolPayloadSize = query.d_proxyProtocolPayload.size(); } } else if (connectionState == ConnectionState::proxySent) { if (query.d_proxyProtocolPayloadAdded) { - if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) { + if (query.d_buffer.size() < query.d_idstate.d_proxyProtocolPayloadSize) { throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size())); } - query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize); + query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_idstate.d_proxyProtocolPayloadSize); query.d_proxyProtocolPayloadAdded = false; - query.d_proxyProtocolPayloadAddedSize = 0; + query.d_idstate.d_proxyProtocolPayloadSize = 0; } } if (query.d_idstate.qclass == QClass::IN && query.d_idstate.qtype == QType::IXFR) { getSerialFromIXFRQuery(query); } - editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true); + editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_idstate.d_proxyProtocolPayloadSize : 0, true); } IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr& conn) @@ -433,7 +433,8 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c /* this one can't be restarted, sorry */ DEBUGLOG("A XFR for which a response has already been sent cannot be restarted"); try { - pending.second.d_sender->notifyIOError(std::move(pending.second.d_query.d_idstate), now); + TCPResponse response(std::move(pending.second.d_query)); + pending.second.d_sender->notifyIOError(now, std::move(response)); } catch (const std::exception& e) { vinfolog("Got an exception while notifying: %s", e.what()); @@ -608,7 +609,8 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F increaseCounters(d_currentQuery.d_query.d_idstate.cs); auto sender = d_currentQuery.d_sender; if (sender->active()) { - sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now); + TCPResponse response(std::move(d_currentQuery.d_query)); + sender->notifyIOError(now, std::move(response)); } } @@ -616,7 +618,8 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F increaseCounters(query.d_query.d_idstate.cs); auto sender = query.d_sender; if (sender->active()) { - sender->notifyIOError(std::move(query.d_query.d_idstate), now); + TCPResponse response(std::move(query.d_query)); + sender->notifyIOError(now, std::move(response)); } } @@ -624,7 +627,8 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F increaseCounters(response.second.d_query.d_idstate.cs); auto sender = response.second.d_sender; if (sender->active()) { - sender->notifyIOError(std::move(response.second.d_query.d_idstate), now); + TCPResponse tresp(std::move(response.second.d_query)); + sender->notifyIOError(now, std::move(tresp)); } } } @@ -726,7 +730,8 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptractive()) { DEBUGLOG("passing response to client connection for "<handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds)); + TCPResponse response(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds); + sender->handleResponse(now, std::move(response)); } if (!d_pendingQueries.empty()) { diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index b668c2f9eb..4318892659 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -2,6 +2,7 @@ #include "dolog.hh" #include "dnsdist-tcp.hh" +#include "dnsdist-tcp-downstream.hh" struct TCPCrossProtocolResponse; @@ -26,7 +27,10 @@ public: class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this { public: - IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id()) + enum class QueryProcessingResult : uint8_t { Forwarded, TooSmall, InvalidHeaders, Empty, Dropped, SelfAnswered, NoBackend, Asynchronous }; + enum class ProxyProtocolResult : uint8_t { Reading, Done, Error }; + + IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext.getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id()) { d_origDest.reset(); d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family; @@ -46,7 +50,7 @@ public: IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete; IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete; - ~IncomingTCPConnectionState(); + virtual ~IncomingTCPConnectionState(); void resetForNewQuery(); @@ -118,24 +122,27 @@ public: static size_t clearAllDownstreamConnections(); - static void handleIO(std::shared_ptr& conn, const struct timeval& now); static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); static void handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param); static void updateIO(std::shared_ptr& state, IOState newState, const struct timeval& now); - static IOState sendResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); static void queueResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); -static void handleTimeout(std::shared_ptr& state, bool write); + static void handleTimeout(std::shared_ptr& state, bool write); + + virtual void handleIO(); - /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */ - void handleResponse(const struct timeval& now, TCPResponse&& response) override; + QueryProcessingResult handleQuery(PacketBuffer&& query, const struct timeval& now, std::optional streamID); + virtual void handleResponse(const struct timeval& now, TCPResponse&& response) override; + virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) override; void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override; - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override; + virtual IOState sendResponse(const struct timeval& now, TCPResponse&& response); + void handleResponseSent(TCPResponse& currentResponse); + void handleHandshakeDone(const struct timeval& now); + ProxyProtocolResult handleProxyProtocolPayload(); void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response); void terminateClientConnection(); - void queueQuery(TCPQuery&& query); bool canAcceptNewQueries(const struct timeval& now); @@ -143,6 +150,20 @@ static void handleTimeout(std::shared_ptr& state, bo { return d_ioState != nullptr; } + virtual bool forwardViaUDPFirst() const + { + return false; + } + virtual std::unique_ptr getDOHUnit(uint32_t streamID) + { + throw std::runtime_error("Getting a DOHUnit state from a generic TCP/DoT connection is not supported"); + } + virtual void restoreDOHUnit(std::unique_ptr&&) + { + throw std::runtime_error("Restoring a DOHUnit state to a generic TCP/DoT connection is not supported"); + } + + std::unique_ptr getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr& ds); std::string toString() const { @@ -151,6 +172,8 @@ static void handleTimeout(std::shared_ptr& state, bo return o.str(); } + dnsdist::Protocol getProtocol() const; + enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ }; TCPResponse d_currentResponse; diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index d5f2edb0d1..aef6cf6ec3 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -21,6 +21,7 @@ */ #pragma once +#include #include #include "channel.hh" #include "iputils.hh" @@ -100,7 +101,6 @@ public: InternalQueryState d_idstate; std::string d_proxyProtocolPayload; PacketBuffer d_buffer; - uint32_t d_proxyProtocolPayloadAddedSize{0}; uint32_t d_ixfrQuerySerial{0}; uint32_t d_xfrMasterSerial{0}; uint32_t d_xfrSerialCount{0}; @@ -133,6 +133,17 @@ struct TCPResponse : public TCPQuery } } + TCPResponse(TCPQuery&& query) : + TCPQuery(std::move(query)) + { + if (d_buffer.size() >= sizeof(dnsheader)) { + memcpy(&d_cleartextDH, reinterpret_cast(d_buffer.data()), sizeof(d_cleartextDH)); + } + else { + memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); + } + } + bool isAsync() const { return d_async; @@ -154,7 +165,7 @@ public: virtual bool active() const = 0; virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0; virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0; - virtual void notifyIOError(InternalQueryState&& query, const struct timeval& now) = 0; + virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) = 0; /* whether the connection should be automatically released to the pool after handleResponse() has been called */ @@ -199,7 +210,6 @@ struct CrossProtocolQuery InternalQuery query; std::shared_ptr downstream{nullptr}; - size_t proxyProtocolPayloadSize{0}; bool d_isResponse{false}; }; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index 91dcd9ad76..a2747400f4 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -534,8 +534,9 @@ public: return handleResponse(now, std::move(response)); } - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override + void notifyIOError(const struct timeval& now, TCPResponse&& response) override { + auto& query = response.d_idstate; if (!query.du) { return; } @@ -1041,7 +1042,7 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req) if (!holders.acl->match(remote)) { ++dnsdist::metrics::g_stats.aclDrops; vinfolog("Query from %s (DoH) dropped because of ACL", remote.toStringWithPort()); - h2o_send_error_403(req, "Forbidden", "dns query not allowed because of ACL", 0); + h2o_send_error_403(req, "Forbidden", "DoH query not allowed because of ACL", 0); return 0; } @@ -1344,6 +1345,13 @@ static void on_accept(h2o_socket_t *listener, const char *err) return; } + if (dsc->df->d_earlyACLDrop && !dsc->df->d_trustForwardedForHeader && !dsc->holders.acl->match(remote)) { + ++dnsdist::metrics::g_stats.aclDrops; + vinfolog("Dropping DoH connection from %s because of ACL", remote.toStringWithPort()); + h2o_socket_close(sock); + return; + } + if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) { vinfolog("Dropping DoH connection from %s because we have too many from this client already", remote.toStringWithPort()); h2o_socket_close(sock); diff --git a/pdns/dnsdistdist/m4/dnsdist_enable_doh.m4 b/pdns/dnsdistdist/m4/dnsdist_enable_doh.m4 index 876a21890f..baf9118e67 100644 --- a/pdns/dnsdistdist/m4/dnsdist_enable_doh.m4 +++ b/pdns/dnsdistdist/m4/dnsdist_enable_doh.m4 @@ -1,7 +1,7 @@ AC_DEFUN([DNSDIST_ENABLE_DNS_OVER_HTTPS], [ AC_MSG_CHECKING([whether to enable incoming DNS over HTTPS (DoH) support]) AC_ARG_ENABLE([dns-over-https], - AS_HELP_STRING([--enable-dns-over-https], [enable incoming DNS over HTTPS (DoH) support (requires libh2o) @<:@default=no@:>@]), + AS_HELP_STRING([--enable-dns-over-https], [enable incoming DNS over HTTPS (DoH) support (requires libh2o or nghttp2) @<:@default=no@:>@]), [enable_dns_over_https=$enableval], [enable_dns_over_https=no] ) diff --git a/pdns/dnsdistdist/m4/pdns_with_nghttp2.m4 b/pdns/dnsdistdist/m4/pdns_with_nghttp2.m4 index 8305b2b906..273385cf24 100644 --- a/pdns/dnsdistdist/m4/pdns_with_nghttp2.m4 +++ b/pdns/dnsdistdist/m4/pdns_with_nghttp2.m4 @@ -13,6 +13,13 @@ AC_DEFUN([PDNS_WITH_NGHTTP2], [ PKG_CHECK_MODULES([NGHTTP2], [libnghttp2], [ [HAVE_NGHTTP2=1] AC_DEFINE([HAVE_NGHTTP2], [1], [Define to 1 if you have nghttp2]) + save_CFLAGS=$CFLAGS + save_LIBS=$LIBS + CFLAGS="$NGHTTP2_CFLAGS $CFLAGS" + LIBS="$NGHTTP2_LIBS $LIBS" + AC_CHECK_FUNCS([nghttp2_check_header_value_rfc9113 nghttp2_check_method nghttp2_check_path]) + CFLAGS=$save_CFLAGS + LIBS=$save_LIBS ], [ : ]) ]) ]) diff --git a/pdns/dnsdistdist/test-dnsdistasync.cc b/pdns/dnsdistdist/test-dnsdistasync.cc index b1fbebca05..65a9c4a53f 100644 --- a/pdns/dnsdistdist/test-dnsdistasync.cc +++ b/pdns/dnsdistdist/test-dnsdistasync.cc @@ -44,7 +44,7 @@ public: { } - void notifyIOError(InternalQueryState&&, const struct timeval&) override + void notifyIOError(const struct timeval&, TCPResponse&&) override { errorRaised = true; } diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc index 9d437578f7..19cafc004c 100644 --- a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc @@ -34,7 +34,6 @@ std::vector> g_frontends; /* add stub implementations, we don't want to include the corresponding object files and their dependencies */ -// NOLINTNEXTLINE(readability-convert-member-functions-to-static): this is a stub, the real one is not that simple.. bool TLSFrontend::setupTLS() { return true; diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc index 41d9992cda..c43297b53e 100644 --- a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc @@ -626,11 +626,11 @@ public: d_valid = true; } - void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override + void handleXFRResponse(const struct timeval&, TCPResponse&&) override { } - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override + void notifyIOError(const struct timeval&, TCPResponse&&) override { d_error = true; } diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index 2aa5adbe20..22e137c24b 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -500,7 +500,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); } @@ -523,7 +523,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size()); BOOST_CHECK(s_writeBuffer == query); } @@ -558,7 +558,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) dynamic_cast(threadData.mplexer.get())->setReady(-1); auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -582,7 +582,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); } @@ -610,7 +610,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * count); #endif } @@ -636,7 +636,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) dynamic_cast(threadData.mplexer.get())->setNotReady(-1); auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0); struct timeval later = now; later.tv_sec += g_tcpRecvTimeout + 1; @@ -672,7 +672,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) dynamic_cast(threadData.mplexer.get())->setNotReady(-1); auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0); struct timeval later = now; later.tv_sec += g_tcpRecvTimeout + 1; @@ -705,7 +705,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); } } @@ -766,7 +766,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered) dynamic_cast(threadData.mplexer.get())->setNotReady(-1); auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0); BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * 2U); } @@ -793,7 +793,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); } @@ -823,7 +823,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered) dynamic_cast(threadData.mplexer.get())->setNotReady(-1); auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0); struct timeval later = now; later.tv_sec += g_tcpRecvTimeout + 1; @@ -903,7 +903,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size()); BOOST_CHECK(s_writeBuffer == query); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size()); @@ -943,7 +943,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size()); BOOST_CHECK(s_backendWriteBuffer == query); @@ -982,7 +982,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size()); BOOST_CHECK(s_backendWriteBuffer == query); @@ -1025,7 +1025,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size()); BOOST_CHECK(s_backendWriteBuffer == query); @@ -1052,7 +1052,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); @@ -1090,7 +1090,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size()); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); @@ -1160,7 +1160,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) /* set the incoming descriptor as ready! */ dynamic_cast(threadData.mplexer.get())->setReady(-1); auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -1221,7 +1221,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); @@ -1257,7 +1257,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); struct timeval later = now; later.tv_sec += backend->d_config.tcpSendTimeout + 1; auto expiredWriteConns = threadData.mplexer->getTimeouts(later, true); @@ -1303,7 +1303,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); struct timeval later = now; later.tv_sec += backend->d_config.tcpRecvTimeout + 1; auto expiredConns = threadData.mplexer->getTimeouts(later, false); @@ -1360,7 +1360,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); @@ -1416,7 +1416,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size()); BOOST_CHECK(s_writeBuffer == query); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size()); @@ -1475,7 +1475,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); @@ -1527,7 +1527,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size() * backend->d_config.d_retries); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); @@ -1587,7 +1587,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size()); BOOST_CHECK(s_writeBuffer == query); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size() * backend->d_config.d_retries); @@ -1628,7 +1628,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U); BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size()); BOOST_CHECK(s_backendWriteBuffer == query); @@ -1690,7 +1690,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * count); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); @@ -1732,7 +1732,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U); /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */ @@ -1916,7 +1916,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -2048,7 +2048,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); @@ -2228,7 +2228,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); @@ -2304,7 +2304,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -2387,7 +2387,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while ((threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -2504,7 +2504,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -2656,7 +2656,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -2863,7 +2863,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -3037,7 +3037,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -3301,7 +3301,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -3427,7 +3427,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -3512,7 +3512,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -3577,7 +3577,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -3768,7 +3768,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -3853,7 +3853,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } @@ -4085,7 +4085,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) { threadData.mplexer->run(&now); } @@ -4137,7 +4137,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR) }; auto state = std::make_shared(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + state->handleIO(); while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) { threadData.mplexer->run(&now); } diff --git a/pdns/doh.hh b/pdns/doh.hh index 58a26f1691..c482b7a0fa 100644 --- a/pdns/doh.hh +++ b/pdns/doh.hh @@ -51,7 +51,7 @@ public: size_t getTicketsKeysCount() override; }; -void dohThread(ClientState* clientState); +void dohThread(ClientState* cs); #endif /* HAVE_LIBH2OEVLOOP */ #endif /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 1b7018c028..78f23f9df4 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -1850,6 +1850,7 @@ bool TLSFrontend::setupTLS() newCtx = std::make_shared(*this); } #endif /* HAVE_LIBSSL */ + if (!newCtx) { #ifdef HAVE_LIBSSL newCtx = std::make_shared(*this); @@ -1874,7 +1875,7 @@ bool TLSFrontend::setupTLS() std::shared_ptr getTLSContext([[maybe_unused]] const TLSContextParameters& params) { -#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) +#ifdef HAVE_DNS_OVER_TLS /* get the "best" available provider */ if (!params.d_provider.empty()) { #ifdef HAVE_GNUTLS @@ -1897,6 +1898,6 @@ std::shared_ptr getTLSContext([[maybe_unused]] const TLSContextParameter #endif /* HAVE_GNUTLS */ #endif /* HAVE_LIBSSL */ -#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */ +#endif /* HAVE_DNS_OVER_TLS */ return nullptr; } diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index 29b59a01f9..5e1d23e737 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -138,7 +138,7 @@ class TLSFrontend public: enum class ALPN : uint8_t { Unset, DoT, DoH }; - TLSFrontend(ALPN alpn) : d_alpn(alpn) + TLSFrontend(ALPN alpn): d_alpn(alpn) { } @@ -233,7 +233,6 @@ protected: class TCPIOHandler { public: - enum class Type : uint8_t { Client, Server }; TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr ctx): d_socket(socket) { diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index c51a930c04..850273eb8a 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -56,7 +56,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query) { - return true; + return false; } namespace dnsdist { diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 156ba19219..6bc56cdb7a 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -624,7 +624,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): return sock @classmethod - def openTLSConnection(cls, port, serverName, caCert=None, timeout=None): + def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) if timeout: @@ -633,6 +633,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): # 2.7.9+ if hasattr(ssl, 'create_default_context'): sslctx = ssl.create_default_context(cafile=caCert) + if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'): + sslctx.set_alpn_protocols(alpn) sslsock = sslctx.wrap_socket(sock, server_hostname=serverName) else: sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED) @@ -992,6 +994,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): #conn.setopt(pycurl.VERBOSE, True) conn.setopt(pycurl.URL, url) conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) + # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" + conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) if useHTTPS: conn.setopt(pycurl.SSL_VERIFYPEER, 1) conn.setopt(pycurl.SSL_VERIFYHOST, 2) @@ -1036,6 +1040,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): #conn.setopt(pycurl.VERBOSE, True) conn.setopt(pycurl.URL, url) conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) + # this means "really do HTTP/2, not HTTP/1 with Upgrade headers" + conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE) if useHTTPS: conn.setopt(pycurl.SSL_VERIFYPEER, 1) conn.setopt(pycurl.SSL_VERIFYHOST, 2) diff --git a/regression-tests.dnsdist/test_DOH.py b/regression-tests.dnsdist/test_DOH.py index 5999021f5d..d4d6606faf 100644 --- a/regression-tests.dnsdist/test_DOH.py +++ b/regression-tests.dnsdist/test_DOH.py @@ -573,7 +573,7 @@ class TestDOHSubPaths(DNSDistDOHTest): # this path is not in the URLs map and should lead to a 404 (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL + "NotPowerDNS", query, caFile=self._caCert, useQueue=False, rawResponse=True) self.assertTrue(receivedResponse) - self.assertEqual(receivedResponse, b'not found') + self.assertIn(receivedResponse, [b'there is no endpoint configured for this path', b'not found']) self.assertEqual(self._rcode, 404) # this path is below one in the URLs map and exactPathMatching is false, so we should be good @@ -1116,7 +1116,7 @@ class TestDOHForwardedFor(DNSDistDOHTest): (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 127.0.0.1:42, 127.0.0.1']) self.assertEqual(self._rcode, 403) - self.assertEqual(receivedResponse, b'dns query not allowed because of ACL') + self.assertEqual(receivedResponse, b'DoH query not allowed because of ACL') class TestDOHForwardedForNoTrusted(DNSDistDOHTest): @@ -1130,7 +1130,7 @@ class TestDOHForwardedForNoTrusted(DNSDistDOHTest): newServer{address="127.0.0.1:%s"} setACL('192.0.2.1/32') - addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }) + addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }, {earlyACLDrop=true}) """ _config_params = ['_testServerPort', '_dohServerPort', '_serverCert', '_serverKey'] @@ -1151,10 +1151,15 @@ class TestDOHForwardedForNoTrusted(DNSDistDOHTest): '127.0.0.1') response.answer.append(rrset) - (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 192.0.2.1:4200']) + dropped = False + try: + (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 192.0.2.1:4200']) + self.assertEqual(self._rcode, 403) + self.assertEqual(receivedResponse, b'DoH query not allowed because of ACL') + except pycurl.error as e: + dropped = True - self.assertEqual(self._rcode, 403) - self.assertEqual(receivedResponse, b'dns query not allowed because of ACL') + self.assertTrue(dropped) class TestDOHFrontendLimits(DNSDistDOHTest): @@ -1190,7 +1195,7 @@ class TestDOHFrontendLimits(DNSDistDOHTest): for idx in range(self._maxTCPConnsPerDOHFrontend + 1): try: - conns.append(self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert)) + conns.append(self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert, alpn=['h2'])) except: conns.append(None) diff --git a/regression-tests.dnsdist/test_Protobuf.py b/regression-tests.dnsdist/test_Protobuf.py index e4f0b0232c..092cdef745 100644 --- a/regression-tests.dnsdist/test_Protobuf.py +++ b/regression-tests.dnsdist/test_Protobuf.py @@ -546,7 +546,6 @@ class TestProtobufMetaDOH(DNSDistProtobufTest): elif method == "sendDOHQueryWrapper": pbMessageType = dnsmessage_pb2.PBDNSMessage.DOH - print(method) self.checkProtobufQuery(msg, pbMessageType, query, dns.rdataclass.IN, dns.rdatatype.A, name) self.assertEqual(len(msg.meta), 5) tags = {}