From: Remi Gacogne Date: Mon, 31 Jul 2023 14:18:02 +0000 (+0200) Subject: dnsdist: Refactor the DoH code to be able to have two libraries X-Git-Tag: rec-5.0.0-alpha1~19^2~35 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=9eb152c401dae2604773dd992837e96ddd041a1a;p=thirdparty%2Fpdns.git dnsdist: Refactor the DoH code to be able to have two libraries --- diff --git a/pdns/dnsdist-carbon.cc b/pdns/dnsdist-carbon.cc index 693f498c43..d73d9ff8b7 100644 --- a/pdns/dnsdist-carbon.cc +++ b/pdns/dnsdist-carbon.cc @@ -147,7 +147,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint) errorCounters = &front->tlsFrontend->d_tlsCounters; } else if (front->dohFrontend != nullptr) { - errorCounters = &front->dohFrontend->d_tlsCounters; + errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters; } if (errorCounters != nullptr) { str << base << "tlsdhkeytoosmall" << ' ' << errorCounters->d_dhKeyTooSmall << " " << now << "\r\n"; @@ -204,7 +204,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint) std::map dohFrontendDuplicates; const string base = "dnsdist." + hostname + ".main.doh."; for (const auto& doh : g_dohlocals) { - string name = doh->d_local.toStringWithPort(); + string name = doh->d_tlsContext.d_addr.toStringWithPort(); boost::replace_all(name, ".", "_"); boost::replace_all(name, ":", "_"); boost::replace_all(name, "[", "_"); diff --git a/pdns/dnsdist-doh-common.hh b/pdns/dnsdist-doh-common.hh new file mode 100644 index 0000000000..44ad826a88 --- /dev/null +++ b/pdns/dnsdist-doh-common.hh @@ -0,0 +1,240 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include +#include + +#include "config.h" +#include "iputils.hh" +#include "libssl.hh" +#include "noinitvector.hh" +#include "stat_t.hh" +#include "tcpiohandler.hh" + +struct DOHServerConfig; + +class DOHResponseMapEntry +{ +public: + DOHResponseMapEntry(const std::string& regex, uint16_t status, const PacketBuffer& content, const boost::optional>& headers) : + d_regex(regex), d_customHeaders(headers), d_content(content), d_status(status) + { + if (status >= 400 && !d_content.empty() && d_content.at(d_content.size() - 1) != 0) { + // we need to make sure it's null-terminated + d_content.push_back(0); + } + } + + bool matches(const std::string& path) const + { + return d_regex.match(path); + } + + uint16_t getStatusCode() const + { + return d_status; + } + + const PacketBuffer& getContent() const + { + return d_content; + } + + const boost::optional>& getHeaders() const + { + return d_customHeaders; + } + +private: + Regex d_regex; + boost::optional> d_customHeaders; + PacketBuffer d_content; + uint16_t d_status; +}; + +struct DOHFrontend +{ + DOHFrontend() + { + } + DOHFrontend(std::shared_ptr tlsCtx) : + d_tlsContext(std::move(tlsCtx)) + { + } + + virtual ~DOHFrontend() + { + } + + std::shared_ptr d_dsc{nullptr}; + std::shared_ptr>> d_responsesMap; + TLSFrontend d_tlsContext{TLSFrontend::ALPN::DoH}; + std::string d_serverTokens{"h2o/dnsdist"}; + std::unordered_map d_customResponseHeaders; + std::string d_library; + + uint32_t d_idleTimeout{30}; // HTTP idle timeout in seconds + std::set> d_urls; + + pdns::stat_t d_httpconnects{0}; // number of TCP/IP connections established + pdns::stat_t d_getqueries{0}; // valid DNS queries received via GET + pdns::stat_t d_postqueries{0}; // valid DNS queries received via POST + pdns::stat_t d_badrequests{0}; // request could not be converted to dns query + pdns::stat_t d_errorresponses{0}; // dnsdist set 'error' on response + pdns::stat_t d_redirectresponses{0}; // dnsdist set 'redirect' on response + pdns::stat_t d_validresponses{0}; // valid responses sent out + + struct HTTPVersionStats + { + pdns::stat_t d_nbQueries{0}; // valid DNS queries received + pdns::stat_t d_nb200Responses{0}; + pdns::stat_t d_nb400Responses{0}; + pdns::stat_t d_nb403Responses{0}; + pdns::stat_t d_nb500Responses{0}; + pdns::stat_t d_nb502Responses{0}; + pdns::stat_t d_nbOtherResponses{0}; + }; + + HTTPVersionStats d_http1Stats; + HTTPVersionStats d_http2Stats; +#ifdef __linux__ + // On Linux this gives us 128k pending queries (default is 8192 queries), + // which should be enough to deal with huge spikes + uint32_t d_internalPipeBufferSize{1024 * 1024}; +#else + uint32_t d_internalPipeBufferSize{0}; +#endif + bool d_sendCacheControlHeaders{true}; + bool d_trustForwardedForHeader{false}; + /* whether we require tue query path to exactly match one of configured ones, + or accept everything below these paths. */ + bool d_exactPathMatching{true}; + bool d_keepIncomingHeaders{false}; + + time_t getTicketsKeyRotationDelay() const + { + return d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay; + } + + bool isHTTPS() const + { + return !d_tlsContext.d_tlsConfig.d_certKeyPairs.empty(); + } + +#ifndef HAVE_DNS_OVER_HTTPS + virtual void setup() + { + } + + virtual void reloadCertificates() + { + } + + virtual void rotateTicketsKey(time_t /* now */) + { + } + + virtual void loadTicketsKeys(const std::string& /* keyFile */) + { + } + + virtual void handleTicketsKeyRotation() + { + } + + virtual std::string getNextTicketsKeyRotation() + { + return std::string(); + } + + virtual size_t getTicketsKeysCount() const + { + size_t res = 0; + return res; + } + +#else + virtual void setup(); + virtual void reloadCertificates(); + + virtual void rotateTicketsKey(time_t now); + virtual void loadTicketsKeys(const std::string& keyFile); + virtual void handleTicketsKeyRotation(); + virtual std::string getNextTicketsKeyRotation() const; + virtual size_t getTicketsKeysCount(); +#endif /* HAVE_DNS_OVER_HTTPS */ +}; + +#include "dnsdist-idstate.hh" + +struct DownstreamState; + +#ifndef HAVE_DNS_OVER_HTTPS +struct DOHUnitInterface +{ + virtual ~DOHUnitInterface() + { + } + static void handleTimeout(std::unique_ptr) + { + } + + static void handleUDPResponse(std::unique_ptr, PacketBuffer&&, InternalQueryState&&, const std::shared_ptr&) + { + } +}; +#else /* HAVE_DNS_OVER_HTTPS */ +struct DOHUnitInterface +{ + virtual ~DOHUnitInterface() + { + } + + virtual std::string getHTTPPath() const = 0; + virtual std::string getHTTPQueryString() const = 0; + virtual const std::string& getHTTPHost() const = 0; + virtual const std::string& getHTTPScheme() const = 0; + virtual const std::unordered_map& getHTTPHeaders() const = 0; + virtual void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "") = 0; + virtual void handleTimeout() = 0; + virtual void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr&) = 0; + + static void handleTimeout(std::unique_ptr unit) + { + if (unit) { + unit->handleTimeout(); + unit.release(); + } + } + + static void handleUDPResponse(std::unique_ptr unit, PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr& ds) + { + if (unit) { + unit->handleUDPResponse(std::move(response), std::move(state), ds); + unit.release(); + } + } + + std::shared_ptr downstream{nullptr}; +}; +#endif /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index cf5442fea0..456e703fb3 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -22,6 +22,7 @@ #pragma once #include "config.h" +#include "dnscrypt.hh" #include "dnsname.hh" #include "dnsdist-protocols.hh" #include "gettime.hh" @@ -29,11 +30,12 @@ #include "uuid-utils.hh" struct ClientState; -struct DOHUnit; +struct DOHUnitInterface; class DNSCryptQuery; class DNSDistPacketCache; using QTag = std::unordered_map; +using HeadersMap = std::unordered_map; struct StopWatch { @@ -89,6 +91,8 @@ private: bool d_needRealTime; }; +class CrossProtocolContext; + struct InternalQueryState { struct ProtoBufData @@ -125,7 +129,9 @@ struct InternalQueryState std::unique_ptr d_protoBufData{nullptr}; boost::optional tempFailureTTL{boost::none}; // 8 ClientState* cs{nullptr}; // 8 - std::unique_ptr du{nullptr}; // 8 + std::unique_ptr du; // 8 + size_t d_proxyProtocolPayloadSize{0}; // 8 + int32_t d_streamID{-1}; // 4 uint32_t cacheKey{0}; // 4 uint32_t cacheKeyNoECS{0}; // 4 // DoH-only */ diff --git a/pdns/dnsdist-lua-inspection.cc b/pdns/dnsdist-lua-inspection.cc index 66200df0ea..f778a492dd 100644 --- a/pdns/dnsdist-lua-inspection.cc +++ b/pdns/dnsdist-lua-inspection.cc @@ -706,7 +706,7 @@ void setupLuaInspection(LuaContext& luaCtx) errorCounters = &f->tlsFrontend->d_tlsCounters; } else if (f->dohFrontend != nullptr) { - errorCounters = &f->dohFrontend->d_tlsCounters; + errorCounters = &f->dohFrontend->d_tlsContext.d_tlsCounters; } if (errorCounters == nullptr) { continue; diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 3224b8763f..c829c2e1b5 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -57,6 +57,7 @@ #include "dnsdist-web.hh" #include "base64.hh" +#include "doh.hh" #include "dolog.hh" #include "sodcrypto.hh" #include "threadname.hh" @@ -2336,31 +2337,39 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) setLuaSideEffect(); auto frontend = std::make_shared(); +#ifdef HAVE_LIBH2OEVLOOP + frontend = std::make_shared(); + frontend->d_library = "h2o"; +#else /* HAVE_LIBH2OEVLOOP */ + errlog("DOH bind %s is configured to use libh2o but the library is not available", addr); + return; +#endif /* HAVE_LIBH2OEVLOOP */ + if (certFiles && !certFiles->empty()) { - if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) { + if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) { return; } - frontend->d_local = ComboAddress(addr, 443); + frontend->d_tlsContext.d_addr = ComboAddress(addr, 443); } else { - frontend->d_local = ComboAddress(addr, 80); - infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_local.toStringWithPort()); + frontend->d_tlsContext.d_addr = ComboAddress(addr, 80); + infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort()); } if (urls) { if (urls->type() == typeid(std::string)) { - frontend->d_urls.push_back(boost::get(*urls)); + frontend->d_urls.insert(boost::get(*urls)); } else if (urls->type() == typeid(LuaArray)) { auto urlsVect = boost::get>(*urls); for (const auto& p : urlsVect) { - frontend->d_urls.push_back(p.second); + frontend->d_urls.insert(p.second); } } } else { - frontend->d_urls = {"/dns-query"}; + frontend->d_urls.insert("/dns-query"); } bool reusePort = false; @@ -2405,7 +2414,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } } - parseTLSConfig(frontend->d_tlsConfig, "addDOHLocal", vars); + parseTLSConfig(frontend->d_tlsContext.d_tlsConfig, "addDOHLocal", vars); bool ignoreTLSConfigurationErrors = false; if (getOptionalValue(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) { @@ -2413,7 +2422,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) // and properly ignore the frontend before actually launching it try { std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses); + auto ctx = libssl_init_server_context(frontend->d_tlsContext.d_tlsConfig, ocspResponses); } catch (const std::runtime_error& e) { errlog("Ignoring DoH frontend: '%s'", e.what()); @@ -2424,7 +2433,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) checkAllParametersConsumed("addDOHLocal", vars); } g_dohlocals.push_back(frontend); - auto cs = std::make_unique(frontend->d_local, true, reusePort, tcpFastOpenQueueSize, interface, cpus); + auto cs = std::make_unique(frontend->d_tlsContext.d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus); cs->dohFrontend = frontend; cs->d_additionalAddresses = std::move(additionalAddresses); @@ -2435,9 +2444,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) cs->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections; } g_frontends.push_back(std::move(cs)); -#else +#else /* HAVE_DNS_OVER_HTTPS */ throw std::runtime_error("addDOHLocal() called but DNS over HTTPS support is not present!"); -#endif +#endif /* HAVE_DNS_OVER_HTTPS */ }); luaCtx.writeFunction("showDOHFrontends", []() { @@ -2449,7 +2458,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) ret << (fmt % "#" % "Address" % "HTTP" % "HTTP/1" % "HTTP/2" % "GET" % "POST" % "Bad" % "Errors" % "Redirects" % "Valid" % "# ticket keys" % "Rotation delay" % "Next rotation") << endl; size_t counter = 0; for (const auto& ctx : g_dohlocals) { - ret << (fmt % counter % ctx->d_local.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl; + ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl; counter++; } g_outputBuffer = ret.str(); @@ -2473,7 +2482,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl; size_t counter = 0; for (const auto& ctx : g_dohlocals) { - ret << (fmt % counter % ctx->d_local.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl; + ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl; counter++; } g_outputBuffer += ret.str(); @@ -2483,7 +2492,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl; counter = 0; for (const auto& ctx : g_dohlocals) { - ret << (fmt % counter % ctx->d_local.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl; + ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl; counter++; } g_outputBuffer += ret.str(); @@ -2537,7 +2546,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) luaCtx.registerFunction::*)(boost::variant, LuaArray, LuaArray>> certFiles, boost::variant> keyFiles)>("loadNewCertificatesAndKeys", [](std::shared_ptr frontend, boost::variant, LuaArray, LuaArray>> certFiles, boost::variant> keyFiles) { #ifdef HAVE_DNS_OVER_HTTPS if (frontend != nullptr) { - if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) { + if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) { frontend->reloadCertificates(); } } @@ -2579,7 +2588,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } setLuaSideEffect(); - shared_ptr frontend = std::make_shared(TLSFrontend::ALPN::DoT); + auto frontend = std::make_shared(TLSFrontend::ALPN::DoT); if (!loadTLSCertificateAndKeys("addTLSLocal", frontend->d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) { return; } diff --git a/pdns/dnsdist-web.cc b/pdns/dnsdist-web.cc index d1132d3783..5037985972 100644 --- a/pdns/dnsdist-web.cc +++ b/pdns/dnsdist-web.cc @@ -741,7 +741,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp) errorCounters = &front->tlsFrontend->d_tlsCounters; } else if (front->dohFrontend != nullptr) { - errorCounters = &front->dohFrontend->d_tlsCounters; + errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters; } if (errorCounters != nullptr) { @@ -779,7 +779,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp) #ifdef HAVE_DNS_OVER_HTTPS std::map dohFrontendDuplicates; for(const auto& doh : g_dohlocals) { - const string frontName = doh->d_local.toStringWithPort(); + const string frontName = doh->d_tlsContext.d_addr.toStringWithPort(); uint64_t threadNumber = 0; auto dupPair = frontendDuplicates.emplace(frontName, 1); if (!dupPair.second) { @@ -1149,7 +1149,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp) errorCounters = &front->tlsFrontend->d_tlsCounters; } else if (front->dohFrontend != nullptr) { - errorCounters = &front->dohFrontend->d_tlsCounters; + errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters; } if (errorCounters != nullptr) { frontend["tlsHandshakeFailuresDHKeyTooSmall"] = (double)errorCounters->d_dhKeyTooSmall; @@ -1172,7 +1172,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp) for (const auto& doh : g_dohlocals) { dohs.emplace_back(Json::object{ { "id", num++ }, - { "address", doh->d_local.toStringWithPort() }, + { "address", doh->d_tlsContext.d_addr.toStringWithPort() }, { "http-connects", (double) doh->d_httpconnects }, { "http1-queries", (double) doh->d_http1Stats.d_nbQueries }, { "http2-queries", (double) doh->d_http2Stats.d_nbQueries }, diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index a673bd6f54..fdf2797104 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -69,6 +69,7 @@ #include "base64.hh" #include "capabilities.hh" #include "delaypipe.hh" +#include "doh.hh" #include "dolog.hh" #include "dnsname.hh" #include "dnsparser.hh" @@ -784,7 +785,7 @@ void responderThread(std::shared_ptr dss) if (du) { #ifdef HAVE_DNS_OVER_HTTPS // DoH query, we cannot touch du after that - handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids)); + DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(*ids), dss); #endif continue; } @@ -1539,19 +1540,14 @@ ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::sha return ProcessQueryResult::Drop; } -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest) +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query) { bool doh = dq.ids.du != nullptr; bool failed = false; - size_t proxyPayloadSize = 0; if (ds->d_config.useProxyProtocol) { try { - if (addProxyProtocol(dq, &proxyPayloadSize)) { - if (dq.ids.du) { - dq.ids.du->proxyProtocolPayloadSize = proxyPayloadSize; - } - } + addProxyProtocol(dq, &dq.ids.d_proxyProtocolPayloadSize); } catch (const std::exception& e) { vinfolog("Adding proxy protocol payload to %s query from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what()); @@ -1559,6 +1555,10 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint1 } } + if (doh && !dq.ids.d_packet) { + dq.ids.d_packet = std::make_unique(query); + } + try { int fd = ds->pickSocketForSending(); dq.ids.backendFD = fd; @@ -1569,7 +1569,7 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint1 auto idOffset = ds->saveState(std::move(dq.ids)); /* set the correct ID */ - memcpy(query.data() + proxyPayloadSize, &idOffset, sizeof(idOffset)); + memcpy(query.data() + dq.ids.d_proxyProtocolPayloadSize, &idOffset, sizeof(idOffset)); /* you can't touch ids or du after this line, unless the call returned a non-negative value, because it might already have been freed */ @@ -1585,9 +1585,6 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint1 auto cleared = ds->getState(idOffset); if (cleared) { dq.ids.du = std::move(cleared->du); - if (dq.ids.du) { - dq.ids.du->status_code = 502; - } } ++dnsdist::metrics::g_stats.downstreamSendErrors; ++ds->sendErrors; @@ -1720,7 +1717,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct return; } - assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, dest); + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query); } catch(const std::exception& e){ vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what()); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 56b7421655..a9ecef0170 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -42,7 +42,7 @@ #include "dnsdist-lbpolicies.hh" #include "dnsdist-protocols.hh" #include "dnsname.hh" -#include "doh.hh" +#include "dnsdist-doh-common.hh" #include "ednsoptions.hh" #include "iputils.hh" #include "misc.hh" @@ -1088,10 +1088,6 @@ struct LocalHolders void tcpAcceptorThread(std::vector states); -#ifdef HAVE_DNS_OVER_HTTPS -void dohThread(ClientState* cs); -#endif /* HAVE_DNS_OVER_HTTPS */ - void setLuaNoSideEffect(); // if nothing has been declared, set that there are no side effects void setLuaSideEffect(); // set to report a side effect, cancelling all _no_ side effect calls bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect @@ -1123,7 +1119,7 @@ bool processResponse(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted); -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest); +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 9b951a5866..99d7cdbe64 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -147,6 +147,7 @@ dnsdist_SOURCES = \ dnsdist-discovery.cc dnsdist-discovery.hh \ dnsdist-dnscrypt.cc \ dnsdist-dnsparser.cc dnsdist-dnsparser.hh \ + dnsdist-doh-common.cc dnsdist-doh-common.hh \ dnsdist-downstream-connection.hh \ dnsdist-dynblocks.cc dnsdist-dynblocks.hh \ dnsdist-dynbpf.cc dnsdist-dynbpf.hh \ @@ -256,6 +257,7 @@ testrunner_SOURCES = \ dnsdist-cache.cc dnsdist-cache.hh \ dnsdist-concurrent-connections.hh \ dnsdist-dnsparser.cc dnsdist-dnsparser.hh \ + dnsdist-doh-common.cc dnsdist-doh-common.hh \ dnsdist-downstream-connection.hh \ dnsdist-dynblocks.cc dnsdist-dynblocks.hh \ dnsdist-dynbpf.cc dnsdist-dynbpf.hh \ diff --git a/pdns/dnsdistdist/dnsdist-async.cc b/pdns/dnsdistdist/dnsdist-async.cc index 19426468df..f54b1c0b14 100644 --- a/pdns/dnsdistdist/dnsdist-async.cc +++ b/pdns/dnsdistdist/dnsdist-async.cc @@ -282,7 +282,6 @@ bool resumeQuery(std::unique_ptr&& query) return resumeResponse(std::move(query)); } - auto& ids = query->query.d_idstate; DNSQuestion dnsQuestion = query->getDQ(); LocalHolders holders; @@ -311,7 +310,7 @@ bool resumeQuery(std::unique_ptr&& query) /* at this point 'du', if it is not nullptr, is owned by the DoHCrossProtocolQuery which will stop existing when we return, so we need to increment the reference count */ - return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dnsQuestion, query->query.d_buffer, ids.origDest); + return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dnsQuestion, query->query.d_buffer); } if (result == ProcessQueryResult::SendAnswer) { auto sender = query->getTCPQuerySender(); diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc index 45a50446da..44b3d9c39d 100644 --- a/pdns/dnsdistdist/dnsdist-backend.cc +++ b/pdns/dnsdistdist/dnsdist-backend.cc @@ -360,7 +360,7 @@ void DownstreamState::handleUDPTimeout(IDState& ids) { ids.age = 0; ids.inUse = false; - handleDOHTimeout(std::move(ids.internal.du)); + DOHUnitInterface::handleTimeout(std::move(ids.internal.du)); ++reuseds; --outstanding; ++dnsdist::metrics::g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout @@ -463,7 +463,7 @@ uint16_t DownstreamState::saveState(InternalQueryState&& state) auto oldDU = std::move(it->second.internal.du); ++reuseds; ++dnsdist::metrics::g_stats.downstreamTimeouts; - handleDOHTimeout(std::move(oldDU)); + DOHUnitInterface::handleTimeout(std::move(oldDU)); } else { ++outstanding; @@ -490,7 +490,7 @@ uint16_t DownstreamState::saveState(InternalQueryState&& state) auto oldDU = std::move(ids.internal.du); ++reuseds; ++dnsdist::metrics::g_stats.downstreamTimeouts; - handleDOHTimeout(std::move(oldDU)); + DOHUnitInterface::handleTimeout(std::move(oldDU)); } else { ++outstanding; @@ -513,7 +513,7 @@ void DownstreamState::restoreState(uint16_t id, InternalQueryState&& state) /* already used */ ++reuseds; ++dnsdist::metrics::g_stats.downstreamTimeouts; - handleDOHTimeout(std::move(state.du)); + DOHUnitInterface::handleTimeout(std::move(state.du)); } else { it->second.internal = std::move(state); @@ -528,14 +528,14 @@ void DownstreamState::restoreState(uint16_t id, InternalQueryState&& state) /* already used */ ++reuseds; ++dnsdist::metrics::g_stats.downstreamTimeouts; - handleDOHTimeout(std::move(state.du)); + DOHUnitInterface::handleTimeout(std::move(state.du)); return; } if (ids.isInUse()) { /* already used */ ++reuseds; ++dnsdist::metrics::g_stats.downstreamTimeouts; - handleDOHTimeout(std::move(state.du)); + DOHUnitInterface::handleTimeout(std::move(state.du)); return; } ids.internal = std::move(state); diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc new file mode 100644 index 0000000000..15fcb9672a --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -0,0 +1,129 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "dnsdist-doh-common.hh" +#include "dnsdist-rules.hh" + +#ifdef HAVE_DNS_OVER_HTTPS + +HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& regex) : + d_header(toLower(header)), d_regex(regex), d_visual("http[" + header + "] ~ " + regex) +{ +} + +bool HTTPHeaderRule::matches(const DNSQuestion* dq) const +{ + if (!dq->ids.du) { + return false; + } + + const auto& headers = dq->ids.du->getHTTPHeaders(); + for (const auto& header : headers) { + if (header.first == d_header) { + return d_regex.match(header.second); + } + } + return false; +} + +string HTTPHeaderRule::toString() const +{ + return d_visual; +} + +HTTPPathRule::HTTPPathRule(const std::string& path) : + d_path(path) +{ +} + +bool HTTPPathRule::matches(const DNSQuestion* dq) const +{ + if (!dq->ids.du) { + return false; + } + + const auto path = dq->ids.du->getHTTPPath(); + return d_path == path; +} + +string HTTPPathRule::toString() const +{ + return "url path == " + d_path; +} + +HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex) : + d_regex(regex), d_visual("http path ~ " + regex) +{ +} + +bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const +{ + if (!dq->ids.du) { + return false; + } + + return d_regex.match(dq->ids.du->getHTTPPath()); +} + +string HTTPPathRegexRule::toString() const +{ + return d_visual; +} + +void DOHFrontend::rotateTicketsKey(time_t now) +{ + return d_tlsContext.rotateTicketsKey(now); +} + +void DOHFrontend::loadTicketsKeys(const std::string& keyFile) +{ + return d_tlsContext.loadTicketsKeys(keyFile); +} + +void DOHFrontend::handleTicketsKeyRotation() +{ +} + +std::string DOHFrontend::getNextTicketsKeyRotation() const +{ + return d_tlsContext.getNextTicketsKeyRotation(); +} + +size_t DOHFrontend::getTicketsKeysCount() +{ + return d_tlsContext.getTicketsKeysCount(); +} + +void DOHFrontend::reloadCertificates() +{ + d_tlsContext.setupTLS(); +} + +void DOHFrontend::setup() +{ + if (isHTTPS()) { + if (!d_tlsContext.setupTLS()) { + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort()); + } + } +} + +#endif /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/dnsdistdist/dnsdist-doh-common.hh b/pdns/dnsdistdist/dnsdist-doh-common.hh new file mode 120000 index 0000000000..5692084494 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-doh-common.hh @@ -0,0 +1 @@ +../dnsdist-doh-common.hh \ No newline at end of file diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index eeb0af4808..91dcd9ad76 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -167,6 +167,8 @@ private: std::atomic_flag d_rotatingTicketsKey; }; +struct DOHUnit; + // we create one of these per thread, and pass around a pointer to it // through the bowels of h2o struct DOHServerConfig @@ -215,6 +217,61 @@ struct DOHServerConfig pdns::channel::Receiver d_responseReceiver; }; +struct DOHUnit : public DOHUnitInterface +{ + DOHUnit(PacketBuffer&& q, std::string&& p, std::string&& h): path(std::move(p)), host(std::move(h)), query(std::move(q)) + { + ids.ednsAdded = false; + } + ~DOHUnit() + { + if (self) { + *self = nullptr; + } + } + + DOHUnit(const DOHUnit&) = delete; + DOHUnit& operator=(const DOHUnit&) = delete; + + InternalQueryState ids; + std::string sni; + std::string path; + std::string scheme; + std::string host; + std::string contentType; + PacketBuffer query; + PacketBuffer response; + std::unique_ptr> headers; + st_h2o_req_t* req{nullptr}; + DOHUnit** self{nullptr}; + DOHServerConfig* dsc{nullptr}; + pdns::channel::Sender* responseSender{nullptr}; + size_t query_at{0}; + int rsock{-1}; + /* the status_code is set from + processDOHQuery() (which is executed in + the DOH client thread) so that the correct + response can be sent in on_dnsdist(), + after the DOHUnit has been passed back to + the main DoH thread. + */ + uint16_t status_code{200}; + /* whether the query was re-sent to the backend over + TCP after receiving a truncated answer over UDP */ + bool tcp{false}; + bool truncated{false}; + + std::string getHTTPPath() const override; + std::string getHTTPQueryString() const override; + const std::string& getHTTPHost() const override; + const std::string& getHTTPScheme() const override; + const std::unordered_map& getHTTPHeaders() const override; + void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType="") override; + virtual void handleTimeout() override; + virtual void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr&) override; +}; +using DOHUnitUniquePtr = std::unique_ptr; + /* This internal function sends back the object to the main thread to send a reply. The caller should NOT release or touch the unit after calling this function */ static void sendDoHUnitToTheMainThread(DOHUnitUniquePtr&& du, const char* description) @@ -233,18 +290,11 @@ static void sendDoHUnitToTheMainThread(DOHUnitUniquePtr&& du, const char* descri } /* This function is called from other threads than the main DoH one, - instructing it to send a 502 error to the client. - It takes ownership of the unit. */ -void handleDOHTimeout(DOHUnitUniquePtr&& oldDU) + instructing it to send a 502 error to the client. */ +void DOHUnit::handleTimeout() { - if (oldDU == nullptr) { - return; - } - - /* we are about to erase an existing DU */ - oldDU->status_code = 502; - - sendDoHUnitToTheMainThread(std::move(oldDU), "DoH timeout"); + status_code = 502; + sendDoHUnitToTheMainThread(std::unique_ptr(this), "DoH timeout"); } struct DOHConnection @@ -385,7 +435,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo h2o_send_error_400(req, getReasonFromStatusCode(statusCode).c_str(), "invalid DNS query" , 0); break; case 403: - h2o_send_error_403(req, getReasonFromStatusCode(statusCode).c_str(), "dns query not allowed", 0); + h2o_send_error_403(req, getReasonFromStatusCode(statusCode).c_str(), "DoH query not allowed", 0); break; case 502: h2o_send_error_502(req, getReasonFromStatusCode(statusCode).c_str(), "no downstream server available", 0); @@ -402,6 +452,12 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo } } +static std::unique_ptr getDUFromIDS(InternalQueryState& ids) +{ + auto du = std::unique_ptr(dynamic_cast(ids.du.release())); + return du; +} + class DoHTCPCrossQuerySender : public TCPQuerySender { public: @@ -420,7 +476,7 @@ public: return; } - auto du = std::move(response.d_idstate.du); + auto du = getDUFromIDS(response.d_idstate); if (du->responseSender == nullptr) { return; } @@ -438,10 +494,11 @@ public: dr.ids.du = std::move(du); - if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(dynamic_cast(dr.ids.du.get())->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { if (dr.ids.du) { - dr.ids.du->status_code = 503; - sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + du = getDUFromIDS(dr.ids); + du->status_code = 503; + sendDoHUnitToTheMainThread(std::move(du), "Response dropped by rules"); } return; } @@ -450,7 +507,7 @@ public: return; } - du = std::move(dr.ids.du); + du = getDUFromIDS(dr.ids); } if (!du->ids.selfGenerated) { @@ -483,11 +540,11 @@ public: return; } - if (query.du->responseSender == nullptr) { + auto du = getDUFromIDS(query); + if (du->responseSender == nullptr) { return; } - auto du = std::move(query.du); du->ids = std::move(query); du->status_code = 502; sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response"); @@ -519,20 +576,23 @@ public: Leave it for now because we know that the onky case where the payload has been added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT, and we know the TCP/DoT code can handle it. */ - query.d_proxyProtocolPayloadAdded = query.d_idstate.du->proxyProtocolPayloadSize > 0; + query.d_proxyProtocolPayloadAdded = query.d_idstate.d_proxyProtocolPayloadSize > 0; downstream = query.d_idstate.du->downstream; - proxyProtocolPayloadSize = query.d_idstate.du->proxyProtocolPayloadSize; } void handleInternalError() { - query.d_idstate.du->status_code = 502; - sendDoHUnitToTheMainThread(std::move(query.d_idstate.du), "DoH internal error"); + auto du = getDUFromIDS(query.d_idstate); + if (!du) { + return; + } + du->status_code = 502; + sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); } std::shared_ptr getTCPQuerySender() override { - query.d_idstate.du->downstream = downstream; + dynamic_cast(query.d_idstate.du.get())->downstream = downstream; return s_sender; } @@ -550,9 +610,9 @@ public: return dr; } - DOHUnitUniquePtr&& releaseDU() + DOHUnitUniquePtr releaseDU() { - return std::move(query.d_idstate.du); + return getDUFromIDS(query.d_idstate); } private: @@ -567,7 +627,7 @@ std::unique_ptr getDoHCrossProtocolQueryFromDQ(DNSQuestion& throw std::runtime_error("Trying to create a DoH cross protocol query without a valid DoH unit"); } - auto du = std::move(dq.ids.du); + auto du = getDUFromIDS(dq.ids); if (&dq.ids != &du->ids) { du->ids = std::move(dq.ids); } @@ -606,121 +666,116 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) }; auto& ids = unit->ids; - ids.du = std::move(unit); - auto& du = ids.du; uint16_t queryId = 0; ComboAddress remote; try { - if (!du->req) { + if (!unit->req) { // we got closed meanwhile. XXX small race condition here // but we should be fine as long as we don't touch du->req // outside of the main DoH thread - du->status_code = 500; - handleImmediateResponse(std::move(du), "DoH killed in flight"); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH killed in flight"); return; } - { - // if there was no EDNS, we add it with a large buffer size - // so we can use UDP to talk to the backend. - auto dh = const_cast(reinterpret_cast(du->query.data())); - - if (!dh->arcount) { - if (generateOptRR(std::string(), du->query, 4096, 4096, 0, false)) { - dh = const_cast(reinterpret_cast(du->query.data())); // may have reallocated - dh->arcount = htons(1); - du->ids.ednsAdded = true; - } - } - else { - // we leave existing EDNS in place - } - } - - remote = du->ids.origRemote; - DOHServerConfig* dsc = du->dsc; + remote = ids.origRemote; + DOHServerConfig* dsc = unit->dsc; auto& holders = dsc->holders; ClientState& cs = *dsc->cs; - if (du->query.size() < sizeof(dnsheader)) { + if (unit->query.size() < sizeof(dnsheader)) { ++dnsdist::metrics::g_stats.nonCompliantQueries; ++cs.nonCompliantQueries; - du->status_code = 400; - handleImmediateResponse(std::move(du), "DoH non-compliant query"); + unit->status_code = 400; + handleImmediateResponse(std::move(unit), "DoH non-compliant query"); return; } ++cs.queries; ++dnsdist::metrics::g_stats.queries; - du->ids.queryRealTime.start(); + ids.queryRealTime.start(); { /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */ - struct dnsheader* dh = reinterpret_cast(du->query.data()); + struct dnsheader* dh = reinterpret_cast(unit->query.data()); if (!checkQueryHeaders(dh, cs)) { - du->status_code = 400; - handleImmediateResponse(std::move(du), "DoH invalid headers"); + unit->status_code = 400; + handleImmediateResponse(std::move(unit), "DoH invalid headers"); return; } if (dh->qdcount == 0) { dh->rcode = RCode::NotImp; dh->qr = true; - du->response = std::move(du->query); + unit->response = std::move(unit->query); - handleImmediateResponse(std::move(du), "DoH empty query"); + handleImmediateResponse(std::move(unit), "DoH empty query"); return; } queryId = ntohs(dh->id); } - auto downstream = du->downstream; - du->ids.qname = DNSName(reinterpret_cast(du->query.data()), du->query.size(), sizeof(dnsheader), false, &du->ids.qtype, &du->ids.qclass); - DNSQuestion dq(du->ids, du->query); + { + // if there was no EDNS, we add it with a large buffer size + // so we can use UDP to talk to the backend. + auto dh = const_cast(reinterpret_cast(unit->query.data())); + if (!dh->arcount) { + if (addEDNS(unit->query, 4096, false, 4096, 0)) { + ids.ednsAdded = true; + } + } + } + + auto downstream = unit->downstream; + ids.qname = DNSName(reinterpret_cast(unit->query.data()), unit->query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + DNSQuestion dq(ids, unit->query); const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); ids.origFlags = *flags; - du->ids.cs = &cs; - dq.sni = std::move(du->sni); - + ids.cs = &cs; + dq.sni = std::move(unit->sni); + ids.du = std::move(unit); auto result = processQuery(dq, holders, downstream); if (result == ProcessQueryResult::Drop) { - du->status_code = 403; - handleImmediateResponse(std::move(du), "DoH dropped query"); + unit = getDUFromIDS(ids); + unit->status_code = 403; + handleImmediateResponse(std::move(unit), "DoH dropped query"); return; } else if (result == ProcessQueryResult::Asynchronous) { return; } else if (result == ProcessQueryResult::SendAnswer) { - if (du->response.empty()) { - du->response = std::move(du->query); + unit = getDUFromIDS(ids); + if (unit->response.empty()) { + unit->response = std::move(unit->query); } - if (du->response.size() >= sizeof(dnsheader) && du->contentType.empty()) { - auto dh = reinterpret_cast(du->response.data()); + if (unit->response.size() >= sizeof(dnsheader) && unit->contentType.empty()) { + auto dh = reinterpret_cast(unit->response.data()); - handleResponseSent(du->ids.qname, QType(du->ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH, false); + handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH, false); } - handleImmediateResponse(std::move(du), "DoH self-answered response"); + handleImmediateResponse(std::move(unit), "DoH self-answered response"); return; } + unit = getDUFromIDS(ids); if (result != ProcessQueryResult::PassToBackend) { - du->status_code = 500; - handleImmediateResponse(std::move(du), "DoH no backend available"); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH no backend available"); return; } if (downstream == nullptr) { - du->status_code = 502; - handleImmediateResponse(std::move(du), "DoH no backend available"); + unit->status_code = 502; + handleImmediateResponse(std::move(unit), "DoH no backend available"); return; } - du->downstream = downstream; + unit->downstream = downstream; if (downstream->isTCPOnly()) { std::string proxyProtocolPayload; @@ -730,11 +785,11 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) proxyProtocolPayload = getProxyProtocolPayload(dq); } - du->ids.origID = htons(queryId); - du->tcp = true; + unit->ids.origID = htons(queryId); + unit->tcp = true; /* this moves du->ids, careful! */ - auto cpq = std::make_unique(std::move(du), false); + auto cpq = std::make_unique(std::move(unit), false); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); if (downstream->passCrossProtocolQuery(std::move(cpq))) { @@ -742,9 +797,9 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) } else { if (inMainThread) { - du = cpq->releaseDU(); - du->status_code = 502; - handleImmediateResponse(std::move(du), "DoH internal error"); + unit = cpq->releaseDU(); + unit->status_code = 502; + handleImmediateResponse(std::move(unit), "DoH internal error"); } else { cpq->handleInternalError(); @@ -753,17 +808,19 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) } } - ComboAddress dest = dq.ids.origDest; - if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, du->query, dest)) { - du->status_code = 502; - handleImmediateResponse(std::move(du), "DoH internal error"); + auto& query = unit->query; + ids.du = std::move(unit); + if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, query)) { + unit = getDUFromIDS(ids); + unit->status_code = 502; + handleImmediateResponse(std::move(unit), "DoH internal error"); return; } } catch (const std::exception& e) { vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); - du->status_code = 500; - handleImmediateResponse(std::move(du), "DoH internal error"); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH internal error"); return; } @@ -838,7 +895,7 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re /* we are doing quite some copies here, sorry about that, but we can't keep accessing the req object once we are in a different thread because the request might get killed by h2o at pretty much any time */ - auto du = std::make_unique(std::move(query), std::move(path), std::string(req->authority.base, req->authority.len)); + auto du = DOHUnitUniquePtr(new DOHUnit(std::move(query), std::move(path), std::string(req->authority.base, req->authority.len))); du->dsc = dsc; du->req = req; du->ids.origDest = local; @@ -869,7 +926,7 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re *(du->self) = du.get(); #ifdef USE_SINGLE_ACCEPTOR_THREAD - processDOHQuery(du, true); + processDOHQuery(std::move(du), true); #else /* USE_SINGLE_ACCEPTOR_THREAD */ try { if (!dsc->d_querySender.send(std::move(du))) { @@ -1102,85 +1159,13 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req) } } -HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& regex) - : d_header(toLower(header)), d_regex(regex), d_visual("http[" + header+ "] ~ " + regex) -{ -} - -bool HTTPHeaderRule::matches(const DNSQuestion* dq) const -{ - if (!dq->ids.du || !dq->ids.du->headers) { - return false; - } - - for (const auto& header : *dq->ids.du->headers) { - if (header.first == d_header) { - return d_regex.match(header.second); - } - } - return false; -} - -string HTTPHeaderRule::toString() const -{ - return d_visual; -} - -HTTPPathRule::HTTPPathRule(const std::string& path) - : d_path(path) -{ - -} - -bool HTTPPathRule::matches(const DNSQuestion* dq) const -{ - if (!dq->ids.du) { - return false; - } - - if (dq->ids.du->query_at == SIZE_MAX) { - return dq->ids.du->path == d_path; - } - else { - return d_path.compare(0, d_path.size(), dq->ids.du->path, 0, dq->ids.du->query_at) == 0; - } -} - -string HTTPPathRule::toString() const -{ - return "url path == " + d_path; -} - -HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex): d_regex(regex), d_visual("http path ~ " + regex) -{ -} - -bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const -{ - if (!dq->ids.du) { - return false; - } - - return d_regex.match(dq->ids.du->getHTTPPath()); -} - -string HTTPPathRegexRule::toString() const -{ - return d_visual; -} - -std::unordered_map DOHUnit::getHTTPHeaders() const +const std::unordered_map& DOHUnit::getHTTPHeaders() const { - std::unordered_map results; - if (headers) { - results.reserve(headers->size()); - - for (const auto& header : *headers) { - results.insert(header); - } + if (!headers) { + static const HeadersMap empty{}; + return empty; } - - return results; + return *headers; } std::string DOHUnit::getHTTPPath() const @@ -1193,12 +1178,12 @@ std::string DOHUnit::getHTTPPath() const } } -std::string DOHUnit::getHTTPHost() const +const std::string& DOHUnit::getHTTPHost() const { return host; } -std::string DOHUnit::getHTTPScheme() const +const std::string& DOHUnit::getHTTPScheme() const { return scheme; } @@ -1280,7 +1265,7 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) memory and likely coming up too late after the client has gone away */ auto* dsc = static_cast(listener->data); while (true) { - std::unique_ptr du{nullptr}; + DOHUnitUniquePtr du{nullptr}; try { auto tmp = dsc->d_responseReceiver.receive(); if (!tmp) { @@ -1300,10 +1285,10 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) if (!du->tcp && du->truncated && - du->query.size() > du->proxyProtocolPayloadSize && - (du->query.size() - du->proxyProtocolPayloadSize) > sizeof(dnsheader)) { + du->query.size() > du->ids.d_proxyProtocolPayloadSize && + (du->query.size() - du->ids.d_proxyProtocolPayloadSize) > sizeof(dnsheader)) { /* restoring the original ID */ - dnsheader* queryDH = reinterpret_cast(du->query.data() + du->proxyProtocolPayloadSize); + dnsheader* queryDH = reinterpret_cast(du->query.data() + du->ids.d_proxyProtocolPayloadSize); queryDH->id = du->ids.origID; du->ids.forwardedOverUDP = false; du->tcp = true; @@ -1494,84 +1479,22 @@ static void setupAcceptContext(DOHAcceptContext& ctx, DOHServerConfig& dsc, bool auto nativeCtx = ctx.get(); nativeCtx->ctx = &dsc.h2o_ctx; nativeCtx->hosts = dsc.h2o_config.hosts; - ctx.d_ticketsKeyRotationDelay = dsc.df->d_tlsConfig.d_ticketsKeyRotationDelay; + auto df = std::atomic_load_explicit(&dsc.df, std::memory_order_acquire); + ctx.d_ticketsKeyRotationDelay = df->d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay; - if (setupTLS && dsc.df->isHTTPS()) { + if (setupTLS && df->isHTTPS()) { try { setupTLSContext(ctx, - dsc.df->d_tlsConfig, - dsc.df->d_tlsCounters); + df->d_tlsContext.d_tlsConfig, + df->d_tlsContext.d_tlsCounters); } catch (const std::runtime_error& e) { - throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dsc.df->d_local.toStringWithPort() + "': " + e.what()); + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + df->d_tlsContext.d_addr.toStringWithPort() + "': " + e.what()); } } ctx.d_cs = dsc.cs; } -void DOHFrontend::rotateTicketsKey(time_t now) -{ - if (d_dsc && d_dsc->accept_ctx) { - d_dsc->accept_ctx->rotateTicketsKey(now); - } -} - -void DOHFrontend::loadTicketsKeys(const std::string& keyFile) -{ - if (d_dsc && d_dsc->accept_ctx) { - d_dsc->accept_ctx->loadTicketsKeys(keyFile); - } -} - -void DOHFrontend::handleTicketsKeyRotation() -{ - if (d_dsc && d_dsc->accept_ctx) { - d_dsc->accept_ctx->handleTicketsKeyRotation(); - } -} - -time_t DOHFrontend::getNextTicketsKeyRotation() const -{ - if (d_dsc && d_dsc->accept_ctx) { - return d_dsc->accept_ctx->getNextTicketsKeyRotation(); - } - return 0; -} - -size_t DOHFrontend::getTicketsKeysCount() const -{ - size_t res = 0; - if (d_dsc && d_dsc->accept_ctx) { - res = d_dsc->accept_ctx->getTicketsKeysCount(); - } - return res; -} - -void DOHFrontend::reloadCertificates() -{ - auto newAcceptContext = std::make_shared(); - setupAcceptContext(*newAcceptContext, *d_dsc, true); - std::atomic_store_explicit(&d_dsc->accept_ctx, newAcceptContext, std::memory_order_release); -} - -void DOHFrontend::setup() -{ - registerOpenSSLUser(); - - d_dsc = std::make_shared(d_idleTimeout, d_internalPipeBufferSize); - - if (isHTTPS()) { - try { - setupTLSContext(*d_dsc->accept_ctx, - d_tlsConfig, - d_tlsCounters); - } - catch (const std::runtime_error& e) { - throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_local.toStringWithPort() + "': " + e.what()); - } - } -} - static h2o_pathconf_t *register_handler(h2o_hostconf_t *hostconf, const char *path, int (*on_req)(h2o_handler_t *, h2o_req_t *)) { h2o_pathconf_t *pathconf = h2o_config_register_path(hostconf, path, 0); @@ -1598,7 +1521,7 @@ void dohThread(ClientState* cs) std::shared_ptr& df = cs->dohFrontend; auto& dsc = df->d_dsc; dsc->cs = cs; - dsc->df = cs->dohFrontend; + std::atomic_store_explicit(&dsc->df, cs->dohFrontend, std::memory_order_release); dsc->h2o_config.server_name = h2o_iovec_init(df->d_serverTokens.c_str(), df->d_serverTokens.size()); #ifndef USE_SINGLE_ACCEPTOR_THREAD @@ -1609,11 +1532,11 @@ void dohThread(ClientState* cs) setThreadName("dnsdist/doh"); // I wonder if this registers an IP address.. I think it does // this may mean we need to actually register a site "name" here and not the IP address - h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(df->d_local.toString().c_str(), df->d_local.toString().size()), 65535); + h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(df->d_tlsContext.d_addr.toString().c_str(), df->d_tlsContext.d_addr.toString().size()), 65535); - for(const auto& url : df->d_urls) { + dsc->paths = df->d_urls; + for (const auto& url : dsc->paths) { register_handler(hostconf, url.c_str(), doh_handler); - dsc->paths.insert(url); } h2o_context_init(&dsc->h2o_ctx, h2o_evloop_create(), &dsc->h2o_config); @@ -1632,11 +1555,11 @@ void dohThread(ClientState* cs) setupAcceptContext(*dsc->accept_ctx, *dsc, false); if (create_listener(dsc, cs->tcpFD) != 0) { - throw std::runtime_error("DOH server failed to listen on " + df->d_local.toStringWithPort() + ": " + strerror(errno)); + throw std::runtime_error("DOH server failed to listen on " + df->d_tlsContext.d_addr.toStringWithPort() + ": " + strerror(errno)); } for (const auto& [addr, fd] : cs->d_additionalAddresses) { if (create_listener(dsc, fd) != 0) { - throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + df->d_local.toStringWithPort() + ": " + strerror(errno)); + throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + df->d_tlsContext.d_addr.toStringWithPort() + ": " + strerror(errno)); } } @@ -1661,25 +1584,31 @@ void dohThread(ClientState* cs) } } -void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, InternalQueryState&& state) +void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, InternalQueryState&& state, const std::shared_ptr&) { - du->response = std::move(udpResponse); + auto du = std::unique_ptr(this); du->ids = std::move(state); - const dnsheader* dh = reinterpret_cast(du->response.data()); - if (!dh->tc) { + { + const dnsheader* dh = reinterpret_cast(udpResponse.data()); + if (dh->tc) { + du->truncated = true; + } + } + if (!du->truncated) { static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - DNSResponse dr(du->ids, du->response, du->downstream); + DNSResponse dr(du->ids, udpResponse, du->downstream); dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); dr.ids.du = std::move(du); - if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(udpResponse, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { if (dr.ids.du) { - dr.ids.du->status_code = 503; - sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + du = getDUFromIDS(dr.ids); + du->status_code = 503; + sendDoHUnitToTheMainThread(std::move(du), "Response dropped by rules"); } return; } @@ -1688,7 +1617,8 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, return; } - du = std::move(dr.ids.du); + du = getDUFromIDS(dr.ids); + du->response = std::move(udpResponse); double udiff = du->ids.queryRealTime.udiff(); vinfolog("Got answer from %s, relayed to %s (https), took %f us", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); @@ -1699,17 +1629,72 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, ++du->ids.cs->responses; } } - else { - du->truncated = true; - } sendDoHUnitToTheMainThread(std::move(du), "DoH response"); } -#endif /* HAVE_LIBH2OEVLOOP */ -#else /* HAVE_DNS_OVER_HTTPS */ -void handleDOHTimeout(DOHUnitUniquePtr&& oldDU) +void H2ODOHFrontend::rotateTicketsKey(time_t now) +{ + if (d_dsc && d_dsc->accept_ctx) { + d_dsc->accept_ctx->rotateTicketsKey(now); + } +} + +void H2ODOHFrontend::loadTicketsKeys(const std::string& keyFile) +{ + if (d_dsc && d_dsc->accept_ctx) { + d_dsc->accept_ctx->loadTicketsKeys(keyFile); + } +} + +void H2ODOHFrontend::handleTicketsKeyRotation() +{ + if (d_dsc && d_dsc->accept_ctx) { + d_dsc->accept_ctx->handleTicketsKeyRotation(); + } +} + +std::string H2ODOHFrontend::getNextTicketsKeyRotation() const +{ + if (d_dsc && d_dsc->accept_ctx) { + return std::to_string(d_dsc->accept_ctx->getNextTicketsKeyRotation()); + } + return 0; +} + +size_t H2ODOHFrontend::getTicketsKeysCount() +{ + size_t res = 0; + if (d_dsc && d_dsc->accept_ctx) { + res = d_dsc->accept_ctx->getTicketsKeysCount(); + } + return res; +} + +void H2ODOHFrontend::reloadCertificates() +{ + auto newAcceptContext = std::make_shared(); + setupAcceptContext(*newAcceptContext, *d_dsc, true); + std::atomic_store_explicit(&d_dsc->accept_ctx, newAcceptContext, std::memory_order_release); +} + +void H2ODOHFrontend::setup() { + registerOpenSSLUser(); + + d_dsc = std::make_shared(d_idleTimeout, d_internalPipeBufferSize); + + if (isHTTPS()) { + try { + setupTLSContext(*d_dsc->accept_ctx, + d_tlsContext.d_tlsConfig, + d_tlsContext.d_tlsCounters); + } + catch (const std::runtime_error& e) { + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort() + "': " + e.what()); + } + } } -#endif /* HAVE_DNS_OVER_HTTPS */ +#endif /* HAVE_LIBH2OEVLOOP */ +#endif /* HAVE_LIBH2OEVLOOP */ diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc index c7e638b219..9d437578f7 100644 --- a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc @@ -34,39 +34,10 @@ std::vector> g_frontends; /* add stub implementations, we don't want to include the corresponding object files and their dependencies */ -#ifdef HAVE_DNS_OVER_HTTPS -std::unordered_map DOHUnit::getHTTPHeaders() const -{ - return {}; -} - -std::string DOHUnit::getHTTPPath() const -{ - return ""; -} - -std::string DOHUnit::getHTTPHost() const -{ - return ""; -} - -std::string DOHUnit::getHTTPScheme() const -{ - return ""; -} - -std::string DOHUnit::getHTTPQueryString() const -{ - return ""; -} - -void DOHUnit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body_, const std::string& contentType_) -{ -} -#endif /* HAVE_DNS_OVER_HTTPS */ - -void handleDOHTimeout(DOHUnitUniquePtr&& oldDU) +// NOLINTNEXTLINE(readability-convert-member-functions-to-static): this is a stub, the real one is not that simple.. +bool TLSFrontend::setupTLS() { + return true; } std::string DNSQuestion::getTrailingData() const diff --git a/pdns/doh.hh b/pdns/doh.hh index 6f3816c300..58a26f1691 100644 --- a/pdns/doh.hh +++ b/pdns/doh.hh @@ -21,236 +21,37 @@ */ #pragma once -#pragma once - -#include - -#include "channel.hh" -#include "iputils.hh" -#include "libssl.hh" -#include "noinitvector.hh" -#include "stat_t.hh" - -struct DOHServerConfig; - -class DOHResponseMapEntry -{ -public: - DOHResponseMapEntry(const std::string& regex, uint16_t status, const PacketBuffer& content, const boost::optional>& headers): d_regex(regex), d_customHeaders(headers), d_content(content), d_status(status) - { - if (status >= 400 && !d_content.empty() && d_content.at(d_content.size() -1) != 0) { - // we need to make sure it's null-terminated - d_content.push_back(0); - } - } - - bool matches(const std::string& path) const - { - return d_regex.match(path); - } - - uint16_t getStatusCode() const - { - return d_status; - } - - const PacketBuffer& getContent() const - { - return d_content; - } - - const boost::optional>& getHeaders() const - { - return d_customHeaders; - } - -private: - Regex d_regex; - boost::optional> d_customHeaders; - PacketBuffer d_content; - uint16_t d_status; -}; - -struct DOHFrontend -{ - DOHFrontend() - { - } - - std::shared_ptr d_dsc{nullptr}; - std::shared_ptr>> d_responsesMap; - TLSConfig d_tlsConfig; - TLSErrorCounters d_tlsCounters; - std::string d_serverTokens{"h2o/dnsdist"}; - std::unordered_map d_customResponseHeaders; - ComboAddress d_local; - - uint32_t d_idleTimeout{30}; // HTTP idle timeout in seconds - std::vector d_urls; - - pdns::stat_t d_httpconnects{0}; // number of TCP/IP connections established - pdns::stat_t d_getqueries{0}; // valid DNS queries received via GET - pdns::stat_t d_postqueries{0}; // valid DNS queries received via POST - pdns::stat_t d_badrequests{0}; // request could not be converted to dns query - pdns::stat_t d_errorresponses{0}; // dnsdist set 'error' on response - pdns::stat_t d_redirectresponses{0}; // dnsdist set 'redirect' on response - pdns::stat_t d_validresponses{0}; // valid responses sent out - - struct HTTPVersionStats - { - pdns::stat_t d_nbQueries{0}; // valid DNS queries received - pdns::stat_t d_nb200Responses{0}; - pdns::stat_t d_nb400Responses{0}; - pdns::stat_t d_nb403Responses{0}; - pdns::stat_t d_nb500Responses{0}; - pdns::stat_t d_nb502Responses{0}; - pdns::stat_t d_nbOtherResponses{0}; - }; - - HTTPVersionStats d_http1Stats; - HTTPVersionStats d_http2Stats; -#ifdef __linux__ - // On Linux this gives us 128k pending queries (default is 8192 queries), - // which should be enough to deal with huge spikes - uint32_t d_internalPipeBufferSize{1024*1024}; -#else - uint32_t d_internalPipeBufferSize{0}; -#endif - bool d_sendCacheControlHeaders{true}; - bool d_trustForwardedForHeader{false}; - /* whether we require tue query path to exactly match one of configured ones, - or accept everything below these paths. */ - bool d_exactPathMatching{true}; - bool d_keepIncomingHeaders{false}; - - time_t getTicketsKeyRotationDelay() const - { - return d_tlsConfig.d_ticketsKeyRotationDelay; - } - - bool isHTTPS() const - { - return !d_tlsConfig.d_certKeyPairs.empty(); - } - -#ifndef HAVE_DNS_OVER_HTTPS - void setup() - { - } - - void reloadCertificates() - { - } +#include "config.h" - void rotateTicketsKey(time_t /* now */) - { - } - - void loadTicketsKeys(const std::string& /* keyFile */) - { - } - - void handleTicketsKeyRotation() - { - } - - time_t getNextTicketsKeyRotation() const - { - return 0; - } - - size_t getTicketsKeysCount() const - { - size_t res = 0; - return res; - } - -#else - void setup(); - void reloadCertificates(); - - void rotateTicketsKey(time_t now); - void loadTicketsKeys(const std::string& keyFile); - void handleTicketsKeyRotation(); - time_t getNextTicketsKeyRotation() const; - size_t getTicketsKeysCount() const; -#endif /* HAVE_DNS_OVER_HTTPS */ -}; +#ifdef HAVE_DNS_OVER_HTTPS +#ifdef HAVE_LIBH2OEVLOOP -#ifndef HAVE_DNS_OVER_HTTPS -struct DOHUnit -{ - size_t proxyProtocolPayloadSize{0}; - uint16_t status_code{200}; -}; +#include +#include +#include -#else /* HAVE_DNS_OVER_HTTPS */ -#ifdef HAVE_LIBH2OEVLOOP -#include +struct CrossProtocolQuery; +struct DNSQuestion; -#include "dnsdist-idstate.hh" +std::unique_ptr getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse); -struct st_h2o_req_t; -struct DownstreamState; +#include "dnsdist-doh-common.hh" -struct DOHUnit +struct H2ODOHFrontend : public DOHFrontend { - DOHUnit(PacketBuffer&& q, std::string&& p, std::string&& h): path(std::move(p)), host(std::move(h)), query(std::move(q)) - { - ids.ednsAdded = false; - } - - DOHUnit(const DOHUnit&) = delete; - DOHUnit& operator=(const DOHUnit&) = delete; +public: - InternalQueryState ids; - std::string sni; - std::string path; - std::string scheme; - std::string host; - std::string contentType; - PacketBuffer query; - PacketBuffer response; - std::shared_ptr downstream{nullptr}; - std::unique_ptr> headers; - st_h2o_req_t* req{nullptr}; - DOHUnit** self{nullptr}; - DOHServerConfig* dsc{nullptr}; - pdns::channel::Sender* responseSender{nullptr}; - size_t query_at{0}; - size_t proxyProtocolPayloadSize{0}; - int rsock{-1}; - /* the status_code is set from - processDOHQuery() (which is executed in - the DOH client thread) so that the correct - response can be sent in on_dnsdist(), - after the DOHUnit has been passed back to - the main DoH thread. - */ - uint16_t status_code{200}; - /* whether the query was re-sent to the backend over - TCP after receiving a truncated answer over UDP */ - bool tcp{false}; - bool truncated{false}; + void setup() override; + void reloadCertificates() override; - std::string getHTTPPath() const; - std::string getHTTPHost() const; - std::string getHTTPScheme() const; - std::string getHTTPQueryString() const; - std::unordered_map getHTTPHeaders() const; - void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType=""); + void rotateTicketsKey(time_t now) override; + void loadTicketsKeys(const std::string& keyFile) override; + void handleTicketsKeyRotation() override; + std::string getNextTicketsKeyRotation() const override; + size_t getTicketsKeysCount() override; }; -void handleUDPResponseForDoH(std::unique_ptr&&, PacketBuffer&& response, InternalQueryState&& state); - -struct CrossProtocolQuery; -struct DNSQuestion; - -std::unique_ptr getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse); +void dohThread(ClientState* clientState); #endif /* HAVE_LIBH2OEVLOOP */ #endif /* HAVE_DNS_OVER_HTTPS */ - -using DOHUnitUniquePtr = std::unique_ptr; - -void handleDOHTimeout(DOHUnitUniquePtr&& oldDU); diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index c4fe42b8aa..c51a930c04 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -54,9 +54,9 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs return false; } -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest) +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query) { - return false; + return true; } namespace dnsdist {