From: Remi Gacogne Date: Mon, 3 Mar 2025 10:57:54 +0000 (+0100) Subject: dnsdist: Share tickets key between identical frontends created via YAML X-Git-Tag: dnsdist-2.0.0-alpha2~88^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d0cf129d3871c7d30b6bef4b4143b7af52d43355;p=thirdparty%2Fpdns.git dnsdist: Share tickets key between identical frontends created via YAML Using the same Session Ticket Encryption Key on identical frontends allow TLS sessions to be resumed in a much more efficient way, reducing the latency and CPU usage. While it was already possible to do so by manually managing the STEK, the default behaviour was to create and use a different STEK for each frontend, because our Lua configuration makes it almost impossible to ensure that two frontends are identical. This is not an issue with the new YAML configuration format, so let's share the STEK automatically in this case. This needs a regression test. --- diff --git a/pdns/dnsdistdist/dnsdist-carbon.cc b/pdns/dnsdistdist/dnsdist-carbon.cc index 596e0eae10..b8daf13672 100644 --- a/pdns/dnsdistdist/dnsdist-carbon.cc +++ b/pdns/dnsdistdist/dnsdist-carbon.cc @@ -164,7 +164,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint) errorCounters = &front->tlsFrontend->d_tlsCounters; } else if (front->dohFrontend != nullptr) { - errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters; + errorCounters = &front->dohFrontend->d_tlsContext->d_tlsCounters; } if (errorCounters != nullptr) { str << base << "tlsdhkeytoosmall" << ' ' << errorCounters->d_dhKeyTooSmall << " " << now << "\r\n"; @@ -227,7 +227,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint) std::map dohFrontendDuplicates; const string base = "dnsdist." + hostname + ".main.doh."; for (const auto& doh : dnsdist::getDoHFrontends()) { - string name = doh->d_tlsContext.d_addr.toStringWithPort(); + string name = doh->d_tlsContext->d_addr.toStringWithPort(); std::replace(name.begin(), name.end(), '.', '_'); std::replace(name.begin(), name.end(), ':', '_'); std::replace(name.begin(), name.end(), '[', '_'); diff --git a/pdns/dnsdistdist/dnsdist-configuration-yaml.cc b/pdns/dnsdistdist/dnsdist-configuration-yaml.cc index 7f0e633d94..3fad3db70a 100644 --- a/pdns/dnsdistdist/dnsdist-configuration-yaml.cc +++ b/pdns/dnsdistdist/dnsdist-configuration-yaml.cc @@ -239,7 +239,7 @@ static bool validateTLSConfiguration(const dnsdist::rust::settings::BindConfigur return true; } -static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfiguration& bind, ClientState& state) +static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfiguration& bind, ClientState& state, std::shared_ptr parent) { auto tlsConfig = getTLSConfigFromRustIncomingTLS(bind.tls); if (!validateTLSConfiguration(bind, tlsConfig)) { @@ -249,6 +249,7 @@ static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfigurat auto protocol = boost::to_lower_copy(std::string(bind.protocol)); if (protocol == "dot") { auto frontend = std::make_shared(TLSFrontend::ALPN::DoT); + frontend->setParent(parent); frontend->d_provider = std::string(bind.tls.provider); boost::algorithm::to_lower(frontend->d_provider); frontend->d_proxyProtocolOutsideTLS = bind.tls.proxy_protocol_outside_tls; @@ -286,8 +287,9 @@ static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfigurat #endif /* HAVE_DNS_OVER_HTTP3 */ else if (protocol == "doh") { auto frontend = std::make_shared(); - frontend->d_tlsContext.d_provider = std::string(bind.tls.provider); - boost::algorithm::to_lower(frontend->d_tlsContext.d_provider); + auto& tlsContext = frontend->d_tlsContext; + tlsContext->d_provider = std::string(bind.tls.provider); + boost::algorithm::to_lower(tlsContext->d_provider); frontend->d_library = std::string(bind.doh.provider); if (frontend->d_library == "h2o") { #ifdef HAVE_LIBH2OEVLOOP @@ -343,16 +345,17 @@ static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfigurat } if (!tlsConfig.d_certKeyPairs.empty()) { - frontend->d_tlsContext.d_addr = ComboAddress(std::string(bind.listen_address), 443); + tlsContext->d_addr = ComboAddress(std::string(bind.listen_address), 443); infolog("DNS over HTTPS configured"); } else { - frontend->d_tlsContext.d_addr = ComboAddress(std::string(bind.listen_address), 80); - infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort()); + tlsContext->d_addr = ComboAddress(std::string(bind.listen_address), 80); + infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", tlsContext->d_addr.toStringWithPort()); } - frontend->d_tlsContext.d_proxyProtocolOutsideTLS = bind.tls.proxy_protocol_outside_tls; - frontend->d_tlsContext.d_tlsConfig = std::move(tlsConfig); + tlsContext->d_proxyProtocolOutsideTLS = bind.tls.proxy_protocol_outside_tls; + tlsContext->d_tlsConfig = std::move(tlsConfig); + tlsContext->setParent(parent); state.dohFrontend = std::move(frontend); } else if (protocol != "do53") { @@ -672,6 +675,7 @@ static void loadBinds(const ::rust::Vec tlsFrontendParent; for (size_t idx = 0; idx < bind.threads; idx++) { #if defined(HAVE_DNSCRYPT) std::shared_ptr dnsCryptContext; @@ -710,9 +714,12 @@ static void loadBinds(const ::rust::VecgetTLSFrontend(); + } } config.d_frontends.emplace_back(std::move(state)); diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc index c533cc7e8c..43713d7a26 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.cc +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -26,17 +26,17 @@ #ifdef HAVE_DNS_OVER_HTTPS void DOHFrontend::rotateTicketsKey(time_t now) { - return d_tlsContext.rotateTicketsKey(now); + return d_tlsContext->rotateTicketsKey(now); } void DOHFrontend::loadTicketsKeys(const std::string& keyFile) { - return d_tlsContext.loadTicketsKeys(keyFile); + return d_tlsContext->loadTicketsKeys(keyFile); } void DOHFrontend::loadTicketsKey(const std::string& key) { - return d_tlsContext.loadTicketsKey(key); + return d_tlsContext->loadTicketsKey(key); } void DOHFrontend::handleTicketsKeyRotation() @@ -45,26 +45,26 @@ void DOHFrontend::handleTicketsKeyRotation() std::string DOHFrontend::getNextTicketsKeyRotation() const { - return d_tlsContext.getNextTicketsKeyRotation(); + return d_tlsContext->getNextTicketsKeyRotation(); } size_t DOHFrontend::getTicketsKeysCount() { - return d_tlsContext.getTicketsKeysCount(); + return d_tlsContext->getTicketsKeysCount(); } void DOHFrontend::reloadCertificates() { if (isHTTPS()) { - d_tlsContext.setupTLS(); + d_tlsContext->setupTLS(); } } void DOHFrontend::setup() { if (isHTTPS()) { - if (!d_tlsContext.setupTLS()) { - throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort()); + if (!d_tlsContext->setupTLS()) { + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext->d_addr.toStringWithPort()); } } } diff --git a/pdns/dnsdistdist/dnsdist-doh-common.hh b/pdns/dnsdistdist/dnsdist-doh-common.hh index 9d0a466928..6af2f962e3 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.hh +++ b/pdns/dnsdistdist/dnsdist-doh-common.hh @@ -81,11 +81,12 @@ private: struct DOHFrontend { - DOHFrontend() + DOHFrontend() : + d_tlsContext(std::make_shared(TLSFrontend::ALPN::DoH)) { } DOHFrontend(std::shared_ptr tlsCtx) : - d_tlsContext(std::move(tlsCtx)) + d_tlsContext(std::make_shared(std::move(tlsCtx))) { } @@ -95,7 +96,7 @@ struct DOHFrontend std::shared_ptr d_dsc{nullptr}; std::shared_ptr>> d_responsesMap; - TLSFrontend d_tlsContext{TLSFrontend::ALPN::DoH}; + std::shared_ptr d_tlsContext; std::string d_serverTokens{"h2o/dnsdist"}; std::unordered_map d_customResponseHeaders; std::string d_library; @@ -141,12 +142,12 @@ struct DOHFrontend time_t getTicketsKeyRotationDelay() const { - return d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay; + return d_tlsContext->d_tlsConfig.d_ticketsKeyRotationDelay; } bool isHTTPS() const { - return !d_tlsContext.d_tlsConfig.d_certKeyPairs.empty(); + return !d_tlsContext->d_tlsConfig.d_certKeyPairs.empty(); } #ifndef HAVE_DNS_OVER_HTTPS diff --git a/pdns/dnsdistdist/dnsdist-lua-inspection.cc b/pdns/dnsdistdist/dnsdist-lua-inspection.cc index 95be35ad37..b1e6c00554 100644 --- a/pdns/dnsdistdist/dnsdist-lua-inspection.cc +++ b/pdns/dnsdistdist/dnsdist-lua-inspection.cc @@ -778,7 +778,7 @@ void setupLuaInspection(LuaContext& luaCtx) errorCounters = &frontend->tlsFrontend->d_tlsCounters; } else if (frontend->dohFrontend != nullptr) { - errorCounters = &frontend->dohFrontend->d_tlsContext.d_tlsCounters; + errorCounters = &frontend->dohFrontend->d_tlsContext->d_tlsCounters; } if (errorCounters == nullptr) { continue; diff --git a/pdns/dnsdistdist/dnsdist-lua.cc b/pdns/dnsdistdist/dnsdist-lua.cc index f202bb7efe..3504cbd6a4 100644 --- a/pdns/dnsdistdist/dnsdist-lua.cc +++ b/pdns/dnsdistdist/dnsdist-lua.cc @@ -2167,15 +2167,15 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) bool useTLS = true; if (certFiles && !certFiles->empty()) { - if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) { + if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext->d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) { return; } - frontend->d_tlsContext.d_addr = ComboAddress(addr, 443); + frontend->d_tlsContext->d_addr = ComboAddress(addr, 443); } else { - frontend->d_tlsContext.d_addr = ComboAddress(addr, 80); - infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort()); + frontend->d_tlsContext->d_addr = ComboAddress(addr, 80); + infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext->d_addr.toStringWithPort()); useTLS = false; } @@ -2208,9 +2208,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections, enableProxyProtocol); getOptionalValue(vars, "idleTimeout", frontend->d_idleTimeout); getOptionalValue(vars, "serverTokens", frontend->d_serverTokens); - getOptionalValue(vars, "provider", frontend->d_tlsContext.d_provider); - boost::algorithm::to_lower(frontend->d_tlsContext.d_provider); - getOptionalValue(vars, "proxyProtocolOutsideTLS", frontend->d_tlsContext.d_proxyProtocolOutsideTLS); + getOptionalValue(vars, "provider", frontend->d_tlsContext->d_provider); + boost::algorithm::to_lower(frontend->d_tlsContext->d_provider); + getOptionalValue(vars, "proxyProtocolOutsideTLS", frontend->d_tlsContext->d_proxyProtocolOutsideTLS); LuaAssociativeTable customResponseHeaders; if (getOptionalValue(vars, "customResponseHeaders", customResponseHeaders) > 0) { @@ -2241,7 +2241,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } } - parseTLSConfig(frontend->d_tlsContext.d_tlsConfig, "addDOHLocal", vars); + parseTLSConfig(frontend->d_tlsContext->d_tlsConfig, "addDOHLocal", vars); bool ignoreTLSConfigurationErrors = false; if (getOptionalValue(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) { @@ -2249,7 +2249,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) // and properly ignore the frontend before actually launching it try { std::map ocspResponses = {}; - auto ctx = libssl_init_server_context(frontend->d_tlsContext.d_tlsConfig, ocspResponses); + auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig, ocspResponses); } catch (const std::runtime_error& e) { errlog("Ignoring DoH frontend: '%s'", e.what()); @@ -2261,8 +2261,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } if (useTLS && frontend->d_library == "nghttp2") { - if (!frontend->d_tlsContext.d_provider.empty()) { - vinfolog("Loading TLS provider '%s'", frontend->d_tlsContext.d_provider); + if (!frontend->d_tlsContext->d_provider.empty()) { + vinfolog("Loading TLS provider '%s'", frontend->d_tlsContext->d_provider); } else { #ifdef HAVE_LIBSSL @@ -2274,7 +2274,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } } - auto clientState = std::make_shared(frontend->d_tlsContext.d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol); + auto clientState = std::make_shared(frontend->d_tlsContext->d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol); clientState->dohFrontend = std::move(frontend); clientState->d_additionalAddresses = std::move(additionalAddresses); @@ -2515,7 +2515,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) ret << (fmt % "#" % "Address" % "HTTP" % "HTTP/1" % "HTTP/2" % "GET" % "POST" % "Bad" % "Errors" % "Redirects" % "Valid" % "# ticket keys" % "Rotation delay" % "Next rotation") << endl; size_t counter = 0; for (const auto& ctx : dnsdist::getDoHFrontends()) { - ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl; + ret << (fmt % counter % ctx->d_tlsContext->d_addr.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl; counter++; } g_outputBuffer = ret.str(); @@ -2598,7 +2598,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl; size_t counter = 0; for (const auto& ctx : dnsdist::getDoHFrontends()) { - ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl; + ret << (fmt % counter % ctx->d_tlsContext->d_addr.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl; counter++; } g_outputBuffer += ret.str(); @@ -2608,7 +2608,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl; counter = 0; for (const auto& ctx : dnsdist::getDoHFrontends()) { - ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl; + ret << (fmt % counter % ctx->d_tlsContext->d_addr.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl; counter++; } g_outputBuffer += ret.str(); @@ -2663,7 +2663,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) luaCtx.registerFunction::*)(boost::variant, LuaArray, LuaArray>> certFiles, LuaTypeOrArrayOf keyFiles)>("loadNewCertificatesAndKeys", []([[maybe_unused]] const std::shared_ptr& frontend, [[maybe_unused]] const boost::variant, LuaArray, LuaArray>>& certFiles, [[maybe_unused]] const LuaTypeOrArrayOf& keyFiles) { #ifdef HAVE_DNS_OVER_HTTPS if (frontend != nullptr) { - if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) { + if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsContext->d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) { frontend->reloadCertificates(); } } diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 984f5d03ad..bf4cc48f7e 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -27,7 +27,7 @@ public: enum class QueryProcessingResult : uint8_t { Forwarded, TooSmall, InvalidHeaders, Dropped, SelfAnswered, NoBackend, Asynchronous }; enum class ProxyProtocolResult : uint8_t { Reading, Done, Error }; - IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(sizeof(uint16_t)), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{dnsdist::configuration::getCurrentRuntimeConfiguration().d_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext.getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id()) + IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(sizeof(uint16_t)), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{dnsdist::configuration::getCurrentRuntimeConfiguration().d_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext->getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id()) { d_origDest.reset(); d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family; @@ -156,7 +156,7 @@ public: if (!d_ci.cs->hasTLS()) { return false; } - return d_ci.cs->getTLSFrontend().d_proxyProtocolOutsideTLS; + return d_ci.cs->getTLSFrontend()->d_proxyProtocolOutsideTLS; } virtual bool forwardViaUDPFirst() const diff --git a/pdns/dnsdistdist/dnsdist-web.cc b/pdns/dnsdistdist/dnsdist-web.cc index 4eb57f6be5..efa70e72fd 100644 --- a/pdns/dnsdistdist/dnsdist-web.cc +++ b/pdns/dnsdistdist/dnsdist-web.cc @@ -751,7 +751,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp) errorCounters = &front->tlsFrontend->d_tlsCounters; } else if (front->dohFrontend != nullptr) { - errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters; + errorCounters = &front->dohFrontend->d_tlsContext->d_tlsCounters; } if (errorCounters != nullptr) { @@ -789,7 +789,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp) #ifdef HAVE_DNS_OVER_HTTPS std::map dohFrontendDuplicates; for(const auto& doh : dnsdist::getDoHFrontends()) { - const string frontName = doh->d_tlsContext.d_addr.toStringWithPort(); + const string frontName = doh->d_tlsContext->d_addr.toStringWithPort(); uint64_t threadNumber = 0; auto dupPair = frontendDuplicates.emplace(frontName, 1); if (!dupPair.second) { @@ -1188,7 +1188,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp) errorCounters = &front->tlsFrontend->d_tlsCounters; } else if (front->dohFrontend != nullptr) { - errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters; + errorCounters = &front->dohFrontend->d_tlsContext->d_tlsCounters; } if (errorCounters != nullptr) { frontend["tlsHandshakeFailuresDHKeyTooSmall"] = (double)errorCounters->d_dhKeyTooSmall; @@ -1212,7 +1212,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp) for (const auto& doh : dohFrontends) { dohs.emplace_back(Json::object{ {"id", num++}, - {"address", doh->d_tlsContext.d_addr.toStringWithPort()}, + {"address", doh->d_tlsContext->d_addr.toStringWithPort()}, {"http-connects", (double)doh->d_httpconnects}, {"http1-queries", (double)doh->d_http1Stats.d_nbQueries}, {"http2-queries", (double)doh->d_http2Stats.d_nbQueries}, diff --git a/pdns/dnsdistdist/dnsdist.hh b/pdns/dnsdistdist/dnsdist.hh index 20fe358d1f..0cf39e4481 100644 --- a/pdns/dnsdistdist/dnsdist.hh +++ b/pdns/dnsdistdist/dnsdist.hh @@ -396,10 +396,10 @@ struct ClientState return tlsFrontend != nullptr || (dohFrontend != nullptr && dohFrontend->isHTTPS()); } - const TLSFrontend& getTLSFrontend() const + const std::shared_ptr getTLSFrontend() const { if (tlsFrontend != nullptr) { - return *tlsFrontend; + return tlsFrontend; } if (dohFrontend) { return dohFrontend->d_tlsContext; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index 62f564aaac..fc03c86430 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -1530,16 +1530,16 @@ static void setupAcceptContext(DOHAcceptContext& ctx, DOHServerConfig& dsc, bool nativeCtx->ctx = &dsc.h2o_ctx; nativeCtx->hosts = dsc.h2o_config.hosts; auto dohFrontend = std::atomic_load_explicit(&dsc.dohFrontend, std::memory_order_acquire); - ctx.d_ticketsKeyRotationDelay = dohFrontend->d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay; + ctx.d_ticketsKeyRotationDelay = dohFrontend->d_tlsContext->d_tlsConfig.d_ticketsKeyRotationDelay; if (setupTLS && dohFrontend->isHTTPS()) { try { setupTLSContext(ctx, - dohFrontend->d_tlsContext.d_tlsConfig, - dohFrontend->d_tlsContext.d_tlsCounters); + dohFrontend->d_tlsContext->d_tlsConfig, + dohFrontend->d_tlsContext->d_tlsCounters); } catch (const std::runtime_error& e) { - throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + "': " + e.what()); + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + "': " + e.what()); } } ctx.d_cs = dsc.clientState; @@ -1582,7 +1582,7 @@ void dohThread(ClientState* clientState) setThreadName("dnsdist/doh"); // I wonder if this registers an IP address.. I think it does // this may mean we need to actually register a site "name" here and not the IP address - h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(dohFrontend->d_tlsContext.d_addr.toString().c_str(), dohFrontend->d_tlsContext.d_addr.toString().size()), 65535); + h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(dohFrontend->d_tlsContext->d_addr.toString().c_str(), dohFrontend->d_tlsContext->d_addr.toString().size()), 65535); dsc->paths = dohFrontend->d_urls; for (const auto& url : dsc->paths) { @@ -1606,11 +1606,11 @@ void dohThread(ClientState* clientState) setupAcceptContext(*dsc->accept_ctx, *dsc, false); if (create_listener(dsc, clientState->tcpFD) != 0) { - throw std::runtime_error("DOH server failed to listen on " + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + ": " + stringerror(errno)); + throw std::runtime_error("DOH server failed to listen on " + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + ": " + stringerror(errno)); } for (const auto& [addr, descriptor] : clientState->d_additionalAddresses) { if (create_listener(dsc, descriptor) != 0) { - throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + ": " + stringerror(errno)); + throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + ": " + stringerror(errno)); } } @@ -1736,11 +1736,11 @@ void H2ODOHFrontend::setup() if (isHTTPS()) { try { setupTLSContext(*d_dsc->accept_ctx, - d_tlsContext.d_tlsConfig, - d_tlsContext.d_tlsCounters); + d_tlsContext->d_tlsConfig, + d_tlsContext->d_tlsCounters); } catch (const std::runtime_error& e) { - throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort() + "': " + e.what()); + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext->d_addr.toStringWithPort() + "': " + e.what()); } } } diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 0b07569d14..9379576b60 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -1883,6 +1883,14 @@ bool TLSFrontend::setupTLS() { #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) std::shared_ptr newCtx{nullptr}; + if (d_parentFrontend) { + newCtx = d_parentFrontend->getContext(); + if (newCtx) { + std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release); + return true; + } + } + /* get the "best" available provider */ #if defined(HAVE_GNUTLS) if (d_provider == "gnutls") { diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index 782aada4e2..9450b61180 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -155,30 +155,35 @@ public: void rotateTicketsKey(time_t now) { - if (d_ctx != nullptr) { + if (d_ctx != nullptr && d_parentFrontend == nullptr) { d_ctx->rotateTicketsKey(now); } } void loadTicketsKeys(const std::string& file) { - if (d_ctx != nullptr) { + if (d_ctx != nullptr && d_parentFrontend == nullptr) { d_ctx->loadTicketsKeys(file); } } void loadTicketsKey(const std::string& key) { - if (d_ctx != nullptr) { + if (d_ctx != nullptr && d_parentFrontend == nullptr) { d_ctx->loadTicketsKey(key); } } - std::shared_ptr getContext() + std::shared_ptr getContext() const { return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire); } + void setParent(std::shared_ptr parent) + { + std::atomic_store_explicit(&d_parentFrontend, std::move(parent), std::memory_order_release); + } + void cleanup() { d_ctx.reset(); @@ -242,6 +247,7 @@ public: bool d_proxyProtocolOutsideTLS{false}; protected: std::shared_ptr d_ctx{nullptr}; + std::shared_ptr d_parentFrontend{nullptr}; }; class TCPIOHandler