]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Share tickets key between identical frontends created via YAML
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 3 Mar 2025 10:57:54 +0000 (11:57 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 2 Apr 2025 10:28:57 +0000 (12:28 +0200)
Using the same Session Ticket Encryption Key on identical frontends
allow TLS sessions to be resumed in a much more efficient way, reducing
the latency and CPU usage. While it was already possible to do so by
manually managing the STEK, the default behaviour was to create and use
a different STEK for each frontend, because our Lua configuration makes
it almost impossible to ensure that two frontends are identical.
This is not an issue with the new YAML configuration format, so let's
share the STEK automatically in this case.

This needs a regression test.

12 files changed:
pdns/dnsdistdist/dnsdist-carbon.cc
pdns/dnsdistdist/dnsdist-configuration-yaml.cc
pdns/dnsdistdist/dnsdist-doh-common.cc
pdns/dnsdistdist/dnsdist-doh-common.hh
pdns/dnsdistdist/dnsdist-lua-inspection.cc
pdns/dnsdistdist/dnsdist-lua.cc
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-web.cc
pdns/dnsdistdist/dnsdist.hh
pdns/dnsdistdist/doh.cc
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh

index 596e0eae1091f38cc9d618348dc0b08a4b79c1ca..b8daf13672b469824b166854466836e7d51092c8 100644 (file)
@@ -164,7 +164,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint)
         errorCounters = &front->tlsFrontend->d_tlsCounters;
       }
       else if (front->dohFrontend != nullptr) {
-        errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters;
+        errorCounters = &front->dohFrontend->d_tlsContext->d_tlsCounters;
       }
       if (errorCounters != nullptr) {
         str << base << "tlsdhkeytoosmall" << ' ' << errorCounters->d_dhKeyTooSmall << " " << now << "\r\n";
@@ -227,7 +227,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint)
       std::map<std::string, uint64_t> dohFrontendDuplicates;
       const string base = "dnsdist." + hostname + ".main.doh.";
       for (const auto& doh : dnsdist::getDoHFrontends()) {
-        string name = doh->d_tlsContext.d_addr.toStringWithPort();
+        string name = doh->d_tlsContext->d_addr.toStringWithPort();
         std::replace(name.begin(), name.end(), '.', '_');
         std::replace(name.begin(), name.end(), ':', '_');
         std::replace(name.begin(), name.end(), '[', '_');
index 7f0e633d9420872e4d64291b26b39f28eaeca6bd..3fad3db70af9e20f03f562b52b15039745c1d4ae 100644 (file)
@@ -239,7 +239,7 @@ static bool validateTLSConfiguration(const dnsdist::rust::settings::BindConfigur
   return true;
 }
 
-static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfiguration& bind, ClientState& state)
+static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfiguration& bind, ClientState& state, std::shared_ptr<const TLSFrontend> parent)
 {
   auto tlsConfig = getTLSConfigFromRustIncomingTLS(bind.tls);
   if (!validateTLSConfiguration(bind, tlsConfig)) {
@@ -249,6 +249,7 @@ static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfigurat
   auto protocol = boost::to_lower_copy(std::string(bind.protocol));
   if (protocol == "dot") {
     auto frontend = std::make_shared<TLSFrontend>(TLSFrontend::ALPN::DoT);
+    frontend->setParent(parent);
     frontend->d_provider = std::string(bind.tls.provider);
     boost::algorithm::to_lower(frontend->d_provider);
     frontend->d_proxyProtocolOutsideTLS = bind.tls.proxy_protocol_outside_tls;
@@ -286,8 +287,9 @@ static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfigurat
 #endif /* HAVE_DNS_OVER_HTTP3 */
   else if (protocol == "doh") {
     auto frontend = std::make_shared<DOHFrontend>();
-    frontend->d_tlsContext.d_provider = std::string(bind.tls.provider);
-    boost::algorithm::to_lower(frontend->d_tlsContext.d_provider);
+    auto& tlsContext = frontend->d_tlsContext;
+    tlsContext->d_provider = std::string(bind.tls.provider);
+    boost::algorithm::to_lower(tlsContext->d_provider);
     frontend->d_library = std::string(bind.doh.provider);
     if (frontend->d_library == "h2o") {
 #ifdef HAVE_LIBH2OEVLOOP
@@ -343,16 +345,17 @@ static bool handleTLSConfiguration(const dnsdist::rust::settings::BindConfigurat
     }
 
     if (!tlsConfig.d_certKeyPairs.empty()) {
-      frontend->d_tlsContext.d_addr = ComboAddress(std::string(bind.listen_address), 443);
+      tlsContext->d_addr = ComboAddress(std::string(bind.listen_address), 443);
       infolog("DNS over HTTPS configured");
     }
     else {
-      frontend->d_tlsContext.d_addr = ComboAddress(std::string(bind.listen_address), 80);
-      infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort());
+      tlsContext->d_addr = ComboAddress(std::string(bind.listen_address), 80);
+      infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", tlsContext->d_addr.toStringWithPort());
     }
 
-    frontend->d_tlsContext.d_proxyProtocolOutsideTLS = bind.tls.proxy_protocol_outside_tls;
-    frontend->d_tlsContext.d_tlsConfig = std::move(tlsConfig);
+    tlsContext->d_proxyProtocolOutsideTLS = bind.tls.proxy_protocol_outside_tls;
+    tlsContext->d_tlsConfig = std::move(tlsConfig);
+    tlsContext->setParent(parent);
     state.dohFrontend = std::move(frontend);
   }
   else if (protocol != "do53") {
@@ -672,6 +675,7 @@ static void loadBinds(const ::rust::Vec<dnsdist::rust::settings::BindConfigurati
         }
       }
 
+      std::shared_ptr<const TLSFrontend> tlsFrontendParent;
       for (size_t idx = 0; idx < bind.threads; idx++) {
 #if defined(HAVE_DNSCRYPT)
         std::shared_ptr<DNSCryptContext> dnsCryptContext;
@@ -710,9 +714,12 @@ static void loadBinds(const ::rust::Vec<dnsdist::rust::settings::BindConfigurati
 #endif /* defined(HAVE_DNSCRYPT) */
         }
         else if (protocol != "do53") {
-          if (!handleTLSConfiguration(bind, *state)) {
+          if (!handleTLSConfiguration(bind, *state, tlsFrontendParent)) {
             continue;
           }
+          if (tlsFrontendParent == nullptr) {
+            tlsFrontendParent = state->getTLSFrontend();
+          }
         }
 
         config.d_frontends.emplace_back(std::move(state));
index c533cc7e8cd648fb5610cc66fc8ccaaf5ea64b4d..43713d7a2653c07c0f8ca423759f9453cb29306b 100644 (file)
 #ifdef HAVE_DNS_OVER_HTTPS
 void DOHFrontend::rotateTicketsKey(time_t now)
 {
-  return d_tlsContext.rotateTicketsKey(now);
+  return d_tlsContext->rotateTicketsKey(now);
 }
 
 void DOHFrontend::loadTicketsKeys(const std::string& keyFile)
 {
-  return d_tlsContext.loadTicketsKeys(keyFile);
+  return d_tlsContext->loadTicketsKeys(keyFile);
 }
 
 void DOHFrontend::loadTicketsKey(const std::string& key)
 {
-  return d_tlsContext.loadTicketsKey(key);
+  return d_tlsContext->loadTicketsKey(key);
 }
 
 void DOHFrontend::handleTicketsKeyRotation()
@@ -45,26 +45,26 @@ void DOHFrontend::handleTicketsKeyRotation()
 
 std::string DOHFrontend::getNextTicketsKeyRotation() const
 {
-  return d_tlsContext.getNextTicketsKeyRotation();
+  return d_tlsContext->getNextTicketsKeyRotation();
 }
 
 size_t DOHFrontend::getTicketsKeysCount()
 {
-  return d_tlsContext.getTicketsKeysCount();
+  return d_tlsContext->getTicketsKeysCount();
 }
 
 void DOHFrontend::reloadCertificates()
 {
   if (isHTTPS()) {
-    d_tlsContext.setupTLS();
+    d_tlsContext->setupTLS();
   }
 }
 
 void DOHFrontend::setup()
 {
   if (isHTTPS()) {
-    if (!d_tlsContext.setupTLS()) {
-      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort());
+    if (!d_tlsContext->setupTLS()) {
+      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext->d_addr.toStringWithPort());
     }
   }
 }
index 9d0a4669288ffe3db4aa939b76eaa1ee6b83ade7..6af2f962e329c62c72b7c80ca8615c4028a66965 100644 (file)
@@ -81,11 +81,12 @@ private:
 
 struct DOHFrontend
 {
-  DOHFrontend()
+  DOHFrontend() :
+    d_tlsContext(std::make_shared<TLSFrontend>(TLSFrontend::ALPN::DoH))
   {
   }
   DOHFrontend(std::shared_ptr<TLSCtx> tlsCtx) :
-    d_tlsContext(std::move(tlsCtx))
+    d_tlsContext(std::make_shared<TLSFrontend>(std::move(tlsCtx)))
   {
   }
 
@@ -95,7 +96,7 @@ struct DOHFrontend
 
   std::shared_ptr<DOHServerConfig> d_dsc{nullptr};
   std::shared_ptr<std::vector<std::shared_ptr<DOHResponseMapEntry>>> d_responsesMap;
-  TLSFrontend d_tlsContext{TLSFrontend::ALPN::DoH};
+  std::shared_ptr<TLSFrontend> d_tlsContext;
   std::string d_serverTokens{"h2o/dnsdist"};
   std::unordered_map<std::string, std::string> d_customResponseHeaders;
   std::string d_library;
@@ -141,12 +142,12 @@ struct DOHFrontend
 
   time_t getTicketsKeyRotationDelay() const
   {
-    return d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay;
+    return d_tlsContext->d_tlsConfig.d_ticketsKeyRotationDelay;
   }
 
   bool isHTTPS() const
   {
-    return !d_tlsContext.d_tlsConfig.d_certKeyPairs.empty();
+    return !d_tlsContext->d_tlsConfig.d_certKeyPairs.empty();
   }
 
 #ifndef HAVE_DNS_OVER_HTTPS
index 95be35ad37cbf77a25d660ece66668deee1cbdf9..b1e6c005543474bc4ea16958a6e6fb5057bcafae 100644 (file)
@@ -778,7 +778,7 @@ void setupLuaInspection(LuaContext& luaCtx)
         errorCounters = &frontend->tlsFrontend->d_tlsCounters;
       }
       else if (frontend->dohFrontend != nullptr) {
-        errorCounters = &frontend->dohFrontend->d_tlsContext.d_tlsCounters;
+        errorCounters = &frontend->dohFrontend->d_tlsContext->d_tlsCounters;
       }
       if (errorCounters == nullptr) {
         continue;
index f202bb7efe72d8e1a49186ea2ed851f006b35552..3504cbd6a408277776f5cf0f6a5b2ed2ebeb21cc 100644 (file)
@@ -2167,15 +2167,15 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 
     bool useTLS = true;
     if (certFiles && !certFiles->empty()) {
-      if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) {
+      if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext->d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) {
         return;
       }
 
-      frontend->d_tlsContext.d_addr = ComboAddress(addr, 443);
+      frontend->d_tlsContext->d_addr = ComboAddress(addr, 443);
     }
     else {
-      frontend->d_tlsContext.d_addr = ComboAddress(addr, 80);
-      infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort());
+      frontend->d_tlsContext->d_addr = ComboAddress(addr, 80);
+      infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext->d_addr.toStringWithPort());
       useTLS = false;
     }
 
@@ -2208,9 +2208,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections, enableProxyProtocol);
       getOptionalValue<int>(vars, "idleTimeout", frontend->d_idleTimeout);
       getOptionalValue<std::string>(vars, "serverTokens", frontend->d_serverTokens);
-      getOptionalValue<std::string>(vars, "provider", frontend->d_tlsContext.d_provider);
-      boost::algorithm::to_lower(frontend->d_tlsContext.d_provider);
-      getOptionalValue<bool>(vars, "proxyProtocolOutsideTLS", frontend->d_tlsContext.d_proxyProtocolOutsideTLS);
+      getOptionalValue<std::string>(vars, "provider", frontend->d_tlsContext->d_provider);
+      boost::algorithm::to_lower(frontend->d_tlsContext->d_provider);
+      getOptionalValue<bool>(vars, "proxyProtocolOutsideTLS", frontend->d_tlsContext->d_proxyProtocolOutsideTLS);
 
       LuaAssociativeTable<std::string> customResponseHeaders;
       if (getOptionalValue<decltype(customResponseHeaders)>(vars, "customResponseHeaders", customResponseHeaders) > 0) {
@@ -2241,7 +2241,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         }
       }
 
-      parseTLSConfig(frontend->d_tlsContext.d_tlsConfig, "addDOHLocal", vars);
+      parseTLSConfig(frontend->d_tlsContext->d_tlsConfig, "addDOHLocal", vars);
 
       bool ignoreTLSConfigurationErrors = false;
       if (getOptionalValue<bool>(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) {
@@ -2249,7 +2249,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         // and properly ignore the frontend before actually launching it
         try {
           std::map<int, std::string> ocspResponses = {};
-          auto ctx = libssl_init_server_context(frontend->d_tlsContext.d_tlsConfig, ocspResponses);
+          auto ctx = libssl_init_server_context(frontend->d_tlsContext->d_tlsConfig, ocspResponses);
         }
         catch (const std::runtime_error& e) {
           errlog("Ignoring DoH frontend: '%s'", e.what());
@@ -2261,8 +2261,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
 
     if (useTLS && frontend->d_library == "nghttp2") {
-      if (!frontend->d_tlsContext.d_provider.empty()) {
-        vinfolog("Loading TLS provider '%s'", frontend->d_tlsContext.d_provider);
+      if (!frontend->d_tlsContext->d_provider.empty()) {
+        vinfolog("Loading TLS provider '%s'", frontend->d_tlsContext->d_provider);
       }
       else {
 #ifdef HAVE_LIBSSL
@@ -2274,7 +2274,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       }
     }
 
-    auto clientState = std::make_shared<ClientState>(frontend->d_tlsContext.d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
+    auto clientState = std::make_shared<ClientState>(frontend->d_tlsContext->d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
     clientState->dohFrontend = std::move(frontend);
     clientState->d_additionalAddresses = std::move(additionalAddresses);
 
@@ -2515,7 +2515,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       ret << (fmt % "#" % "Address" % "HTTP" % "HTTP/1" % "HTTP/2" % "GET" % "POST" % "Bad" % "Errors" % "Redirects" % "Valid" % "# ticket keys" % "Rotation delay" % "Next rotation") << endl;
       size_t counter = 0;
       for (const auto& ctx : dnsdist::getDoHFrontends()) {
-        ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl;
+        ret << (fmt % counter % ctx->d_tlsContext->d_addr.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl;
         counter++;
       }
       g_outputBuffer = ret.str();
@@ -2598,7 +2598,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl;
       size_t counter = 0;
       for (const auto& ctx : dnsdist::getDoHFrontends()) {
-        ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl;
+        ret << (fmt % counter % ctx->d_tlsContext->d_addr.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl;
         counter++;
       }
       g_outputBuffer += ret.str();
@@ -2608,7 +2608,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl;
       counter = 0;
       for (const auto& ctx : dnsdist::getDoHFrontends()) {
-        ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl;
+        ret << (fmt % counter % ctx->d_tlsContext->d_addr.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl;
         counter++;
       }
       g_outputBuffer += ret.str();
@@ -2663,7 +2663,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
   luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(boost::variant<std::string, std::shared_ptr<TLSCertKeyPair>, LuaArray<std::string>, LuaArray<std::shared_ptr<TLSCertKeyPair>>> certFiles, LuaTypeOrArrayOf<std::string> keyFiles)>("loadNewCertificatesAndKeys", []([[maybe_unused]] const std::shared_ptr<DOHFrontend>& frontend, [[maybe_unused]] const boost::variant<std::string, std::shared_ptr<TLSCertKeyPair>, LuaArray<std::string>, LuaArray<std::shared_ptr<TLSCertKeyPair>>>& certFiles, [[maybe_unused]] const LuaTypeOrArrayOf<std::string>& keyFiles) {
 #ifdef HAVE_DNS_OVER_HTTPS
     if (frontend != nullptr) {
-      if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) {
+      if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsContext->d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) {
         frontend->reloadCertificates();
       }
     }
index 984f5d03adfa111e466de67439d649f3dcb79715..bf4cc48f7e802714a9effae69511cbdb97d351d7 100644 (file)
@@ -27,7 +27,7 @@ public:
   enum class QueryProcessingResult : uint8_t { Forwarded, TooSmall, InvalidHeaders, Dropped, SelfAnswered, NoBackend, Asynchronous };
   enum class ProxyProtocolResult : uint8_t { Reading, Done, Error };
 
-  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(sizeof(uint16_t)), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{dnsdist::configuration::getCurrentRuntimeConfiguration().d_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext.getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(sizeof(uint16_t)), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{dnsdist::configuration::getCurrentRuntimeConfiguration().d_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext->getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
   {
     d_origDest.reset();
     d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
@@ -156,7 +156,7 @@ public:
     if (!d_ci.cs->hasTLS()) {
       return false;
     }
-    return d_ci.cs->getTLSFrontend().d_proxyProtocolOutsideTLS;
+    return d_ci.cs->getTLSFrontend()->d_proxyProtocolOutsideTLS;
   }
 
   virtual bool forwardViaUDPFirst() const
index 4eb57f6be51a1debfe0841d8e55563a9181a7021..efa70e72fd1d6aa4cb0fe435a1a55335385d4b38 100644 (file)
@@ -751,7 +751,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp)
           errorCounters = &front->tlsFrontend->d_tlsCounters;
         }
         else if (front->dohFrontend != nullptr) {
-          errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters;
+          errorCounters = &front->dohFrontend->d_tlsContext->d_tlsCounters;
         }
 
         if (errorCounters != nullptr) {
@@ -789,7 +789,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp)
 #ifdef HAVE_DNS_OVER_HTTPS
   std::map<std::string,uint64_t> dohFrontendDuplicates;
   for(const auto& doh : dnsdist::getDoHFrontends()) {
-    const string frontName = doh->d_tlsContext.d_addr.toStringWithPort();
+    const string frontName = doh->d_tlsContext->d_addr.toStringWithPort();
     uint64_t threadNumber = 0;
     auto dupPair = frontendDuplicates.emplace(frontName, 1);
     if (!dupPair.second) {
@@ -1188,7 +1188,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp)
       errorCounters = &front->tlsFrontend->d_tlsCounters;
     }
     else if (front->dohFrontend != nullptr) {
-      errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters;
+      errorCounters = &front->dohFrontend->d_tlsContext->d_tlsCounters;
     }
     if (errorCounters != nullptr) {
       frontend["tlsHandshakeFailuresDHKeyTooSmall"] = (double)errorCounters->d_dhKeyTooSmall;
@@ -1212,7 +1212,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp)
     for (const auto& doh : dohFrontends) {
       dohs.emplace_back(Json::object{
         {"id", num++},
-        {"address", doh->d_tlsContext.d_addr.toStringWithPort()},
+        {"address", doh->d_tlsContext->d_addr.toStringWithPort()},
         {"http-connects", (double)doh->d_httpconnects},
         {"http1-queries", (double)doh->d_http1Stats.d_nbQueries},
         {"http2-queries", (double)doh->d_http2Stats.d_nbQueries},
index 20fe358d1fbebd8dbfa041d6233d5073c6047021..0cf39e44810158b3b20222c710785aa6734a6d77 100644 (file)
@@ -396,10 +396,10 @@ struct ClientState
     return tlsFrontend != nullptr || (dohFrontend != nullptr && dohFrontend->isHTTPS());
   }
 
-  const TLSFrontend& getTLSFrontend() const
+  const std::shared_ptr<const TLSFrontend> getTLSFrontend() const
   {
     if (tlsFrontend != nullptr) {
-      return *tlsFrontend;
+      return tlsFrontend;
     }
     if (dohFrontend) {
       return dohFrontend->d_tlsContext;
index 62f564aaacc766758c84900cd053d055cde03f12..fc03c86430431b5196b0150779b61b46c7337677 100644 (file)
@@ -1530,16 +1530,16 @@ static void setupAcceptContext(DOHAcceptContext& ctx, DOHServerConfig& dsc, bool
   nativeCtx->ctx = &dsc.h2o_ctx;
   nativeCtx->hosts = dsc.h2o_config.hosts;
   auto dohFrontend = std::atomic_load_explicit(&dsc.dohFrontend, std::memory_order_acquire);
-  ctx.d_ticketsKeyRotationDelay = dohFrontend->d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay;
+  ctx.d_ticketsKeyRotationDelay = dohFrontend->d_tlsContext->d_tlsConfig.d_ticketsKeyRotationDelay;
 
   if (setupTLS && dohFrontend->isHTTPS()) {
     try {
       setupTLSContext(ctx,
-                      dohFrontend->d_tlsContext.d_tlsConfig,
-                      dohFrontend->d_tlsContext.d_tlsCounters);
+                      dohFrontend->d_tlsContext->d_tlsConfig,
+                      dohFrontend->d_tlsContext->d_tlsCounters);
     }
     catch (const std::runtime_error& e) {
-      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + "': " + e.what());
+      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + "': " + e.what());
     }
   }
   ctx.d_cs = dsc.clientState;
@@ -1582,7 +1582,7 @@ void dohThread(ClientState* clientState)
     setThreadName("dnsdist/doh");
     // I wonder if this registers an IP address.. I think it does
     // this may mean we need to actually register a site "name" here and not the IP address
-    h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(dohFrontend->d_tlsContext.d_addr.toString().c_str(), dohFrontend->d_tlsContext.d_addr.toString().size()), 65535);
+    h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(dohFrontend->d_tlsContext->d_addr.toString().c_str(), dohFrontend->d_tlsContext->d_addr.toString().size()), 65535);
 
     dsc->paths = dohFrontend->d_urls;
     for (const auto& url : dsc->paths) {
@@ -1606,11 +1606,11 @@ void dohThread(ClientState* clientState)
     setupAcceptContext(*dsc->accept_ctx, *dsc, false);
 
     if (create_listener(dsc, clientState->tcpFD) != 0) {
-      throw std::runtime_error("DOH server failed to listen on " + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + ": " + stringerror(errno));
+      throw std::runtime_error("DOH server failed to listen on " + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + ": " + stringerror(errno));
     }
     for (const auto& [addr, descriptor] : clientState->d_additionalAddresses) {
       if (create_listener(dsc, descriptor) != 0) {
-        throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + ": " + stringerror(errno));
+        throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + dohFrontend->d_tlsContext->d_addr.toStringWithPort() + ": " + stringerror(errno));
       }
     }
 
@@ -1736,11 +1736,11 @@ void H2ODOHFrontend::setup()
   if  (isHTTPS()) {
     try {
       setupTLSContext(*d_dsc->accept_ctx,
-                      d_tlsContext.d_tlsConfig,
-                      d_tlsContext.d_tlsCounters);
+                      d_tlsContext->d_tlsConfig,
+                      d_tlsContext->d_tlsCounters);
     }
     catch (const std::runtime_error& e) {
-      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort() + "': " + e.what());
+      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext->d_addr.toStringWithPort() + "': " + e.what());
     }
   }
 }
index 0b07569d14bdc8463a40506d6707e91284864132..9379576b6088706a7283f2f1bbc413add184b2c2 100644 (file)
@@ -1883,6 +1883,14 @@ bool TLSFrontend::setupTLS()
 {
 #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
   std::shared_ptr<TLSCtx> newCtx{nullptr};
+  if (d_parentFrontend) {
+    newCtx = d_parentFrontend->getContext();
+    if (newCtx) {
+      std::atomic_store_explicit(&d_ctx, std::move(newCtx), std::memory_order_release);
+      return true;
+    }
+  }
+
   /* get the "best" available provider */
 #if defined(HAVE_GNUTLS)
   if (d_provider == "gnutls") {
index 782aada4e2a28de9fb0154bdacc351e6afa8c0a8..9450b611807144f875f88c42472a5b92ff1dcc43 100644 (file)
@@ -155,30 +155,35 @@ public:
 
   void rotateTicketsKey(time_t now)
   {
-    if (d_ctx != nullptr) {
+    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
       d_ctx->rotateTicketsKey(now);
     }
   }
 
   void loadTicketsKeys(const std::string& file)
   {
-    if (d_ctx != nullptr) {
+    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
       d_ctx->loadTicketsKeys(file);
     }
   }
 
   void loadTicketsKey(const std::string& key)
   {
-    if (d_ctx != nullptr) {
+    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
       d_ctx->loadTicketsKey(key);
     }
   }
 
-  std::shared_ptr<TLSCtx> getContext()
+  std::shared_ptr<TLSCtx> getContext() const
   {
     return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
   }
 
+  void setParent(std::shared_ptr<const TLSFrontend> parent)
+  {
+    std::atomic_store_explicit(&d_parentFrontend, std::move(parent), std::memory_order_release);
+  }
+
   void cleanup()
   {
     d_ctx.reset();
@@ -242,6 +247,7 @@ public:
   bool d_proxyProtocolOutsideTLS{false};
 protected:
   std::shared_ptr<TLSCtx> d_ctx{nullptr};
+  std::shared_ptr<const TLSFrontend> d_parentFrontend{nullptr};
 };
 
 class TCPIOHandler