From: Charles-Henri Bruyand Date: Tue, 21 Nov 2023 22:49:35 +0000 (+0100) Subject: dnsdist: add beta support for incoming DNS over HTTP/3 X-Git-Tag: dnsdist-1.9.0-alpha4~15^2~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e891e0b9fde7c50ce5f3f11fc18b5f50920bd3e6;p=thirdparty%2Fpdns.git dnsdist: add beta support for incoming DNS over HTTP/3 --- diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index b3c2443456..38d5cb8cf9 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -470,6 +470,7 @@ const std::vector g_consoleKeywords{ { "addConsoleACL", true, "netmask", "add a netmask to the console ACL" }, { "addDNSCryptBind", true, "\"127.0.0.1:8443\", \"provider name\", \"/path/to/resolver.cert\", \"/path/to/resolver.key\", {reusePort=false, tcpFastOpenQueueSize=0, interface=\"\", cpus={}}", "listen to incoming DNSCrypt queries on 127.0.0.1 port 8443, with a provider name of `provider name`, using a resolver certificate and associated key stored respectively in the `resolver.cert` and `resolver.key` files. The fifth optional parameter is a table of parameters" }, { "addDOHLocal", true, "addr, certFile, keyFile [, urls [, vars]]", "listen to incoming DNS over HTTPS queries on the specified address using the specified certificate and key. The last two parameters are tables" }, + { "addDOH3Local", true, "addr, certFile, keyFile [, vars]", "listen to incoming DNS over HTTP/3 queries on the specified address using the specified certificate and key. The last parameter is a table" }, { "addDOQLocal", true, "addr, certFile, keyFile [, vars]", "listen to incoming DNS over QUIC queries on the specified address using the specified certificate and key. The last parameter is a table" }, { "addDynamicBlock", true, "address, message[, action [, seconds [, clientIPMask [, clientIPPortMask]]]]", "block the supplied address with message `msg`, for `seconds` seconds (10 by default), applying `action` (default to the one set with `setDynBlocksAction()`)" }, { "addDynBlocks", true, "addresses, message[, seconds[, action]]", "block the set of addresses with message `msg`, for `seconds` seconds (10 by default), applying `action` (default to the one set with `setDynBlocksAction()`)" }, diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index e1fb8d38e1..73d5f6e5e3 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -36,6 +36,7 @@ struct ClientState; struct DOHUnitInterface; struct DOQUnit; +struct DOH3Unit; class DNSCryptQuery; class DNSDistPacketCache; @@ -139,6 +140,7 @@ struct InternalQueryState size_t d_proxyProtocolPayloadSize{0}; // 8 int32_t d_streamID{-1}; // 4 std::unique_ptr doqu{nullptr}; // 8 + std::unique_ptr doh3u{nullptr}; // 8 uint32_t cacheKey{0}; // 4 uint32_t cacheKeyNoECS{0}; // 4 // DoH-only */ diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index faca89bbae..86afffe5c0 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -2585,6 +2585,79 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) #endif /* HAVE_DNS_OVER_HTTPS */ }); + luaCtx.writeFunction("addDOH3Local", [client](const std::string& addr, const boost::variant, LuaArray, LuaArray>>& certFiles, const boost::variant>& keyFiles, boost::optional vars) { + if (client) { + return; + } +#ifdef HAVE_DNS_OVER_HTTP3 + if (!checkConfigurationTime("addDOH3Local")) { + return; + } + setLuaSideEffect(); + + auto frontend = std::make_shared(); + if (!loadTLSCertificateAndKeys("addDOH3Local", frontend->d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) { + return; + } + frontend->d_local = ComboAddress(addr, 853); + + bool reusePort = false; + int tcpFastOpenQueueSize = 0; + int tcpListenQueueSize = 0; + uint64_t maxInFlightQueriesPerConn = 0; + uint64_t tcpMaxConcurrentConnections = 0; + std::string interface; + std::set cpus; + std::vector> additionalAddresses; + + if (vars) { + parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections); + if (maxInFlightQueriesPerConn > 0) { + frontend->d_maxInFlight = maxInFlightQueriesPerConn; + } + getOptionalValue(vars, "internalPipeBufferSize", frontend->d_internalPipeBufferSize); + getOptionalValue(vars, "idleTimeout", frontend->d_idleTimeout); + getOptionalValue(vars, "keyLogFile", frontend->d_keyLogFile); + { + std::string valueStr; + if (getOptionalValue(vars, "congestionControlAlgo", valueStr) > 0) { + if (DOH3Frontend::s_available_cc_algorithms.count(valueStr) > 0) { + frontend->d_ccAlgo = valueStr; + } + else { + warnlog("Ignoring unknown value '%s' for 'congestionControlAlgo' on 'addDOH3Local'", valueStr); + } + } + } + parseTLSConfig(frontend->d_tlsConfig, "addDOH3Local", vars); + + bool ignoreTLSConfigurationErrors = false; + if (getOptionalValue(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) { + // we are asked to try to load the certificates so we can return a potential error + // and properly ignore the frontend before actually launching it + try { + std::map ocspResponses = {}; + auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses); + } + catch (const std::runtime_error& e) { + errlog("Ignoring DoH3 frontend: '%s'", e.what()); + return; + } + } + + checkAllParametersConsumed("addDOH3Local", vars); + } + g_doh3locals.push_back(frontend); + auto cs = std::make_unique(frontend->d_local, false, reusePort, tcpFastOpenQueueSize, interface, cpus); + cs->doh3Frontend = frontend; + cs->d_additionalAddresses = std::move(additionalAddresses); + + g_frontends.push_back(std::move(cs)); +#else + throw std::runtime_error("addDOH3Local() called but DNS over HTTP/3 support is not present!"); +#endif + }); + // NOLINTNEXTLINE(performance-unnecessary-value-param): somehow clang-tidy gets confused about the fact vars could be const while it cannot luaCtx.writeFunction("addDOQLocal", [client](const std::string& addr, const boost::variant, LuaArray, LuaArray>>& certFiles, const boost::variant>& keyFiles, boost::optional vars) { if (client) { diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 77c7f8d3a9..4d3105a120 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -109,6 +109,7 @@ string g_outputBuffer; std::vector> g_tlslocals; std::vector> g_dohlocals; std::vector> g_doqlocals; +std::vector> g_doh3locals; std::vector> g_dnsCryptLocals; shared_ptr g_defaultBPFFilter{nullptr}; @@ -323,6 +324,12 @@ static void doLatencyStats(dnsdist::Protocol protocol, double udiff) doAvg(dnsdist::metrics::g_stats.latencyDoQAvg10000, udiff, 10000); doAvg(dnsdist::metrics::g_stats.latencyDoQAvg1000000, udiff, 1000000); } + else if (protocol == dnsdist::Protocol::DoH3) { + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg100, udiff, 100); + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg1000, udiff, 1000); + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg10000, udiff, 10000); + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg1000000, udiff, 1000000); + } } bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote, unsigned int& qnameWireLength) @@ -2414,6 +2421,8 @@ static void setupLocalSocket(ClientState& clientState, const ComboAddress& addr, } else { if (clientState.doqFrontend != nullptr) { infolog("Listening on %s for DoQ", addr.toStringWithPort()); + } else if (clientState.doh3Frontend != nullptr) { + infolog("Listening on %s for DoH3", addr.toStringWithPort()); } } } @@ -2444,6 +2453,9 @@ static void setUpLocalBind(std::unique_ptr& cstate) if (cstate->doqFrontend != nullptr) { cstate->doqFrontend->setup(); } + if (cstate->doh3Frontend != nullptr) { + cstate->doh3Frontend->setup(); + } cstate->ready = true; } @@ -2830,7 +2842,7 @@ static void initFrontends() if (!g_cmdLine.locals.empty()) { for (auto it = g_frontends.begin(); it != g_frontends.end(); ) { /* DoH, DoT and DNSCrypt frontends are separate */ - if ((*it)->dohFrontend == nullptr && (*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr && (*it)->doqFrontend == nullptr) { + if ((*it)->dohFrontend == nullptr && (*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr && (*it)->doqFrontend == nullptr && (*it)->doh3Frontend == nullptr) { it = g_frontends.erase(it); } else { @@ -3085,6 +3097,16 @@ int main(int argc, char** argv) #endif /* HAVE_DNS_OVER_QUIC */ continue; } + if (cs->doh3Frontend != nullptr) { +#ifdef HAVE_DNS_OVER_HTTP3 + std::thread t1(doh3Thread, cs.get()); + if (!cs->cpus.empty()) { + mapThreadToCPUList(t1.native_handle(), cs->cpus); + } + t1.detach(); +#endif /* HAVE_DNS_OVER_HTTP3 */ + continue; + } if (cs->udpFD >= 0) { #ifdef USE_SINGLE_ACCEPTOR_THREAD udpStates.push_back(cs.get()); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index aa64c3f0c3..0232b2b962 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -45,6 +45,7 @@ #include "dnsname.hh" #include "dnsdist-doh-common.hh" #include "doq.hh" +#include "doh3.hh" #include "ednsoptions.hh" #include "iputils.hh" #include "misc.hh" @@ -501,6 +502,7 @@ struct ClientState std::shared_ptr tlsFrontend{nullptr}; std::shared_ptr dohFrontend{nullptr}; std::shared_ptr doqFrontend{nullptr}; + std::shared_ptr doh3Frontend{nullptr}; std::shared_ptr d_filter{nullptr}; size_t d_maxInFlightQueriesPerConn{1}; size_t d_tcpConcurrentConnectionsLimit{0}; @@ -578,6 +580,9 @@ struct ClientState if (doqFrontend) { result += " (DNS over QUIC)"; } + else if (doh3Frontend) { + result += " (DNS over HTTP/3)"; + } else if (dohFrontend) { if (dohFrontend->isHTTPS()) { result += " (DNS over HTTPS)"; @@ -1067,6 +1072,7 @@ extern ComboAddress g_serverControl; // not changed during runtime extern std::vector> g_tlslocals; extern std::vector> g_dohlocals; extern std::vector> g_doqlocals; +extern std::vector> g_doh3locals; extern std::vector> g_frontends; extern bool g_truncateTC; extern bool g_fixupCase; diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index bd276cc03c..be0f936c54 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -215,6 +215,7 @@ dnsdist_SOURCES = \ doh.hh \ dolog.cc dolog.hh \ doq.hh \ + doh3.hh \ ednscookies.cc ednscookies.hh \ ednsextendederror.cc ednsextendederror.hh \ ednsoptions.cc ednsoptions.hh \ @@ -454,11 +455,16 @@ endif if HAVE_DNS_OVER_QUIC dnsdist_SOURCES += doq.cc +endif + +if HAVE_DNS_OVER_HTTP3 +dnsdist_SOURCES += doh3.cc +endif + if HAVE_QUICHE AM_CPPFLAGS += $(QUICHE_CFLAGS) dnsdist_LDADD += $(QUICHE_LDFLAGS) $(QUICHE_LIBS) endif -endif if !HAVE_LUA_HPP BUILT_SOURCES += lua.hpp diff --git a/pdns/dnsdistdist/configure.ac b/pdns/dnsdistdist/configure.ac index 9b3055c496..ec972af114 100644 --- a/pdns/dnsdistdist/configure.ac +++ b/pdns/dnsdistdist/configure.ac @@ -82,6 +82,7 @@ DNSDIST_ENABLE_TLS_PROVIDERS PDNS_ENABLE_DNS_OVER_TLS DNSDIST_ENABLE_DNS_OVER_HTTPS DNSDIST_ENABLE_DNS_OVER_QUIC +DNSDIST_ENABLE_DNS_OVER_HTTP3 AS_IF([test "x$enable_dns_over_tls" != "xno" -o "x$enable_dns_over_https" != "xno" -o "x$enable_dns_over_quic" != "xno" ], [ PDNS_WITH_LIBSSL @@ -118,6 +119,12 @@ AS_IF([test "x$enable_dns_over_quic" != "xno"], [ ]) ]) +AS_IF([test "x$enable_dns_over_http3" != "xno"], [ + AS_IF([test "x$HAVE_QUICHE" != "x1"], [ + AC_MSG_ERROR([DNS over HTTP/3 support requested but quiche was not found]) + ]) +]) + DNSDIST_WITH_CDB PDNS_CHECK_LMDB PDNS_ENABLE_IPCIPHER diff --git a/pdns/dnsdistdist/dnsdist-internal-queries.cc b/pdns/dnsdistdist/dnsdist-internal-queries.cc index 9f6a3c40d3..b707fefb9b 100644 --- a/pdns/dnsdistdist/dnsdist-internal-queries.cc +++ b/pdns/dnsdistdist/dnsdist-internal-queries.cc @@ -49,6 +49,11 @@ std::unique_ptr getInternalQueryFromDQ(DNSQuestion& dq, bool else if (protocol == dnsdist::Protocol::DoQ) { return getDOQCrossProtocolQueryFromDQ(dq, isResponse); } +#endif +#ifdef HAVE_DNS_OVER_HTTP3 + else if (protocol == dnsdist::Protocol::DoH3) { + return getDOH3CrossProtocolQueryFromDQ(dq, isResponse); + } #endif else { return getTCPCrossProtocolQueryFromDQ(dq); diff --git a/pdns/dnsdistdist/dnsdist-metrics.cc b/pdns/dnsdistdist/dnsdist-metrics.cc index adf961eb8b..d47236ea8b 100644 --- a/pdns/dnsdistdist/dnsdist-metrics.cc +++ b/pdns/dnsdistdist/dnsdist-metrics.cc @@ -114,6 +114,10 @@ Stats::Stats() : {"latency-doq-avg1000", &latencyDoQAvg1000}, {"latency-doq-avg10000", &latencyDoQAvg10000}, {"latency-doq-avg1000000", &latencyDoQAvg1000000}, + {"latency-doh3-avg100", &latencyDoH3Avg100}, + {"latency-doh3-avg1000", &latencyDoH3Avg1000}, + {"latency-doh3-avg10000", &latencyDoH3Avg10000}, + {"latency-doh3-avg1000000", &latencyDoH3Avg1000000}, {"uptime", uptimeOfProcess}, {"real-memory-usage", getRealMemoryUsage}, {"special-memory-usage", getSpecialMemoryUsage}, @@ -146,6 +150,7 @@ Stats::Stats() : {"doh-query-pipe-full", &dohQueryPipeFull}, {"doh-response-pipe-full", &dohResponsePipeFull}, {"doq-response-pipe-full", &doqResponsePipeFull}, + {"doh3-response-pipe-full", &doh3ResponsePipeFull}, {"outgoing-doh-query-pipe-full", &outgoingDoHQueryPipeFull}, {"tcp-query-pipe-full", &tcpQueryPipeFull}, {"tcp-cross-protocol-query-pipe-full", &tcpCrossProtocolQueryPipeFull}, diff --git a/pdns/dnsdistdist/dnsdist-metrics.hh b/pdns/dnsdistdist/dnsdist-metrics.hh index 264054f672..8e899cedea 100644 --- a/pdns/dnsdistdist/dnsdist-metrics.hh +++ b/pdns/dnsdistdist/dnsdist-metrics.hh @@ -75,6 +75,7 @@ struct Stats stat_t dohQueryPipeFull{0}; stat_t dohResponsePipeFull{0}; stat_t doqResponsePipeFull{0}; + stat_t doh3ResponsePipeFull{0}; stat_t outgoingDoHQueryPipeFull{0}; stat_t proxyProtocolInvalid{0}; stat_t tcpQueryPipeFull{0}; @@ -85,6 +86,7 @@ struct Stats double latencyDoTAvg100{0}, latencyDoTAvg1000{0}, latencyDoTAvg10000{0}, latencyDoTAvg1000000{0}; double latencyDoHAvg100{0}, latencyDoHAvg1000{0}, latencyDoHAvg10000{0}, latencyDoHAvg1000000{0}; double latencyDoQAvg100{0}, latencyDoQAvg1000{0}, latencyDoQAvg10000{0}, latencyDoQAvg1000000{0}; + double latencyDoH3Avg100{0}, latencyDoH3Avg1000{0}, latencyDoH3Avg10000{0}, latencyDoH3Avg1000000{0}; using statfunction_t = std::function; using entry_t = std::variant*, double*, statfunction_t>; struct EntryPair diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc new file mode 100644 index 0000000000..784de899d4 --- /dev/null +++ b/pdns/dnsdistdist/doh3.cc @@ -0,0 +1,1091 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "doh3.hh" + +#ifdef HAVE_DNS_OVER_HTTP3 +#include + +#include "dnsparser.hh" +#include "dolog.hh" +#include "iputils.hh" +#include "misc.hh" +#include "sodcrypto.hh" +#include "sstuff.hh" +#include "threadname.hh" +#include "base64.hh" + +#include "dnsdist-ecs.hh" +#include "dnsdist-dnsparser.hh" +#include "dnsdist-proxy-protocol.hh" +#include "dnsdist-tcp.hh" +#include "dnsdist-random.hh" + +// FIXME : to be renamed ? +static std::string s_quicRetryTokenKey = newKey(false); + +std::map DOH3Frontend::s_available_cc_algorithms = { + {"reno", QUICHE_CC_RENO}, + {"cubic", QUICHE_CC_CUBIC}, + {"bbr", QUICHE_CC_BBR}, +}; + +using QuicheConnection = std::unique_ptr; +using QuicheHTTP3Connection = std::unique_ptr; +using QuicheConfig = std::unique_ptr; +using QuicheHTTP3Config = std::unique_ptr; + +class H3Connection +{ +public: + H3Connection(const ComboAddress& peer, QuicheConnection&& conn) : + d_peer(peer), d_conn(std::move(conn)) + { + } + H3Connection(const H3Connection&) = delete; + H3Connection(H3Connection&&) = default; + H3Connection& operator=(const H3Connection&) = delete; + H3Connection& operator=(H3Connection&&) = default; + ~H3Connection() = default; + + ComboAddress d_peer; + QuicheConnection d_conn; + QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free}; + std::unordered_map d_streamBuffers; +}; + +static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description); + +struct DOH3ServerConfig +{ + DOH3ServerConfig(QuicheConfig&& config_, QuicheHTTP3Config&& http3config_, uint32_t internalPipeBufferSize) : + config(std::move(config_)), http3config(std::move(http3config_)) + { + { + auto [sender, receiver] = pdns::channel::createObjectQueue(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize); + d_responseSender = std::move(sender); + d_responseReceiver = std::move(receiver); + } + } + DOH3ServerConfig(const DOH3ServerConfig&) = delete; + DOH3ServerConfig(DOH3ServerConfig&&) = default; + DOH3ServerConfig& operator=(const DOH3ServerConfig&) = delete; + DOH3ServerConfig& operator=(DOH3ServerConfig&&) = default; + ~DOH3ServerConfig() = default; + + using ConnectionsMap = std::map; + + LocalHolders holders; + ConnectionsMap d_connections; + QuicheConfig config; + QuicheHTTP3Config http3config; + ClientState* clientState{nullptr}; + std::shared_ptr df{nullptr}; + pdns::channel::Sender d_responseSender; + pdns::channel::Receiver d_responseReceiver; +}; + +/* these might seem useless, but they are needed because + they need to be declared _after_ the definition of DOH3ServerConfig + so that we can use a unique_ptr in DOH3Frontend */ +DOH3Frontend::DOH3Frontend() = default; +DOH3Frontend::~DOH3Frontend() = default; + +#if 0 +#define DEBUGLOG_ENABLED +#define DEBUGLOG(x) std::cerr << x << std::endl; +#else +#define DEBUGLOG(x) +#endif + +static constexpr size_t MAX_DATAGRAM_SIZE = 1200; +static constexpr size_t LOCAL_CONN_ID_LEN = 16; + +class DOH3TCPCrossQuerySender final : public TCPQuerySender +{ +public: + DOH3TCPCrossQuerySender() = default; + + [[nodiscard]] bool active() const override + { + return true; + } + + void handleResponse([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override + { + if (!response.d_idstate.doh3u) { + return; + } + + auto unit = std::move(response.d_idstate.doh3u); + if (unit->dsc == nullptr) { + return; + } + + unit->response = std::move(response.d_buffer); + unit->ids = std::move(response.d_idstate); + DNSResponse dnsResponse(unit->ids, unit->response, unit->downstream); + + dnsheader cleartextDH{}; + memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH)); + + if (!response.isAsync()) { + + static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); + static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); + + dnsResponse.ids.doh3u = std::move(unit); + + if (!processResponse(dnsResponse.ids.doh3u->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) { + if (dnsResponse.ids.doh3u) { + + sendBackDOH3Unit(std::move(dnsResponse.ids.doh3u), "Response dropped by rules"); + } + return; + } + + if (dnsResponse.isAsynchronous()) { + return; + } + + unit = std::move(dnsResponse.ids.doh3u); + } + + if (!unit->ids.selfGenerated) { + double udiff = unit->ids.queryRealTime.udiff(); + vinfolog("Got answer from %s, relayed to %s (http/3), took %f us", unit->downstream->d_config.remote.toStringWithPort(), unit->ids.origRemote.toStringWithPort(), udiff); + + auto backendProtocol = unit->downstream->getProtocol(); + if (backendProtocol == dnsdist::Protocol::DoUDP && unit->tcp) { + backendProtocol = dnsdist::Protocol::DoTCP; + } + handleResponseSent(unit->ids, udiff, unit->ids.origRemote, unit->downstream->d_config.remote, unit->response.size(), cleartextDH, backendProtocol, true); + } + + ++dnsdist::metrics::g_stats.responses; + if (unit->ids.cs != nullptr) { + ++unit->ids.cs->responses; + } + + sendBackDOH3Unit(std::move(unit), "Cross-protocol response"); + } + + void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override + { + return handleResponse(now, std::move(response)); + } + + void notifyIOError([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override + { + if (!response.d_idstate.doh3u) { + return; + } + + auto unit = std::move(response.d_idstate.doh3u); + if (unit->dsc == nullptr) { + return; + } + + /* this will signal an error */ + unit->response.clear(); + unit->ids = std::move(response.d_idstate); + sendBackDOH3Unit(std::move(unit), "Cross-protocol error"); + } +}; + +class DOH3CrossProtocolQuery : public CrossProtocolQuery +{ +public: + DOH3CrossProtocolQuery(DOH3UnitUniquePtr&& unit, bool isResponse) + { + if (isResponse) { + /* happens when a response becomes async */ + query = InternalQuery(std::move(unit->response), std::move(unit->ids)); + } + else { + /* we need to duplicate the query here because we might need + the existing query later if we get a truncated answer */ + query = InternalQuery(PacketBuffer(unit->query), std::move(unit->ids)); + } + + /* it might have been moved when we moved unit->ids */ + if (unit) { + query.d_idstate.doh3u = std::move(unit); + } + + /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload, + clearing query.d_proxyProtocolPayloadAdded and unit->proxyProtocolPayloadSize. + Leave it for now because we know that the onky case where the payload has been + added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT, + and we know the TCP/DoT code can handle it. */ + query.d_proxyProtocolPayloadAdded = query.d_idstate.doh3u->proxyProtocolPayloadSize > 0; + downstream = query.d_idstate.doh3u->downstream; + } + + void handleInternalError() + { + sendBackDOH3Unit(std::move(query.d_idstate.doh3u), "DOH3 internal error"); + } + + std::shared_ptr getTCPQuerySender() override + { + query.d_idstate.doh3u->downstream = downstream; + return s_sender; + } + + DNSQuestion getDQ() override + { + auto& ids = query.d_idstate; + DNSQuestion dnsQuestion(ids, query.d_buffer); + return dnsQuestion; + } + + DNSResponse getDR() override + { + auto& ids = query.d_idstate; + DNSResponse dnsResponse(ids, query.d_buffer, downstream); + return dnsResponse; + } + + DOH3UnitUniquePtr&& releaseDU() + { + return std::move(query.d_idstate.doh3u); + } + +private: + static std::shared_ptr s_sender; +}; + +std::shared_ptr DOH3CrossProtocolQuery::s_sender = std::make_shared(); + +static void h3_send_response(quiche_conn *quic_conn, quiche_h3_conn *conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len) +{ + std::string status = std::to_string(statusCode); + std::string lenStr = std::to_string(len); + quiche_h3_header headers[] = { + { + .name = reinterpret_cast(":status"), + .name_len = sizeof(":status") - 1, + + .value = reinterpret_cast(status.data()), + .value_len = status.size(), + }, + { + .name = reinterpret_cast("content-length"), + .name_len = sizeof("content-length") - 1, + + .value = reinterpret_cast(lenStr.data()), + .value_len = lenStr.size(), + }, + }; + quiche_h3_send_response(conn, quic_conn, + streamID, headers, 2, false); + + size_t pos = 0; + while (pos < len) { + auto res = quiche_h3_send_body(conn, quic_conn, + streamID, const_cast(body) + pos, len - pos, true); + if (res < 0) { + // Shutdown with internal error code + quiche_conn_stream_shutdown(quic_conn, streamID, QUICHE_SHUTDOWN_WRITE, static_cast(1)); + return; + } + pos += res; + } +} + +static void h3_send_response(quiche_conn *quic_conn, quiche_h3_conn *conn, const uint64_t streamID, uint16_t statusCode, const std::string& content) +{ + h3_send_response(quic_conn, conn, streamID, statusCode, reinterpret_cast(content.data()), content.size()); +} +static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len) +{ + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, statusCode, body, len); +} + +static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response) +{ + if (statusCode == 200) { + ++frontend.d_validResponses; + } else { + ++frontend.d_errorResponses; + } + h3_send_response(conn, streamID, statusCode, &response.at(0), response.size()); +} + +static void fillRandom(PacketBuffer& buffer, size_t size) +{ + buffer.reserve(size); + while (size > 0) { + buffer.insert(buffer.end(), dnsdist::getRandomValue(std::numeric_limits::max())); + --size; + } +} + +void DOH3Frontend::setup() +{ + auto config = QuicheConfig(quiche_config_new(QUICHE_PROTOCOL_VERSION), quiche_config_free); + for (const auto& pair : d_tlsConfig.d_certKeyPairs) { + auto res = quiche_config_load_cert_chain_from_pem_file(config.get(), pair.d_cert.c_str()); + if (res != 0) { + throw std::runtime_error("Error loading the server certificate: " + std::to_string(res)); + } + if (pair.d_key) { + res = quiche_config_load_priv_key_from_pem_file(config.get(), pair.d_key->c_str()); + if (res != 0) { + throw std::runtime_error("Error loading the server key: " + std::to_string(res)); + } + } + } + + { + auto res = quiche_config_set_application_protos(config.get(), + (uint8_t *) QUICHE_H3_APPLICATION_PROTOCOL, + sizeof(QUICHE_H3_APPLICATION_PROTOCOL) - 1); + if (res != 0) { + throw std::runtime_error("Error setting ALPN: " + std::to_string(res)); + } + } + + quiche_config_set_max_idle_timeout(config.get(), d_idleTimeout * 1000); + /* maximum size of an outgoing packet, which means the buffer we pass to quiche_conn_send() should be at least that big */ + quiche_config_set_max_send_udp_payload_size(config.get(), MAX_DATAGRAM_SIZE); + quiche_config_set_max_recv_udp_payload_size(config.get(), MAX_DATAGRAM_SIZE); + + // The number of concurrent remotely-initiated bidirectional streams to be open at any given time + // https://docs.rs/quiche/latest/quiche/struct.Config.html#method.set_initial_max_streams_bidi + // 0 means none will get accepted, that's why we have a default value of 65535 + quiche_config_set_initial_max_streams_bidi(config.get(), d_maxInFlight); + quiche_config_set_initial_max_streams_uni(config.get(), d_maxInFlight); + + // The number of bytes of incoming stream data to be buffered for each localy or remotely-initiated bidirectional stream + quiche_config_set_initial_max_stream_data_bidi_local(config.get(), 1000000); + quiche_config_set_initial_max_stream_data_bidi_remote(config.get(), 1000000); + quiche_config_set_initial_max_stream_data_uni(config.get(), 1000000); + + quiche_config_set_disable_active_migration(config.get(), true); + + // The number of total bytes of incoming stream data to be buffered for the whole connection + // https://docs.rs/quiche/latest/quiche/struct.Config.html#method.set_initial_max_data + quiche_config_set_initial_max_data(config.get(), 8192 * d_maxInFlight); + if (!d_keyLogFile.empty()) { + quiche_config_log_keys(config.get()); + } + + auto algo = DOH3Frontend::s_available_cc_algorithms.find(d_ccAlgo); + if (algo != DOH3Frontend::s_available_cc_algorithms.end()) { + quiche_config_set_cc_algorithm(config.get(), static_cast(algo->second)); + } + + { + PacketBuffer resetToken; + fillRandom(resetToken, 16); + quiche_config_set_stateless_reset_token(config.get(), resetToken.data()); + } + + // quiche_h3_config_new + auto http3config = QuicheHTTP3Config(quiche_h3_config_new(), quiche_h3_config_free); + + d_server_config = std::make_unique(std::move(config), std::move(http3config), d_internalPipeBufferSize); +} + +static std::optional getCID() +{ + PacketBuffer buffer; + + fillRandom(buffer, LOCAL_CONN_ID_LEN); + + return buffer; +} + +static constexpr size_t MAX_TOKEN_LEN = dnsdist::crypto::authenticated::getEncryptedSize(std::tuple_size{} /* nonce */ + sizeof(uint64_t) /* TTD */ + 16 /* IPv6 */ + QUICHE_MAX_CONN_ID_LEN); + +static PacketBuffer mintToken(const PacketBuffer& dcid, const ComboAddress& peer) +{ + try { + SodiumNonce nonce; + nonce.init(); + + const auto addrBytes = peer.toByteString(); + // this token will be valid for 60s + const uint64_t ttd = time(nullptr) + 60U; + PacketBuffer plainTextToken; + plainTextToken.reserve(sizeof(ttd) + addrBytes.size() + dcid.size()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic) + plainTextToken.insert(plainTextToken.end(), reinterpret_cast(&ttd), reinterpret_cast(&ttd) + sizeof(ttd)); + plainTextToken.insert(plainTextToken.end(), addrBytes.begin(), addrBytes.end()); + plainTextToken.insert(plainTextToken.end(), dcid.begin(), dcid.end()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + const auto encryptedToken = sodEncryptSym(std::string_view(reinterpret_cast(plainTextToken.data()), plainTextToken.size()), s_quicRetryTokenKey, nonce, false); + // a bit sad, let's see if we can do better later + auto encryptedTokenPacket = PacketBuffer(encryptedToken.begin(), encryptedToken.end()); + encryptedTokenPacket.insert(encryptedTokenPacket.begin(), nonce.value.begin(), nonce.value.end()); + return encryptedTokenPacket; + } + catch (const std::exception& exp) { + vinfolog("Error while minting DoH3 token: %s", exp.what()); + throw; + } +} + +// returns the original destination ID if the token is valid, nothing otherwise +static std::optional validateToken(const PacketBuffer& token, const ComboAddress& peer) +{ + try { + SodiumNonce nonce; + auto addrBytes = peer.toByteString(); + const uint64_t now = time(nullptr); + const auto minimumSize = nonce.value.size() + sizeof(now) + addrBytes.size(); + if (token.size() <= minimumSize) { + return std::nullopt; + } + + memcpy(nonce.value.data(), token.data(), nonce.value.size()); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto cipher = std::string_view(reinterpret_cast(&token.at(nonce.value.size())), token.size() - nonce.value.size()); + auto plainText = sodDecryptSym(cipher, s_quicRetryTokenKey, nonce, false); + + if (plainText.size() <= sizeof(now) + addrBytes.size()) { + return std::nullopt; + } + + uint64_t ttd{0}; + memcpy(&ttd, plainText.data(), sizeof(ttd)); + if (ttd < now) { + return std::nullopt; + } + + if (std::memcmp(&plainText.at(sizeof(ttd)), &*addrBytes.begin(), addrBytes.size()) != 0) { + return std::nullopt; + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return PacketBuffer(plainText.begin() + (sizeof(ttd) + addrBytes.size()), plainText.end()); + } + catch (const std::exception& exp) { + vinfolog("Error while validating DoH3 token: %s", exp.what()); + return std::nullopt; + } +} + +static void handleStatelessRetry(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer, uint32_t version) +{ + auto newServerConnID = getCID(); + if (!newServerConnID) { + return; + } + + auto token = mintToken(serverConnID, peer); + + PacketBuffer out(MAX_DATAGRAM_SIZE); + auto written = quiche_retry(clientConnID.data(), clientConnID.size(), + serverConnID.data(), serverConnID.size(), + newServerConnID->data(), newServerConnID->size(), + token.data(), token.size(), + version, + out.data(), out.size()); + + if (written < 0) { + DEBUGLOG("failed to create retry packet " << written); + return; + } + + out.resize(written); + sock.sendTo(std::string(out.begin(), out.end()), peer); +} + +static void handleVersionNegociation(Socket& sock, const PacketBuffer& clientConnID, const PacketBuffer& serverConnID, const ComboAddress& peer) +{ + PacketBuffer out(MAX_DATAGRAM_SIZE); + + auto written = quiche_negotiate_version(clientConnID.data(), clientConnID.size(), + serverConnID.data(), serverConnID.size(), + out.data(), out.size()); + + if (written < 0) { + DEBUGLOG("failed to create vneg packet " << written); + return; + } + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + sock.sendTo(reinterpret_cast(out.data()), written, peer); +} + +static std::optional> getConnection(DOH3ServerConfig::ConnectionsMap& connMap, const PacketBuffer& connID) +{ + auto iter = connMap.find(connID); + if (iter == connMap.end()) { + return std::nullopt; + } + return iter->second; +} + +static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description) +{ + if (unit->dsc == nullptr) { + return; + } + try { + if (!unit->dsc->d_responseSender.send(std::move(unit))) { + ++dnsdist::metrics::g_stats.doh3ResponsePipeFull; + vinfolog("Unable to pass a %s to the DoH3 worker thread because the pipe is full", description); + } + } + catch (const std::exception& e) { + vinfolog("Unable to pass a %s to the DoH3 worker thread because we couldn't write to the pipe: %s", description, e.what()); + } +} + +static std::optional> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer) +{ + auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(), + originalDestinationID.data(), originalDestinationID.size(), + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + reinterpret_cast(&local), + local.getSocklen(), + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + reinterpret_cast(&peer), + peer.getSocklen(), + config.config.get()), + quiche_conn_free); + + if (config.df && !config.df->d_keyLogFile.empty()) { + quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_keyLogFile.c_str()); + } + + auto conn = H3Connection(peer, std::move(quicheConn)); + auto pair = config.d_connections.emplace(serverSideID, std::move(conn)); + return pair.first->second; +} + +static void flushEgress(Socket& sock, H3Connection& conn) +{ + std::array out{}; + quiche_send_info send_info; + + while (true) { + auto written = quiche_conn_send(conn.d_conn.get(), out.data(), out.size(), &send_info); + if (written == QUICHE_ERR_DONE) { + return; + } + + if (written < 0) { + return; + } + // FIXME pacing (as send_info.at should tell us when to send the packet) ? + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + sock.sendTo(reinterpret_cast(out.data()), written, conn.d_peer); + } +} + +std::unique_ptr getDOH3CrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion, bool isResponse) +{ + if (!dnsQuestion.ids.doh3u) { + throw std::runtime_error("Trying to create a DoH3 cross protocol query without a valid DoH3 unit"); + } + + auto unit = std::move(dnsQuestion.ids.doh3u); + if (&dnsQuestion.ids != &unit->ids) { + unit->ids = std::move(dnsQuestion.ids); + } + + unit->ids.origID = dnsQuestion.getHeader()->id; + + if (!isResponse) { + if (unit->query.data() != dnsQuestion.getMutableData().data()) { + unit->query = std::move(dnsQuestion.getMutableData()); + } + } + else { + if (unit->response.data() != dnsQuestion.getMutableData().data()) { + unit->response = std::move(dnsQuestion.getMutableData()); + } + } + + return std::make_unique(std::move(unit), isResponse); +} + +/* + We are not in the main DoH3 thread but in the DoH3 'client' thread. +*/ +static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit) +{ + const auto handleImmediateResponse = [](DOH3UnitUniquePtr&& unit, [[maybe_unused]] const char* reason) { + DEBUGLOG("handleImmediateResponse() reason=" << reason); + auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID); + handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response); + unit->ids.doh3u.reset(); + }; + + auto& ids = doh3Unit->ids; + ids.doh3u = std::move(doh3Unit); + auto& unit = ids.doh3u; + uint16_t queryId = 0; + ComboAddress remote; + + try { + + remote = unit->ids.origRemote; + DOH3ServerConfig* dsc = unit->dsc; + auto& holders = dsc->holders; + ClientState& clientState = *dsc->clientState; + + if (unit->query.size() < sizeof(dnsheader)) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + unit->response.clear(); + + unit->status_code = 400; + handleImmediateResponse(std::move(unit), "DoH3 non-compliant query"); + return; + } + + ++clientState.queries; + ++dnsdist::metrics::g_stats.queries; + unit->ids.queryRealTime.start(); + + { + /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */ + dnsheader_aligned dnsHeader(unit->query.data()); + + if (!checkQueryHeaders(dnsHeader.get(), clientState)) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) { + header.rcode = RCode::ServFail; + header.qr = true; + return true; + }); + unit->response = std::move(unit->query); + + unit->status_code = 400; + handleImmediateResponse(std::move(unit), "DoH3 invalid headers"); + return; + } + + if (dnsHeader->qdcount == 0) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); + unit->response = std::move(unit->query); + + unit->status_code = 400; + handleImmediateResponse(std::move(unit), "DoH3 empty query"); + return; + } + + queryId = ntohs(dnsHeader->id); + } + + auto downstream = unit->downstream; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + unit->ids.qname = DNSName(reinterpret_cast(unit->query.data()), static_cast(unit->query.size()), sizeof(dnsheader), false, &unit->ids.qtype, &unit->ids.qclass); + DNSQuestion dnsQuestion(unit->ids, unit->query); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) { + const uint16_t* flags = getFlagsFromDNSHeader(&header); + ids.origFlags = *flags; + return true; + }); + unit->ids.cs = &clientState; + + auto result = processQuery(dnsQuestion, holders, downstream); + if (result == ProcessQueryResult::Drop) { + unit->status_code = 403; + handleImmediateResponse(std::move(unit), "DoH3 dropped query"); + return; + } + if (result == ProcessQueryResult::Asynchronous) { + return; + } + if (result == ProcessQueryResult::SendAnswer) { + if (unit->response.empty()) { + unit->response = std::move(unit->query); + } + if (unit->response.size() >= sizeof(dnsheader)) { + const dnsheader_aligned dnsHeader(unit->response.data()); + + handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dnsHeader, dnsdist::Protocol::DoH3, dnsdist::Protocol::DoH3, false); + } + handleImmediateResponse(std::move(unit), "DoH3 self-answered response"); + return; + } + + ++dnsdist::metrics::g_stats.responses; + if (unit->ids.cs != nullptr) { + ++unit->ids.cs->responses; + } + + if (result != ProcessQueryResult::PassToBackend) { + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH3 no backend available"); + return; + } + + if (downstream == nullptr) { + unit->status_code = 502; + handleImmediateResponse(std::move(unit), "DoH3 no backend available"); + return; + } + + unit->downstream = downstream; + + std::string proxyProtocolPayload; + /* we need to do this _before_ creating the cross protocol query because + after that the buffer will have been moved */ + if (downstream->d_config.useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion); + } + + unit->ids.origID = htons(queryId); + unit->tcp = true; + + /* this moves unit->ids, careful! */ + auto cpq = std::make_unique(std::move(unit), false); + cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); + + if (downstream->passCrossProtocolQuery(std::move(cpq))) { + return; + } + // NOLINTNEXTLINE(bugprone-use-after-move): it was only moved if the call succeeded + unit = cpq->releaseDU(); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH3 internal error"); + return; + } + catch (const std::exception& e) { + vinfolog("Got an error in DOH3 question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH3 internal error"); + return; + } +} + +static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID) +{ + try { + /* we only parse it there as a sanity check, we will parse it again later */ + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + DNSPacketMangler mangler(reinterpret_cast(query.data()), query.size()); + mangler.skipDomainName(); + mangler.skipBytes(4); + + auto unit = std::make_unique(std::move(query)); + unit->dsc = &dsc; + unit->ids.origDest = local; + unit->ids.origRemote = remote; + unit->ids.protocol = dnsdist::Protocol::DoH3; + unit->serverConnID = serverConnID; + unit->streamID = streamID; + + processDOH3Query(std::move(unit)); + } + catch (const std::exception& exp) { + vinfolog("Had error parsing DoH3 DNS packet from %s: %s", remote.toStringWithPort(), exp.what()); + } +} + +static void flushResponses(pdns::channel::Receiver& receiver) +{ + for (;;) { + try { + auto tmp = receiver.receive(); + if (!tmp) { + return; + } + + auto unit = std::move(*tmp); + auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID); + if (conn) { + handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response); + } + } + catch (const std::exception& e) { + errlog("Error while processing response received over DoH3: %s", e.what()); + } + catch (...) { + errlog("Unspecified error while processing response received over DoH3"); + } + } +} + +// this is the entrypoint from dnsdist.cc +void doh3Thread(ClientState* clientState) +{ + try { + std::shared_ptr& frontend = clientState->doh3Frontend; + + frontend->d_server_config->clientState = clientState; + frontend->d_server_config->df = clientState->doh3Frontend; + + setThreadName("dnsdist/doh3"); + + Socket sock(clientState->udpFD); + + PacketBuffer buffer(std::numeric_limits::max()); + auto mplexer = std::unique_ptr(FDMultiplexer::getMultiplexerSilent()); + + auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor(); + mplexer->addReadFD(sock.getHandle(), [](int, FDMultiplexer::funcparam_t&) {}); + mplexer->addReadFD(responseReceiverFD, [](int, FDMultiplexer::funcparam_t&) {}); + while (true) { + std::vector readyFDs; + mplexer->getAvailableFDs(readyFDs, 500); + + if (std::find(readyFDs.begin(), readyFDs.end(), sock.getHandle()) != readyFDs.end()) { + DEBUGLOG("Received datagram"); + std::string bufferStr; + ComboAddress client; + sock.recvFrom(bufferStr, client); + + uint32_t version{0}; + uint8_t type{0}; + std::array scid{}; + size_t scid_len = scid.size(); + std::array dcid{}; + size_t dcid_len = dcid.size(); + std::array token{}; + size_t token_len = token.size(); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto res = quiche_header_info(reinterpret_cast(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN, + &version, &type, + scid.data(), &scid_len, + dcid.data(), &dcid_len, + token.data(), &token_len); + if (res != 0) { + DEBUGLOG("Error in quiche_header_info: " << res); + continue; + } + + // destination connection ID, will have to be sent as original destination connection ID + PacketBuffer serverConnID(dcid.begin(), dcid.begin() + dcid_len); + // source connection ID, will have to be sent as destination connection ID + PacketBuffer clientConnID(scid.begin(), scid.begin() + scid_len); + auto conn = getConnection(frontend->d_server_config->d_connections, serverConnID); + + if (!conn) { + DEBUGLOG("Connection not found"); + if (!quiche_version_is_supported(version)) { + DEBUGLOG("Unsupported version"); + ++frontend->d_doh3UnsupportedVersionErrors; + handleVersionNegociation(sock, clientConnID, serverConnID, client); + continue; + } + + if (token_len == 0) { + /* stateless retry */ + DEBUGLOG("No token received"); + handleStatelessRetry(sock, clientConnID, serverConnID, client, version); + continue; + } + + PacketBuffer tokenBuf(token.begin(), token.begin() + token_len); + auto originalDestinationID = validateToken(tokenBuf, client); + if (!originalDestinationID) { + ++frontend->d_doh3InvalidTokensReceived; + DEBUGLOG("Discarding invalid token"); + continue; + } + + DEBUGLOG("Creating a new connection"); + conn = createConnection(*frontend->d_server_config, serverConnID, *originalDestinationID, clientState->local, client); + if (!conn) { + continue; + } + } + DEBUGLOG("Connection found"); + quiche_recv_info recv_info = { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + reinterpret_cast(&client), + client.getSocklen(), + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + reinterpret_cast(&clientState->local), + clientState->local.getSocklen(), + }; + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto done = quiche_conn_recv(conn->get().d_conn.get(), reinterpret_cast(bufferStr.data()), bufferStr.size(), &recv_info); + if (done < 0) { + continue; + } + + if (quiche_conn_is_established(conn->get().d_conn.get())) { + DEBUGLOG("Connection is established"); + + if (!conn->get().d_http3) { + conn->get().d_http3 = QuicheHTTP3Connection(quiche_h3_conn_new_with_transport(conn->get().d_conn.get(), frontend->d_server_config->http3config.get()), + quiche_h3_conn_free); + if (!conn->get().d_http3) { + continue ; + } + DEBUGLOG("Successfully created HTTP/3 connection"); + } + + while (1) { + quiche_h3_event *ev; + // Processes HTTP/3 data received from the peer + int64_t streamID = quiche_h3_conn_poll(conn->get().d_http3.get(), + conn->get().d_conn.get(), + &ev); + + if (streamID < 0) { + break; + } + + switch (quiche_h3_event_type(ev)) { + case QUICHE_H3_EVENT_HEADERS: { + std::string path; + int rc = quiche_h3_event_for_each_header(ev, + [](uint8_t *name, size_t name_len, uint8_t *value, size_t value_len, void *argp) -> int { + std::string_view key(reinterpret_cast(name), name_len); + std::string_view content(reinterpret_cast(value), value_len); + if (key == ":path") { + auto pathptr = reinterpret_cast(argp); + *pathptr = content; + } + return 0; + }, &path); + if (rc != 0) { + DEBUGLOG("Failed to process headers"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to process query headers"); + break ; + } + if (path.empty()) { + DEBUGLOG("Path not found"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found"); + break; + } + { + auto pos = path.find("?dns="); + if (pos == string::npos) { + pos = path.find("&dns="); + } + if (pos != string::npos) { + // need to base64url decode this + string sdns(path.substr(pos + 5)); + boost::replace_all(sdns, "-", "+"); + boost::replace_all(sdns, "_", "/"); + // re-add padding that may have been missing + switch (sdns.size() % 4) { + case 2: + sdns.append(2, '='); + break; + case 3: + sdns.append(1, '='); + break; + } + + PacketBuffer decoded; + + /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */ + const size_t maxAdditionalSizeForEDNS = 35U; + /* rough estimate so we hopefully don't need a new allocation later */ + /* We reserve at few additional bytes to be able to add EDNS later */ + const size_t estimate = ((sdns.size() * 3) / 4); + decoded.reserve(estimate + maxAdditionalSizeForEDNS); + if (B64Decode(sdns, decoded) < 0) { + DEBUGLOG("Unable to base64 decode()"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL"); + break ; + } + + if (decoded.size() < sizeof(dnsheader)) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query"); + break; + } + DEBUGLOG("Dispatching query"); + doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID); + conn->get().d_streamBuffers.erase(streamID); + } + else { + DEBUGLOG("User error, unable to find the DNS parameter"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState->nonCompliantQueries; + ++frontend->d_errorResponses; + h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to find the DNS parameter"); + break ; + } + } + break; + } + + case QUICHE_H3_EVENT_DATA: + case QUICHE_H3_EVENT_FINISHED: + case QUICHE_H3_EVENT_RESET: + case QUICHE_H3_EVENT_PRIORITY_UPDATE: + case QUICHE_H3_EVENT_GOAWAY: + break; + } + + quiche_h3_event_free(ev); + } + + } + else { + DEBUGLOG("Connection not established"); + } + } + + if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) { + flushResponses(frontend->d_server_config->d_responseReceiver); + } + + for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) { + quiche_conn_on_timeout(conn->second.d_conn.get()); + + flushEgress(sock, conn->second); + + if (quiche_conn_is_closed(conn->second.d_conn.get())) { +#ifdef DEBUGLOG_ENABLED + quiche_stats stats; + quiche_path_stats path_stats; + + quiche_conn_stats(conn->second.d_conn.get(), &stats); + quiche_conn_path_stats(conn->second.d_conn.get(), 0, &path_stats); + + DEBUGLOG("Connection closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd); +#endif + conn = frontend->d_server_config->d_connections.erase(conn); + } + else { + ++conn; + } + } + } + } + catch (const std::exception& e) { + DEBUGLOG("Caught fatal error: " << e.what()); + } +} + +#endif /* HAVE_DNS_OVER_HTTP3 */ diff --git a/pdns/dnsdistdist/doh3.hh b/pdns/dnsdistdist/doh3.hh new file mode 100644 index 0000000000..40d44b9289 --- /dev/null +++ b/pdns/dnsdistdist/doh3.hh @@ -0,0 +1,122 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include + +#include "config.h" +#include "channel.hh" +#include "iputils.hh" +#include "libssl.hh" +#include "noinitvector.hh" +#include "stat_t.hh" +#include "dnsdist-idstate.hh" + +struct DOH3ServerConfig; +struct DownstreamState; + +#ifdef HAVE_DNS_OVER_HTTP3 + +struct DOH3Frontend +{ + DOH3Frontend(); + DOH3Frontend(const DOH3Frontend&) = delete; + DOH3Frontend(DOH3Frontend&&) = delete; + DOH3Frontend& operator=(const DOH3Frontend&) = delete; + DOH3Frontend& operator=(DOH3Frontend&&) = delete; + ~DOH3Frontend(); + + void setup(); + + std::unique_ptr d_server_config; + TLSConfig d_tlsConfig; + ComboAddress d_local; + std::string d_keyLogFile; + +#ifdef __linux__ + // On Linux this gives us 128k pending queries (default is 8192 queries), + // which should be enough to deal with huge spikes + uint32_t d_internalPipeBufferSize{1024 * 1024}; +#else + uint32_t d_internalPipeBufferSize{0}; +#endif + uint64_t d_idleTimeout{5}; + uint64_t d_maxInFlight{65535}; + std::string d_ccAlgo{"reno"}; + + pdns::stat_t d_doh3UnsupportedVersionErrors{0}; // Unsupported protocol version errors + pdns::stat_t d_doh3InvalidTokensReceived{0}; // Discarded received tokens + pdns::stat_t d_validResponses{0}; // Valid responses sent + pdns::stat_t d_errorResponses{0}; // Empty responses (no backend, drops, invalid queries, etc.) + + static std::map s_available_cc_algorithms; +}; + +struct DOH3Unit +{ + DOH3Unit(PacketBuffer&& query_) : + query(std::move(query_)) + { + } + + DOH3Unit(const DOH3Unit&) = delete; + DOH3Unit& operator=(const DOH3Unit&) = delete; + + InternalQueryState ids; + PacketBuffer query; + PacketBuffer response; + PacketBuffer serverConnID; + std::shared_ptr downstream{nullptr}; + DOH3ServerConfig* dsc{nullptr}; + uint64_t streamID{0}; + size_t proxyProtocolPayloadSize{0}; + uint16_t status_code{200}; + /* whether the query was re-sent to the backend over + TCP after receiving a truncated answer over UDP */ + bool tcp{false}; +}; + +using DOH3UnitUniquePtr = std::unique_ptr; + +struct CrossProtocolQuery; +struct DNSQuestion; +std::unique_ptr getDOH3CrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion, bool isResponse); + +void doh3Thread(ClientState* clientState); + +#else + +struct DOH3Unit +{ +}; + +struct DOH3Frontend +{ + DOH3Frontend() + { + } + void setup() + { + } +}; + +#endif diff --git a/pdns/dnsdistdist/m4/dnsdist_enable_doh3.m4 b/pdns/dnsdistdist/m4/dnsdist_enable_doh3.m4 new file mode 100644 index 0000000000..ffac6f0efa --- /dev/null +++ b/pdns/dnsdistdist/m4/dnsdist_enable_doh3.m4 @@ -0,0 +1,14 @@ +AC_DEFUN([DNSDIST_ENABLE_DNS_OVER_HTTP3], [ + AC_MSG_CHECKING([whether to enable incoming DNS over HTTP3 (DoH3) support]) + AC_ARG_ENABLE([dns-over-http3], + AS_HELP_STRING([--enable-dns-over-http3], [enable incoming DNS over HTTP3 (DoH3) support (requires quiche) @<:@default=no@:>@]), + [enable_dns_over_http3=$enableval], + [enable_dns_over_http3=no] + ) + AC_MSG_RESULT([$enable_dns_over_http3]) + AM_CONDITIONAL([HAVE_DNS_OVER_HTTP3], [test "x$enable_dns_over_http3" != "xno"]) + + AM_COND_IF([HAVE_DNS_OVER_HTTP3], [ + AC_DEFINE([HAVE_DNS_OVER_HTTP3], [1], [Define to 1 if you enable DNS over HTTP/3 support]) + ]) +])