]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Refactor the DoH code to be able to have two libraries
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 31 Jul 2023 14:18:02 +0000 (16:18 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Sep 2023 07:19:15 +0000 (09:19 +0200)
17 files changed:
pdns/dnsdist-carbon.cc
pdns/dnsdist-doh-common.hh [new file with mode: 0644]
pdns/dnsdist-idstate.hh
pdns/dnsdist-lua-inspection.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-web.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-async.cc
pdns/dnsdistdist/dnsdist-backend.cc
pdns/dnsdistdist/dnsdist-doh-common.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-doh-common.hh [new symlink]
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
pdns/doh.hh
pdns/test-dnsdist_cc.cc

index 693f498c435f8564126f9a53c0934d09c273a5ef..d73d9ff8b7efcdfb68c6dbee089bdae07927848f 100644 (file)
@@ -147,7 +147,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint)
         errorCounters = &front->tlsFrontend->d_tlsCounters;
       }
       else if (front->dohFrontend != nullptr) {
-        errorCounters = &front->dohFrontend->d_tlsCounters;
+        errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters;
       }
       if (errorCounters != nullptr) {
         str << base << "tlsdhkeytoosmall" << ' ' << errorCounters->d_dhKeyTooSmall << " " << now << "\r\n";
@@ -204,7 +204,7 @@ static bool doOneCarbonExport(const Carbon::Endpoint& endpoint)
       std::map<std::string, uint64_t> dohFrontendDuplicates;
       const string base = "dnsdist." + hostname + ".main.doh.";
       for (const auto& doh : g_dohlocals) {
-        string name = doh->d_local.toStringWithPort();
+        string name = doh->d_tlsContext.d_addr.toStringWithPort();
         boost::replace_all(name, ".", "_");
         boost::replace_all(name, ":", "_");
         boost::replace_all(name, "[", "_");
diff --git a/pdns/dnsdist-doh-common.hh b/pdns/dnsdist-doh-common.hh
new file mode 100644 (file)
index 0000000..44ad826
--- /dev/null
@@ -0,0 +1,240 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include <unordered_map>
+#include <set>
+
+#include "config.h"
+#include "iputils.hh"
+#include "libssl.hh"
+#include "noinitvector.hh"
+#include "stat_t.hh"
+#include "tcpiohandler.hh"
+
+struct DOHServerConfig;
+
+class DOHResponseMapEntry
+{
+public:
+  DOHResponseMapEntry(const std::string& regex, uint16_t status, const PacketBuffer& content, const boost::optional<std::unordered_map<std::string, std::string>>& headers) :
+    d_regex(regex), d_customHeaders(headers), d_content(content), d_status(status)
+  {
+    if (status >= 400 && !d_content.empty() && d_content.at(d_content.size() - 1) != 0) {
+      // we need to make sure it's null-terminated
+      d_content.push_back(0);
+    }
+  }
+
+  bool matches(const std::string& path) const
+  {
+    return d_regex.match(path);
+  }
+
+  uint16_t getStatusCode() const
+  {
+    return d_status;
+  }
+
+  const PacketBuffer& getContent() const
+  {
+    return d_content;
+  }
+
+  const boost::optional<std::unordered_map<std::string, std::string>>& getHeaders() const
+  {
+    return d_customHeaders;
+  }
+
+private:
+  Regex d_regex;
+  boost::optional<std::unordered_map<std::string, std::string>> d_customHeaders;
+  PacketBuffer d_content;
+  uint16_t d_status;
+};
+
+struct DOHFrontend
+{
+  DOHFrontend()
+  {
+  }
+  DOHFrontend(std::shared_ptr<TLSCtx> tlsCtx) :
+    d_tlsContext(std::move(tlsCtx))
+  {
+  }
+
+  virtual ~DOHFrontend()
+  {
+  }
+
+  std::shared_ptr<DOHServerConfig> d_dsc{nullptr};
+  std::shared_ptr<std::vector<std::shared_ptr<DOHResponseMapEntry>>> d_responsesMap;
+  TLSFrontend d_tlsContext{TLSFrontend::ALPN::DoH};
+  std::string d_serverTokens{"h2o/dnsdist"};
+  std::unordered_map<std::string, std::string> d_customResponseHeaders;
+  std::string d_library;
+
+  uint32_t d_idleTimeout{30}; // HTTP idle timeout in seconds
+  std::set<std::string, std::less<>> d_urls;
+
+  pdns::stat_t d_httpconnects{0}; // number of TCP/IP connections established
+  pdns::stat_t d_getqueries{0}; // valid DNS queries received via GET
+  pdns::stat_t d_postqueries{0}; // valid DNS queries received via POST
+  pdns::stat_t d_badrequests{0}; // request could not be converted to dns query
+  pdns::stat_t d_errorresponses{0}; // dnsdist set 'error' on response
+  pdns::stat_t d_redirectresponses{0}; // dnsdist set 'redirect' on response
+  pdns::stat_t d_validresponses{0}; // valid responses sent out
+
+  struct HTTPVersionStats
+  {
+    pdns::stat_t d_nbQueries{0}; // valid DNS queries received
+    pdns::stat_t d_nb200Responses{0};
+    pdns::stat_t d_nb400Responses{0};
+    pdns::stat_t d_nb403Responses{0};
+    pdns::stat_t d_nb500Responses{0};
+    pdns::stat_t d_nb502Responses{0};
+    pdns::stat_t d_nbOtherResponses{0};
+  };
+
+  HTTPVersionStats d_http1Stats;
+  HTTPVersionStats d_http2Stats;
+#ifdef __linux__
+  // On Linux this gives us 128k pending queries (default is 8192 queries),
+  // which should be enough to deal with huge spikes
+  uint32_t d_internalPipeBufferSize{1024 * 1024};
+#else
+  uint32_t d_internalPipeBufferSize{0};
+#endif
+  bool d_sendCacheControlHeaders{true};
+  bool d_trustForwardedForHeader{false};
+  /* whether we require tue query path to exactly match one of configured ones,
+     or accept everything below these paths. */
+  bool d_exactPathMatching{true};
+  bool d_keepIncomingHeaders{false};
+
+  time_t getTicketsKeyRotationDelay() const
+  {
+    return d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay;
+  }
+
+  bool isHTTPS() const
+  {
+    return !d_tlsContext.d_tlsConfig.d_certKeyPairs.empty();
+  }
+
+#ifndef HAVE_DNS_OVER_HTTPS
+  virtual void setup()
+  {
+  }
+
+  virtual void reloadCertificates()
+  {
+  }
+
+  virtual void rotateTicketsKey(time_t /* now */)
+  {
+  }
+
+  virtual void loadTicketsKeys(const std::string& /* keyFile */)
+  {
+  }
+
+  virtual void handleTicketsKeyRotation()
+  {
+  }
+
+  virtual std::string getNextTicketsKeyRotation()
+  {
+    return std::string();
+  }
+
+  virtual size_t getTicketsKeysCount() const
+  {
+    size_t res = 0;
+    return res;
+  }
+
+#else
+  virtual void setup();
+  virtual void reloadCertificates();
+
+  virtual void rotateTicketsKey(time_t now);
+  virtual void loadTicketsKeys(const std::string& keyFile);
+  virtual void handleTicketsKeyRotation();
+  virtual std::string getNextTicketsKeyRotation() const;
+  virtual size_t getTicketsKeysCount();
+#endif /* HAVE_DNS_OVER_HTTPS */
+};
+
+#include "dnsdist-idstate.hh"
+
+struct DownstreamState;
+
+#ifndef HAVE_DNS_OVER_HTTPS
+struct DOHUnitInterface
+{
+  virtual ~DOHUnitInterface()
+  {
+  }
+  static void handleTimeout(std::unique_ptr<DOHUnitInterface>)
+  {
+  }
+
+  static void handleUDPResponse(std::unique_ptr<DOHUnitInterface>, PacketBuffer&&, InternalQueryState&&, const std::shared_ptr<DownstreamState>&)
+  {
+  }
+};
+#else /* HAVE_DNS_OVER_HTTPS */
+struct DOHUnitInterface
+{
+  virtual ~DOHUnitInterface()
+  {
+  }
+
+  virtual std::string getHTTPPath() const = 0;
+  virtual std::string getHTTPQueryString() const = 0;
+  virtual const std::string& getHTTPHost() const = 0;
+  virtual const std::string& getHTTPScheme() const = 0;
+  virtual const std::unordered_map<std::string, std::string>& getHTTPHeaders() const = 0;
+  virtual void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "") = 0;
+  virtual void handleTimeout() = 0;
+  virtual void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>&) = 0;
+
+  static void handleTimeout(std::unique_ptr<DOHUnitInterface> unit)
+  {
+    if (unit) {
+      unit->handleTimeout();
+      unit.release();
+    }
+  }
+
+  static void handleUDPResponse(std::unique_ptr<DOHUnitInterface> unit, PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds)
+  {
+    if (unit) {
+      unit->handleUDPResponse(std::move(response), std::move(state), ds);
+      unit.release();
+    }
+  }
+
+  std::shared_ptr<DownstreamState> downstream{nullptr};
+};
+#endif /* HAVE_DNS_OVER_HTTPS  */
index cf5442fea0b254b16acadc0e5e6406f5bccfa6dd..456e703fb3db9b4e307de71940340582cc409757 100644 (file)
@@ -22,6 +22,7 @@
 #pragma once
 
 #include "config.h"
+#include "dnscrypt.hh"
 #include "dnsname.hh"
 #include "dnsdist-protocols.hh"
 #include "gettime.hh"
 #include "uuid-utils.hh"
 
 struct ClientState;
-struct DOHUnit;
+struct DOHUnitInterface;
 class DNSCryptQuery;
 class DNSDistPacketCache;
 
 using QTag = std::unordered_map<string, string>;
+using HeadersMap = std::unordered_map<std::string, std::string>;
 
 struct StopWatch
 {
@@ -89,6 +91,8 @@ private:
   bool d_needRealTime;
 };
 
+class CrossProtocolContext;
+
 struct InternalQueryState
 {
   struct ProtoBufData
@@ -125,7 +129,9 @@ struct InternalQueryState
   std::unique_ptr<ProtoBufData> d_protoBufData{nullptr};
   boost::optional<uint32_t> tempFailureTTL{boost::none}; // 8
   ClientState* cs{nullptr}; // 8
-  std::unique_ptr<DOHUnit> du{nullptr}; // 8
+  std::unique_ptr<DOHUnitInterface> du; // 8
+  size_t d_proxyProtocolPayloadSize{0}; // 8
+  int32_t d_streamID{-1}; // 4
   uint32_t cacheKey{0}; // 4
   uint32_t cacheKeyNoECS{0}; // 4
   // DoH-only */
index 66200df0eaa5e82e3785964f580aeabb69a92014..f778a492ddedaacbc2dd303107bcd0e41d8144bb 100644 (file)
@@ -706,7 +706,7 @@ void setupLuaInspection(LuaContext& luaCtx)
           errorCounters = &f->tlsFrontend->d_tlsCounters;
         }
         else if (f->dohFrontend != nullptr) {
-          errorCounters = &f->dohFrontend->d_tlsCounters;
+          errorCounters = &f->dohFrontend->d_tlsContext.d_tlsCounters;
         }
         if (errorCounters == nullptr) {
           continue;
index 3224b8763fd5074be0d609b8d2b4036c445dd32f..c829c2e1b5c1e35d57dc43eea9b92bb4481d7ba0 100644 (file)
@@ -57,6 +57,7 @@
 #include "dnsdist-web.hh"
 
 #include "base64.hh"
+#include "doh.hh"
 #include "dolog.hh"
 #include "sodcrypto.hh"
 #include "threadname.hh"
@@ -2336,31 +2337,39 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     setLuaSideEffect();
 
     auto frontend = std::make_shared<DOHFrontend>();
+#ifdef HAVE_LIBH2OEVLOOP
+    frontend = std::make_shared<H2ODOHFrontend>();
+    frontend->d_library = "h2o";
+#else /* HAVE_LIBH2OEVLOOP */
+    errlog("DOH bind %s is configured to use libh2o but the library is not available", addr);
+    return;
+#endif /* HAVE_LIBH2OEVLOOP */
+
     if (certFiles && !certFiles->empty()) {
-      if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) {
+      if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) {
         return;
       }
 
-      frontend->d_local = ComboAddress(addr, 443);
+      frontend->d_tlsContext.d_addr = ComboAddress(addr, 443);
     }
     else {
-      frontend->d_local = ComboAddress(addr, 80);
-      infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_local.toStringWithPort());
+      frontend->d_tlsContext.d_addr = ComboAddress(addr, 80);
+      infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort());
     }
 
     if (urls) {
       if (urls->type() == typeid(std::string)) {
-        frontend->d_urls.push_back(boost::get<std::string>(*urls));
+        frontend->d_urls.insert(boost::get<std::string>(*urls));
       }
       else if (urls->type() == typeid(LuaArray<std::string>)) {
         auto urlsVect = boost::get<LuaArray<std::string>>(*urls);
         for (const auto& p : urlsVect) {
-          frontend->d_urls.push_back(p.second);
+          frontend->d_urls.insert(p.second);
         }
       }
     }
     else {
-      frontend->d_urls = {"/dns-query"};
+      frontend->d_urls.insert("/dns-query");
     }
 
     bool reusePort = false;
@@ -2405,7 +2414,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         }
       }
 
-      parseTLSConfig(frontend->d_tlsConfig, "addDOHLocal", vars);
+      parseTLSConfig(frontend->d_tlsContext.d_tlsConfig, "addDOHLocal", vars);
 
       bool ignoreTLSConfigurationErrors = false;
       if (getOptionalValue<bool>(vars, "ignoreTLSConfigurationErrors", ignoreTLSConfigurationErrors) > 0 && ignoreTLSConfigurationErrors) {
@@ -2413,7 +2422,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         // and properly ignore the frontend before actually launching it
         try {
           std::map<int, std::string> ocspResponses = {};
-          auto ctx = libssl_init_server_context(frontend->d_tlsConfig, ocspResponses);
+          auto ctx = libssl_init_server_context(frontend->d_tlsContext.d_tlsConfig, ocspResponses);
         }
         catch (const std::runtime_error& e) {
           errlog("Ignoring DoH frontend: '%s'", e.what());
@@ -2424,7 +2433,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       checkAllParametersConsumed("addDOHLocal", vars);
     }
     g_dohlocals.push_back(frontend);
-    auto cs = std::make_unique<ClientState>(frontend->d_local, true, reusePort, tcpFastOpenQueueSize, interface, cpus);
+    auto cs = std::make_unique<ClientState>(frontend->d_tlsContext.d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus);
     cs->dohFrontend = frontend;
     cs->d_additionalAddresses = std::move(additionalAddresses);
 
@@ -2435,9 +2444,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       cs->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections;
     }
     g_frontends.push_back(std::move(cs));
-#else
+#else /* HAVE_DNS_OVER_HTTPS */
       throw std::runtime_error("addDOHLocal() called but DNS over HTTPS support is not present!");
-#endif
+#endif /* HAVE_DNS_OVER_HTTPS */
   });
 
   luaCtx.writeFunction("showDOHFrontends", []() {
@@ -2449,7 +2458,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       ret << (fmt % "#" % "Address" % "HTTP" % "HTTP/1" % "HTTP/2" % "GET" % "POST" % "Bad" % "Errors" % "Redirects" % "Valid" % "# ticket keys" % "Rotation delay" % "Next rotation") << endl;
       size_t counter = 0;
       for (const auto& ctx : g_dohlocals) {
-        ret << (fmt % counter % ctx->d_local.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl;
+        ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_httpconnects % ctx->d_http1Stats.d_nbQueries % ctx->d_http2Stats.d_nbQueries % ctx->d_getqueries % ctx->d_postqueries % ctx->d_badrequests % ctx->d_errorresponses % ctx->d_redirectresponses % ctx->d_validresponses % ctx->getTicketsKeysCount() % ctx->getTicketsKeyRotationDelay() % ctx->getNextTicketsKeyRotation()) << endl;
         counter++;
       }
       g_outputBuffer = ret.str();
@@ -2473,7 +2482,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl;
       size_t counter = 0;
       for (const auto& ctx : g_dohlocals) {
-        ret << (fmt % counter % ctx->d_local.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl;
+        ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http1Stats.d_nb200Responses % ctx->d_http1Stats.d_nb400Responses % ctx->d_http1Stats.d_nb403Responses % ctx->d_http1Stats.d_nb500Responses % ctx->d_http1Stats.d_nb502Responses % ctx->d_http1Stats.d_nbOtherResponses) << endl;
         counter++;
       }
       g_outputBuffer += ret.str();
@@ -2483,7 +2492,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       ret << (fmt % "#" % "Address" % "200" % "400" % "403" % "500" % "502" % "Others") << endl;
       counter = 0;
       for (const auto& ctx : g_dohlocals) {
-        ret << (fmt % counter % ctx->d_local.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl;
+        ret << (fmt % counter % ctx->d_tlsContext.d_addr.toStringWithPort() % ctx->d_http2Stats.d_nb200Responses % ctx->d_http2Stats.d_nb400Responses % ctx->d_http2Stats.d_nb403Responses % ctx->d_http2Stats.d_nb500Responses % ctx->d_http2Stats.d_nb502Responses % ctx->d_http2Stats.d_nbOtherResponses) << endl;
         counter++;
       }
       g_outputBuffer += ret.str();
@@ -2537,7 +2546,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
   luaCtx.registerFunction<void (std::shared_ptr<DOHFrontend>::*)(boost::variant<std::string, std::shared_ptr<TLSCertKeyPair>, LuaArray<std::string>, LuaArray<std::shared_ptr<TLSCertKeyPair>>> certFiles, boost::variant<std::string, LuaArray<std::string>> keyFiles)>("loadNewCertificatesAndKeys", [](std::shared_ptr<DOHFrontend> frontend, boost::variant<std::string, std::shared_ptr<TLSCertKeyPair>, LuaArray<std::string>, LuaArray<std::shared_ptr<TLSCertKeyPair>>> certFiles, boost::variant<std::string, LuaArray<std::string>> keyFiles) {
 #ifdef HAVE_DNS_OVER_HTTPS
     if (frontend != nullptr) {
-      if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) {
+      if (loadTLSCertificateAndKeys("DOHFrontend::loadNewCertificatesAndKeys", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) {
         frontend->reloadCertificates();
       }
     }
@@ -2579,7 +2588,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     }
     setLuaSideEffect();
 
-    shared_ptr<TLSFrontend> frontend = std::make_shared<TLSFrontend>(TLSFrontend::ALPN::DoT);
+    auto frontend = std::make_shared<TLSFrontend>(TLSFrontend::ALPN::DoT);
     if (!loadTLSCertificateAndKeys("addTLSLocal", frontend->d_tlsConfig.d_certKeyPairs, certFiles, keyFiles)) {
       return;
     }
index d1132d37839fc9609bee776ac1cab5907444949c..50379859725ce33c9434bdf4483cd1c471275adf 100644 (file)
@@ -741,7 +741,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp)
           errorCounters = &front->tlsFrontend->d_tlsCounters;
         }
         else if (front->dohFrontend != nullptr) {
-          errorCounters = &front->dohFrontend->d_tlsCounters;
+          errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters;
         }
 
         if (errorCounters != nullptr) {
@@ -779,7 +779,7 @@ static void handlePrometheus(const YaHTTP::Request& req, YaHTTP::Response& resp)
 #ifdef HAVE_DNS_OVER_HTTPS
   std::map<std::string,uint64_t> dohFrontendDuplicates;
   for(const auto& doh : g_dohlocals) {
-    const string frontName = doh->d_local.toStringWithPort();
+    const string frontName = doh->d_tlsContext.d_addr.toStringWithPort();
     uint64_t threadNumber = 0;
     auto dupPair = frontendDuplicates.emplace(frontName, 1);
     if (!dupPair.second) {
@@ -1149,7 +1149,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp)
       errorCounters = &front->tlsFrontend->d_tlsCounters;
     }
     else if (front->dohFrontend != nullptr) {
-      errorCounters = &front->dohFrontend->d_tlsCounters;
+      errorCounters = &front->dohFrontend->d_tlsContext.d_tlsCounters;
     }
     if (errorCounters != nullptr) {
       frontend["tlsHandshakeFailuresDHKeyTooSmall"] = (double)errorCounters->d_dhKeyTooSmall;
@@ -1172,7 +1172,7 @@ static void handleStats(const YaHTTP::Request& req, YaHTTP::Response& resp)
     for (const auto& doh : g_dohlocals) {
       dohs.emplace_back(Json::object{
         { "id", num++ },
-        { "address", doh->d_local.toStringWithPort() },
+        { "address", doh->d_tlsContext.d_addr.toStringWithPort() },
         { "http-connects", (double) doh->d_httpconnects },
         { "http1-queries", (double) doh->d_http1Stats.d_nbQueries },
         { "http2-queries", (double) doh->d_http2Stats.d_nbQueries },
index a673bd6f54e2ed77c51a385c270ba1d1f6f028a7..fdf2797104c14916769ae5eb9aafcdf0dccc95d9 100644 (file)
@@ -69,6 +69,7 @@
 #include "base64.hh"
 #include "capabilities.hh"
 #include "delaypipe.hh"
+#include "doh.hh"
 #include "dolog.hh"
 #include "dnsname.hh"
 #include "dnsparser.hh"
@@ -784,7 +785,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
         if (du) {
 #ifdef HAVE_DNS_OVER_HTTPS
           // DoH query, we cannot touch du after that
-          handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids));
+          DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(*ids), dss);
 #endif
           continue;
         }
@@ -1539,19 +1540,14 @@ ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::sha
   return ProcessQueryResult::Drop;
 }
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
 {
   bool doh = dq.ids.du != nullptr;
 
   bool failed = false;
-  size_t proxyPayloadSize = 0;
   if (ds->d_config.useProxyProtocol) {
     try {
-      if (addProxyProtocol(dq, &proxyPayloadSize)) {
-        if (dq.ids.du) {
-          dq.ids.du->proxyProtocolPayloadSize = proxyPayloadSize;
-        }
-      }
+      addProxyProtocol(dq, &dq.ids.d_proxyProtocolPayloadSize);
     }
     catch (const std::exception& e) {
       vinfolog("Adding proxy protocol payload to %s query from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what());
@@ -1559,6 +1555,10 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint1
     }
   }
 
+  if (doh && !dq.ids.d_packet) {
+    dq.ids.d_packet = std::make_unique<PacketBuffer>(query);
+  }
+
   try {
     int fd = ds->pickSocketForSending();
     dq.ids.backendFD = fd;
@@ -1569,7 +1569,7 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint1
 
     auto idOffset = ds->saveState(std::move(dq.ids));
     /* set the correct ID */
-    memcpy(query.data() + proxyPayloadSize, &idOffset, sizeof(idOffset));
+    memcpy(query.data() + dq.ids.d_proxyProtocolPayloadSize, &idOffset, sizeof(idOffset));
 
     /* you can't touch ids or du after this line, unless the call returned a non-negative value,
        because it might already have been freed */
@@ -1585,9 +1585,6 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint1
       auto cleared = ds->getState(idOffset);
       if (cleared) {
         dq.ids.du = std::move(cleared->du);
-        if (dq.ids.du) {
-          dq.ids.du->status_code = 502;
-        }
       }
       ++dnsdist::metrics::g_stats.downstreamSendErrors;
       ++ds->sendErrors;
@@ -1720,7 +1717,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       return;
     }
 
-    assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, dest);
+    assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query);
   }
   catch(const std::exception& e){
     vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what());
index 56b7421655d453e2fec9f6f57264ff96ddc8b1c6..a9ecef0170809c6af5cfc0053e2d84bed317563e 100644 (file)
@@ -42,7 +42,7 @@
 #include "dnsdist-lbpolicies.hh"
 #include "dnsdist-protocols.hh"
 #include "dnsname.hh"
-#include "doh.hh"
+#include "dnsdist-doh-common.hh"
 #include "ednsoptions.hh"
 #include "iputils.hh"
 #include "misc.hh"
@@ -1088,10 +1088,6 @@ struct LocalHolders
 
 void tcpAcceptorThread(std::vector<ClientState*> states);
 
-#ifdef HAVE_DNS_OVER_HTTPS
-void dohThread(ClientState* cs);
-#endif /* HAVE_DNS_OVER_HTTPS */
-
 void setLuaNoSideEffect(); // if nothing has been declared, set that there are no side effects
 void setLuaSideEffect();   // set to report a side effect, cancelling all _no_ side effect calls
 bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect
@@ -1123,7 +1119,7 @@ bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRu
 bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop);
 bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted);
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest);
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query);
 
 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
 bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote);
index 9b951a58669444df7ef415a935fbd226d0680908..99d7cdbe6475655094ba72a714a2ec6c7f6049a6 100644 (file)
@@ -147,6 +147,7 @@ dnsdist_SOURCES = \
        dnsdist-discovery.cc dnsdist-discovery.hh \
        dnsdist-dnscrypt.cc \
        dnsdist-dnsparser.cc dnsdist-dnsparser.hh \
+       dnsdist-doh-common.cc dnsdist-doh-common.hh \
        dnsdist-downstream-connection.hh \
        dnsdist-dynblocks.cc dnsdist-dynblocks.hh \
        dnsdist-dynbpf.cc dnsdist-dynbpf.hh \
@@ -256,6 +257,7 @@ testrunner_SOURCES = \
        dnsdist-cache.cc dnsdist-cache.hh \
        dnsdist-concurrent-connections.hh \
        dnsdist-dnsparser.cc dnsdist-dnsparser.hh \
+       dnsdist-doh-common.cc dnsdist-doh-common.hh \
        dnsdist-downstream-connection.hh \
        dnsdist-dynblocks.cc dnsdist-dynblocks.hh \
        dnsdist-dynbpf.cc dnsdist-dynbpf.hh \
index 19426468df74c04b2f1f1c96748da2b4afa315d7..f54b1c0b144644d9ff5c90b60b5f2bf0f74eb43b 100644 (file)
@@ -282,7 +282,6 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query)
     return resumeResponse(std::move(query));
   }
 
-  auto& ids = query->query.d_idstate;
   DNSQuestion dnsQuestion = query->getDQ();
   LocalHolders holders;
 
@@ -311,7 +310,7 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query)
     /* at this point 'du', if it is not nullptr, is owned by the DoHCrossProtocolQuery
        which will stop existing when we return, so we need to increment the reference count
     */
-    return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dnsQuestion, query->query.d_buffer, ids.origDest);
+    return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dnsQuestion, query->query.d_buffer);
   }
   if (result == ProcessQueryResult::SendAnswer) {
     auto sender = query->getTCPQuerySender();
index 45a50446da71105f47323cc1808f0f1d59c6809b..44b3d9c39dc4685cf041fb5261c797acc8200af5 100644 (file)
@@ -360,7 +360,7 @@ void DownstreamState::handleUDPTimeout(IDState& ids)
 {
   ids.age = 0;
   ids.inUse = false;
-  handleDOHTimeout(std::move(ids.internal.du));
+  DOHUnitInterface::handleTimeout(std::move(ids.internal.du));
   ++reuseds;
   --outstanding;
   ++dnsdist::metrics::g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout
@@ -463,7 +463,7 @@ uint16_t DownstreamState::saveState(InternalQueryState&& state)
         auto oldDU = std::move(it->second.internal.du);
         ++reuseds;
         ++dnsdist::metrics::g_stats.downstreamTimeouts;
-        handleDOHTimeout(std::move(oldDU));
+        DOHUnitInterface::handleTimeout(std::move(oldDU));
       }
       else {
         ++outstanding;
@@ -490,7 +490,7 @@ uint16_t DownstreamState::saveState(InternalQueryState&& state)
       auto oldDU = std::move(ids.internal.du);
       ++reuseds;
       ++dnsdist::metrics::g_stats.downstreamTimeouts;
-      handleDOHTimeout(std::move(oldDU));
+      DOHUnitInterface::handleTimeout(std::move(oldDU));
     }
     else {
       ++outstanding;
@@ -513,7 +513,7 @@ void DownstreamState::restoreState(uint16_t id, InternalQueryState&& state)
       /* already used */
       ++reuseds;
       ++dnsdist::metrics::g_stats.downstreamTimeouts;
-      handleDOHTimeout(std::move(state.du));
+      DOHUnitInterface::handleTimeout(std::move(state.du));
     }
     else {
       it->second.internal = std::move(state);
@@ -528,14 +528,14 @@ void DownstreamState::restoreState(uint16_t id, InternalQueryState&& state)
     /* already used */
     ++reuseds;
     ++dnsdist::metrics::g_stats.downstreamTimeouts;
-    handleDOHTimeout(std::move(state.du));
+    DOHUnitInterface::handleTimeout(std::move(state.du));
     return;
   }
   if (ids.isInUse()) {
     /* already used */
     ++reuseds;
     ++dnsdist::metrics::g_stats.downstreamTimeouts;
-    handleDOHTimeout(std::move(state.du));
+    DOHUnitInterface::handleTimeout(std::move(state.du));
     return;
   }
   ids.internal = std::move(state);
diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc
new file mode 100644 (file)
index 0000000..15fcb96
--- /dev/null
@@ -0,0 +1,129 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include "dnsdist-doh-common.hh"
+#include "dnsdist-rules.hh"
+
+#ifdef HAVE_DNS_OVER_HTTPS
+
+HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& regex) :
+  d_header(toLower(header)), d_regex(regex), d_visual("http[" + header + "] ~ " + regex)
+{
+}
+
+bool HTTPHeaderRule::matches(const DNSQuestion* dq) const
+{
+  if (!dq->ids.du) {
+    return false;
+  }
+
+  const auto& headers = dq->ids.du->getHTTPHeaders();
+  for (const auto& header : headers) {
+    if (header.first == d_header) {
+      return d_regex.match(header.second);
+    }
+  }
+  return false;
+}
+
+string HTTPHeaderRule::toString() const
+{
+  return d_visual;
+}
+
+HTTPPathRule::HTTPPathRule(const std::string& path) :
+  d_path(path)
+{
+}
+
+bool HTTPPathRule::matches(const DNSQuestion* dq) const
+{
+  if (!dq->ids.du) {
+    return false;
+  }
+
+  const auto path = dq->ids.du->getHTTPPath();
+  return d_path == path;
+}
+
+string HTTPPathRule::toString() const
+{
+  return "url path == " + d_path;
+}
+
+HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex) :
+  d_regex(regex), d_visual("http path ~ " + regex)
+{
+}
+
+bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const
+{
+  if (!dq->ids.du) {
+    return false;
+  }
+
+  return d_regex.match(dq->ids.du->getHTTPPath());
+}
+
+string HTTPPathRegexRule::toString() const
+{
+  return d_visual;
+}
+
+void DOHFrontend::rotateTicketsKey(time_t now)
+{
+  return d_tlsContext.rotateTicketsKey(now);
+}
+
+void DOHFrontend::loadTicketsKeys(const std::string& keyFile)
+{
+  return d_tlsContext.loadTicketsKeys(keyFile);
+}
+
+void DOHFrontend::handleTicketsKeyRotation()
+{
+}
+
+std::string DOHFrontend::getNextTicketsKeyRotation() const
+{
+  return d_tlsContext.getNextTicketsKeyRotation();
+}
+
+size_t DOHFrontend::getTicketsKeysCount()
+{
+  return d_tlsContext.getTicketsKeysCount();
+}
+
+void DOHFrontend::reloadCertificates()
+{
+  d_tlsContext.setupTLS();
+}
+
+void DOHFrontend::setup()
+{
+  if (isHTTPS()) {
+    if (!d_tlsContext.setupTLS()) {
+      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort());
+    }
+  }
+}
+
+#endif /* HAVE_DNS_OVER_HTTPS */
diff --git a/pdns/dnsdistdist/dnsdist-doh-common.hh b/pdns/dnsdistdist/dnsdist-doh-common.hh
new file mode 120000 (symlink)
index 0000000..5692084
--- /dev/null
@@ -0,0 +1 @@
+../dnsdist-doh-common.hh
\ No newline at end of file
index eeb0af48084743300e25cc84e1518792a34496a4..91dcd9ad767418df8fe93538bcd662783ec50da9 100644 (file)
@@ -167,6 +167,8 @@ private:
   std::atomic_flag d_rotatingTicketsKey;
 };
 
+struct DOHUnit;
+
 // we create one of these per thread, and pass around a pointer to it
 // through the bowels of h2o
 struct DOHServerConfig
@@ -215,6 +217,61 @@ struct DOHServerConfig
   pdns::channel::Receiver<DOHUnit> d_responseReceiver;
 };
 
+struct DOHUnit : public DOHUnitInterface
+{
+  DOHUnit(PacketBuffer&& q, std::string&& p, std::string&& h): path(std::move(p)), host(std::move(h)), query(std::move(q))
+  {
+    ids.ednsAdded = false;
+  }
+  ~DOHUnit()
+  {
+    if (self) {
+      *self = nullptr;
+    }
+  }
+
+  DOHUnit(const DOHUnit&) = delete;
+  DOHUnit& operator=(const DOHUnit&) = delete;
+
+  InternalQueryState ids;
+  std::string sni;
+  std::string path;
+  std::string scheme;
+  std::string host;
+  std::string contentType;
+  PacketBuffer query;
+  PacketBuffer response;
+  std::unique_ptr<std::unordered_map<std::string, std::string>> headers;
+  st_h2o_req_t* req{nullptr};
+  DOHUnit** self{nullptr};
+  DOHServerConfig* dsc{nullptr};
+  pdns::channel::Sender<DOHUnit>* responseSender{nullptr};
+  size_t query_at{0};
+  int rsock{-1};
+  /* the status_code is set from
+     processDOHQuery() (which is executed in
+     the DOH client thread) so that the correct
+     response can be sent in on_dnsdist(),
+     after the DOHUnit has been passed back to
+     the main DoH thread.
+  */
+  uint16_t status_code{200};
+  /* whether the query was re-sent to the backend over
+     TCP after receiving a truncated answer over UDP */
+  bool tcp{false};
+  bool truncated{false};
+
+  std::string getHTTPPath() const override;
+  std::string getHTTPQueryString() const override;
+  const std::string& getHTTPHost() const override;
+  const std::string& getHTTPScheme() const override;
+  const std::unordered_map<std::string, std::string>& getHTTPHeaders() const override;
+  void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType="") override;
+  virtual void handleTimeout() override;
+  virtual void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>&) override;
+};
+using DOHUnitUniquePtr = std::unique_ptr<DOHUnit>;
+
 /* This internal function sends back the object to the main thread to send a reply.
    The caller should NOT release or touch the unit after calling this function */
 static void sendDoHUnitToTheMainThread(DOHUnitUniquePtr&& du, const char* description)
@@ -233,18 +290,11 @@ static void sendDoHUnitToTheMainThread(DOHUnitUniquePtr&& du, const char* descri
 }
 
 /* This function is called from other threads than the main DoH one,
-   instructing it to send a 502 error to the client.
-   It takes ownership of the unit. */
-void handleDOHTimeout(DOHUnitUniquePtr&& oldDU)
+   instructing it to send a 502 error to the client. */
+void DOHUnit::handleTimeout()
 {
-  if (oldDU == nullptr) {
-    return;
-  }
-
-  /* we are about to erase an existing DU */
-  oldDU->status_code = 502;
-
-  sendDoHUnitToTheMainThread(std::move(oldDU), "DoH timeout");
+  status_code = 502;
+  sendDoHUnitToTheMainThread(std::unique_ptr<DOHUnit>(this), "DoH timeout");
 }
 
 struct DOHConnection
@@ -385,7 +435,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo
         h2o_send_error_400(req, getReasonFromStatusCode(statusCode).c_str(), "invalid DNS query" , 0);
         break;
       case 403:
-        h2o_send_error_403(req, getReasonFromStatusCode(statusCode).c_str(), "dns query not allowed", 0);
+        h2o_send_error_403(req, getReasonFromStatusCode(statusCode).c_str(), "DoH query not allowed", 0);
         break;
       case 502:
         h2o_send_error_502(req, getReasonFromStatusCode(statusCode).c_str(), "no downstream server available", 0);
@@ -402,6 +452,12 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo
   }
 }
 
+static std::unique_ptr<DOHUnit> getDUFromIDS(InternalQueryState& ids)
+{
+  auto du = std::unique_ptr<DOHUnit>(dynamic_cast<DOHUnit*>(ids.du.release()));
+  return du;
+}
+
 class DoHTCPCrossQuerySender : public TCPQuerySender
 {
 public:
@@ -420,7 +476,7 @@ public:
       return;
     }
 
-    auto du = std::move(response.d_idstate.du);
+    auto du = getDUFromIDS(response.d_idstate);
     if (du->responseSender == nullptr) {
       return;
     }
@@ -438,10 +494,11 @@ public:
 
       dr.ids.du = std::move(du);
 
-      if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) {
+      if (!processResponse(dynamic_cast<DOHUnit*>(dr.ids.du.get())->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) {
         if (dr.ids.du) {
-          dr.ids.du->status_code = 503;
-          sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules");
+          du = getDUFromIDS(dr.ids);
+          du->status_code = 503;
+          sendDoHUnitToTheMainThread(std::move(du), "Response dropped by rules");
         }
         return;
       }
@@ -450,7 +507,7 @@ public:
         return;
       }
 
-      du = std::move(dr.ids.du);
+      du = getDUFromIDS(dr.ids);
     }
 
     if (!du->ids.selfGenerated) {
@@ -483,11 +540,11 @@ public:
       return;
     }
 
-    if (query.du->responseSender == nullptr) {
+    auto du = getDUFromIDS(query);
+    if (du->responseSender == nullptr) {
       return;
     }
 
-    auto du = std::move(query.du);
     du->ids = std::move(query);
     du->status_code = 502;
     sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response");
@@ -519,20 +576,23 @@ public:
        Leave it for now because we know that the onky case where the payload has been
        added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT,
        and we know the TCP/DoT code can handle it. */
-    query.d_proxyProtocolPayloadAdded = query.d_idstate.du->proxyProtocolPayloadSize > 0;
+    query.d_proxyProtocolPayloadAdded = query.d_idstate.d_proxyProtocolPayloadSize > 0;
     downstream = query.d_idstate.du->downstream;
-    proxyProtocolPayloadSize = query.d_idstate.du->proxyProtocolPayloadSize;
   }
 
   void handleInternalError()
   {
-    query.d_idstate.du->status_code = 502;
-    sendDoHUnitToTheMainThread(std::move(query.d_idstate.du), "DoH internal error");
+    auto du = getDUFromIDS(query.d_idstate);
+    if (!du) {
+      return;
+    }
+    du->status_code = 502;
+    sendDoHUnitToTheMainThread(std::move(du), "DoH internal error");
   }
 
   std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
   {
-    query.d_idstate.du->downstream = downstream;
+    dynamic_cast<DOHUnit*>(query.d_idstate.du.get())->downstream = downstream;
     return s_sender;
   }
 
@@ -550,9 +610,9 @@ public:
     return dr;
    }
 
-  DOHUnitUniquePtr&& releaseDU()
+  DOHUnitUniquePtr releaseDU()
   {
-    return std::move(query.d_idstate.du);
+    return getDUFromIDS(query.d_idstate);
   }
 
 private:
@@ -567,7 +627,7 @@ std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion&
     throw std::runtime_error("Trying to create a DoH cross protocol query without a valid DoH unit");
   }
 
-  auto du = std::move(dq.ids.du);
+  auto du = getDUFromIDS(dq.ids);
   if (&dq.ids != &du->ids) {
    du->ids = std::move(dq.ids);
   }
@@ -606,121 +666,116 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
   };
 
   auto& ids = unit->ids;
-  ids.du = std::move(unit);
-  auto& du = ids.du;
   uint16_t queryId = 0;
   ComboAddress remote;
 
   try {
-    if (!du->req) {
+    if (!unit->req) {
       // we got closed meanwhile. XXX small race condition here
       // but we should be fine as long as we don't touch du->req
       // outside of the main DoH thread
-      du->status_code = 500;
-      handleImmediateResponse(std::move(du), "DoH killed in flight");
+      unit->status_code = 500;
+      handleImmediateResponse(std::move(unit), "DoH killed in flight");
       return;
     }
 
-    {
-      // if there was no EDNS, we add it with a large buffer size
-      // so we can use UDP to talk to the backend.
-      auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(du->query.data()));
-
-      if (!dh->arcount) {
-        if (generateOptRR(std::string(), du->query, 4096, 4096, 0, false)) {
-          dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(du->query.data())); // may have reallocated
-          dh->arcount = htons(1);
-          du->ids.ednsAdded = true;
-        }
-      }
-      else {
-        // we leave existing EDNS in place
-      }
-    }
-
-    remote = du->ids.origRemote;
-    DOHServerConfig* dsc = du->dsc;
+    remote = ids.origRemote;
+    DOHServerConfig* dsc = unit->dsc;
     auto& holders = dsc->holders;
     ClientState& cs = *dsc->cs;
 
-    if (du->query.size() < sizeof(dnsheader)) {
+    if (unit->query.size() < sizeof(dnsheader)) {
       ++dnsdist::metrics::g_stats.nonCompliantQueries;
       ++cs.nonCompliantQueries;
-      du->status_code = 400;
-      handleImmediateResponse(std::move(du), "DoH non-compliant query");
+      unit->status_code = 400;
+      handleImmediateResponse(std::move(unit), "DoH non-compliant query");
       return;
     }
 
     ++cs.queries;
     ++dnsdist::metrics::g_stats.queries;
-    du->ids.queryRealTime.start();
+    ids.queryRealTime.start();
 
     {
       /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
-      struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(du->query.data());
+      struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(unit->query.data());
 
       if (!checkQueryHeaders(dh, cs)) {
-        du->status_code = 400;
-        handleImmediateResponse(std::move(du), "DoH invalid headers");
+        unit->status_code = 400;
+        handleImmediateResponse(std::move(unit), "DoH invalid headers");
         return;
       }
 
       if (dh->qdcount == 0) {
         dh->rcode = RCode::NotImp;
         dh->qr = true;
-        du->response = std::move(du->query);
+        unit->response = std::move(unit->query);
 
-        handleImmediateResponse(std::move(du), "DoH empty query");
+        handleImmediateResponse(std::move(unit), "DoH empty query");
         return;
       }
 
       queryId = ntohs(dh->id);
     }
 
-    auto downstream = du->downstream;
-    du->ids.qname = DNSName(reinterpret_cast<const char*>(du->query.data()), du->query.size(), sizeof(dnsheader), false, &du->ids.qtype, &du->ids.qclass);
-    DNSQuestion dq(du->ids, du->query);
+    {
+      // if there was no EDNS, we add it with a large buffer size
+      // so we can use UDP to talk to the backend.
+      auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(unit->query.data()));
+      if (!dh->arcount) {
+        if (addEDNS(unit->query, 4096, false, 4096, 0)) {
+          ids.ednsAdded = true;
+        }
+      }
+    }
+
+    auto downstream = unit->downstream;
+    ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), unit->query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
+    DNSQuestion dq(ids, unit->query);
     const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
     ids.origFlags = *flags;
-    du->ids.cs = &cs;
-    dq.sni = std::move(du->sni);
-
+    ids.cs = &cs;
+    dq.sni = std::move(unit->sni);
+    ids.du = std::move(unit);
     auto result = processQuery(dq, holders, downstream);
 
     if (result == ProcessQueryResult::Drop) {
-      du->status_code = 403;
-      handleImmediateResponse(std::move(du), "DoH dropped query");
+      unit = getDUFromIDS(ids);
+      unit->status_code = 403;
+      handleImmediateResponse(std::move(unit), "DoH dropped query");
       return;
     }
     else if (result == ProcessQueryResult::Asynchronous) {
       return;
     }
     else if (result == ProcessQueryResult::SendAnswer) {
-      if (du->response.empty()) {
-        du->response = std::move(du->query);
+      unit = getDUFromIDS(ids);
+      if (unit->response.empty()) {
+        unit->response = std::move(unit->query);
       }
-      if (du->response.size() >= sizeof(dnsheader) && du->contentType.empty()) {
-        auto dh = reinterpret_cast<const struct dnsheader*>(du->response.data());
+      if (unit->response.size() >= sizeof(dnsheader) && unit->contentType.empty()) {
+        auto dh = reinterpret_cast<const struct dnsheader*>(unit->response.data());
 
-        handleResponseSent(du->ids.qname, QType(du->ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH, false);
+        handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH, false);
       }
-      handleImmediateResponse(std::move(du), "DoH self-answered response");
+      handleImmediateResponse(std::move(unit), "DoH self-answered response");
       return;
     }
 
+    unit = getDUFromIDS(ids);
     if (result != ProcessQueryResult::PassToBackend) {
-      du->status_code = 500;
-      handleImmediateResponse(std::move(du), "DoH no backend available");
+      unit->status_code = 500;
+      handleImmediateResponse(std::move(unit), "DoH no backend available");
       return;
     }
 
     if (downstream == nullptr) {
-      du->status_code = 502;
-      handleImmediateResponse(std::move(du), "DoH no backend available");
+      unit->status_code = 502;
+      handleImmediateResponse(std::move(unit), "DoH no backend available");
       return;
     }
 
-    du->downstream = downstream;
+    unit->downstream = downstream;
 
     if (downstream->isTCPOnly()) {
       std::string proxyProtocolPayload;
@@ -730,11 +785,11 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
         proxyProtocolPayload = getProxyProtocolPayload(dq);
       }
 
-      du->ids.origID = htons(queryId);
-      du->tcp = true;
+      unit->ids.origID = htons(queryId);
+      unit->tcp = true;
 
       /* this moves du->ids, careful! */
-      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du), false);
+      auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(unit), false);
       cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
       if (downstream->passCrossProtocolQuery(std::move(cpq))) {
@@ -742,9 +797,9 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
       }
       else {
         if (inMainThread) {
-          du = cpq->releaseDU();
-          du->status_code = 502;
-          handleImmediateResponse(std::move(du), "DoH internal error");
+          unit = cpq->releaseDU();
+          unit->status_code = 502;
+          handleImmediateResponse(std::move(unit), "DoH internal error");
         }
         else {
           cpq->handleInternalError();
@@ -753,17 +808,19 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
       }
     }
 
-    ComboAddress dest = dq.ids.origDest;
-    if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, du->query, dest)) {
-      du->status_code = 502;
-      handleImmediateResponse(std::move(du), "DoH internal error");
+    auto& query = unit->query;
+    ids.du = std::move(unit);
+    if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, query)) {
+      unit = getDUFromIDS(ids);
+      unit->status_code = 502;
+      handleImmediateResponse(std::move(unit), "DoH internal error");
       return;
     }
   }
   catch (const std::exception& e) {
     vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
-    du->status_code = 500;
-    handleImmediateResponse(std::move(du), "DoH internal error");
+    unit->status_code = 500;
+    handleImmediateResponse(std::move(unit), "DoH internal error");
     return;
   }
 
@@ -838,7 +895,7 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re
     /* we are doing quite some copies here, sorry about that,
        but we can't keep accessing the req object once we are in a different thread
        because the request might get killed by h2o at pretty much any time */
-    auto du = std::make_unique<DOHUnit>(std::move(query), std::move(path), std::string(req->authority.base, req->authority.len));
+    auto du = DOHUnitUniquePtr(new DOHUnit(std::move(query), std::move(path), std::string(req->authority.base, req->authority.len)));
     du->dsc = dsc;
     du->req = req;
     du->ids.origDest = local;
@@ -869,7 +926,7 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re
     *(du->self) = du.get();
 
 #ifdef USE_SINGLE_ACCEPTOR_THREAD
-    processDOHQuery(du, true);
+    processDOHQuery(std::move(du), true);
 #else /* USE_SINGLE_ACCEPTOR_THREAD */
     try {
       if (!dsc->d_querySender.send(std::move(du))) {
@@ -1102,85 +1159,13 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req)
   }
 }
 
-HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& regex)
-  : d_header(toLower(header)), d_regex(regex), d_visual("http[" + header+ "] ~ " + regex)
-{
-}
-
-bool HTTPHeaderRule::matches(const DNSQuestion* dq) const
-{
-  if (!dq->ids.du || !dq->ids.du->headers) {
-    return false;
-  }
-
-  for (const auto& header : *dq->ids.du->headers) {
-    if (header.first == d_header) {
-      return d_regex.match(header.second);
-    }
-  }
-  return false;
-}
-
-string HTTPHeaderRule::toString() const
-{
-  return d_visual;
-}
-
-HTTPPathRule::HTTPPathRule(const std::string& path)
-  :  d_path(path)
-{
-
-}
-
-bool HTTPPathRule::matches(const DNSQuestion* dq) const
-{
-  if (!dq->ids.du) {
-    return false;
-  }
-
-  if (dq->ids.du->query_at == SIZE_MAX) {
-    return dq->ids.du->path == d_path;
-  }
-  else {
-    return d_path.compare(0, d_path.size(), dq->ids.du->path, 0, dq->ids.du->query_at) == 0;
-  }
-}
-
-string HTTPPathRule::toString() const
-{
-  return "url path == " + d_path;
-}
-
-HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex): d_regex(regex), d_visual("http path ~ " + regex)
-{
-}
-
-bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const
-{
-  if (!dq->ids.du) {
-    return false;
-  }
-
-  return d_regex.match(dq->ids.du->getHTTPPath());
-}
-
-string HTTPPathRegexRule::toString() const
-{
-  return d_visual;
-}
-
-std::unordered_map<std::string, std::string> DOHUnit::getHTTPHeaders() const
+const std::unordered_map<std::string, std::string>& DOHUnit::getHTTPHeaders() const
 {
-  std::unordered_map<std::string, std::string> results;
-  if (headers) {
-    results.reserve(headers->size());
-
-    for (const auto& header : *headers) {
-      results.insert(header);
-    }
+  if (!headers) {
+    static const HeadersMap empty{};
+    return empty;
   }
-
-  return results;
+  return *headers;
 }
 
 std::string DOHUnit::getHTTPPath() const
@@ -1193,12 +1178,12 @@ std::string DOHUnit::getHTTPPath() const
   }
 }
 
-std::string DOHUnit::getHTTPHost() const
+const std::string& DOHUnit::getHTTPHost() const
 {
   return host;
 }
 
-std::string DOHUnit::getHTTPScheme() const
+const std::string& DOHUnit::getHTTPScheme() const
 {
   return scheme;
 }
@@ -1280,7 +1265,7 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
      memory and likely coming up too late after the client has gone away */
   auto* dsc = static_cast<DOHServerConfig*>(listener->data);
   while (true) {
-    std::unique_ptr<DOHUnit> du{nullptr};
+    DOHUnitUniquePtr du{nullptr};
     try {
       auto tmp = dsc->d_responseReceiver.receive();
       if (!tmp) {
@@ -1300,10 +1285,10 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
 
     if (!du->tcp &&
         du->truncated &&
-        du->query.size() > du->proxyProtocolPayloadSize &&
-        (du->query.size() - du->proxyProtocolPayloadSize) > sizeof(dnsheader)) {
+        du->query.size() > du->ids.d_proxyProtocolPayloadSize &&
+        (du->query.size() - du->ids.d_proxyProtocolPayloadSize) > sizeof(dnsheader)) {
       /* restoring the original ID */
-      dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(du->query.data() + du->proxyProtocolPayloadSize);
+      dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(du->query.data() + du->ids.d_proxyProtocolPayloadSize);
       queryDH->id = du->ids.origID;
       du->ids.forwardedOverUDP = false;
       du->tcp = true;
@@ -1494,84 +1479,22 @@ static void setupAcceptContext(DOHAcceptContext& ctx, DOHServerConfig& dsc, bool
   auto nativeCtx = ctx.get();
   nativeCtx->ctx = &dsc.h2o_ctx;
   nativeCtx->hosts = dsc.h2o_config.hosts;
-  ctx.d_ticketsKeyRotationDelay = dsc.df->d_tlsConfig.d_ticketsKeyRotationDelay;
+  auto df = std::atomic_load_explicit(&dsc.df, std::memory_order_acquire);
+  ctx.d_ticketsKeyRotationDelay = df->d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay;
 
-  if (setupTLS && dsc.df->isHTTPS()) {
+  if (setupTLS && df->isHTTPS()) {
     try {
       setupTLSContext(ctx,
-                      dsc.df->d_tlsConfig,
-                      dsc.df->d_tlsCounters);
+                      df->d_tlsContext.d_tlsConfig,
+                      df->d_tlsContext.d_tlsCounters);
     }
     catch (const std::runtime_error& e) {
-      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dsc.df->d_local.toStringWithPort() + "': " + e.what());
+      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + df->d_tlsContext.d_addr.toStringWithPort() + "': " + e.what());
     }
   }
   ctx.d_cs = dsc.cs;
 }
 
-void DOHFrontend::rotateTicketsKey(time_t now)
-{
-  if (d_dsc && d_dsc->accept_ctx) {
-    d_dsc->accept_ctx->rotateTicketsKey(now);
-  }
-}
-
-void DOHFrontend::loadTicketsKeys(const std::string& keyFile)
-{
-  if (d_dsc && d_dsc->accept_ctx) {
-    d_dsc->accept_ctx->loadTicketsKeys(keyFile);
-  }
-}
-
-void DOHFrontend::handleTicketsKeyRotation()
-{
-  if (d_dsc && d_dsc->accept_ctx) {
-    d_dsc->accept_ctx->handleTicketsKeyRotation();
-  }
-}
-
-time_t DOHFrontend::getNextTicketsKeyRotation() const
-{
-  if (d_dsc && d_dsc->accept_ctx) {
-    return d_dsc->accept_ctx->getNextTicketsKeyRotation();
-  }
-  return 0;
-}
-
-size_t DOHFrontend::getTicketsKeysCount() const
-{
-  size_t res = 0;
-  if (d_dsc && d_dsc->accept_ctx) {
-    res = d_dsc->accept_ctx->getTicketsKeysCount();
-  }
-  return res;
-}
-
-void DOHFrontend::reloadCertificates()
-{
-  auto newAcceptContext = std::make_shared<DOHAcceptContext>();
-  setupAcceptContext(*newAcceptContext, *d_dsc, true);
-  std::atomic_store_explicit(&d_dsc->accept_ctx, newAcceptContext, std::memory_order_release);
-}
-
-void DOHFrontend::setup()
-{
-  registerOpenSSLUser();
-
-  d_dsc = std::make_shared<DOHServerConfig>(d_idleTimeout, d_internalPipeBufferSize);
-
-  if  (isHTTPS()) {
-    try {
-      setupTLSContext(*d_dsc->accept_ctx,
-                      d_tlsConfig,
-                      d_tlsCounters);
-    }
-    catch (const std::runtime_error& e) {
-      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_local.toStringWithPort() + "': " + e.what());
-    }
-  }
-}
-
 static h2o_pathconf_t *register_handler(h2o_hostconf_t *hostconf, const char *path, int (*on_req)(h2o_handler_t *, h2o_req_t *))
 {
   h2o_pathconf_t *pathconf = h2o_config_register_path(hostconf, path, 0);
@@ -1598,7 +1521,7 @@ void dohThread(ClientState* cs)
     std::shared_ptr<DOHFrontend>& df = cs->dohFrontend;
     auto& dsc = df->d_dsc;
     dsc->cs = cs;
-    dsc->df = cs->dohFrontend;
+    std::atomic_store_explicit(&dsc->df, cs->dohFrontend, std::memory_order_release);
     dsc->h2o_config.server_name = h2o_iovec_init(df->d_serverTokens.c_str(), df->d_serverTokens.size());
 
 #ifndef USE_SINGLE_ACCEPTOR_THREAD
@@ -1609,11 +1532,11 @@ void dohThread(ClientState* cs)
     setThreadName("dnsdist/doh");
     // I wonder if this registers an IP address.. I think it does
     // this may mean we need to actually register a site "name" here and not the IP address
-    h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(df->d_local.toString().c_str(), df->d_local.toString().size()), 65535);
+    h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(df->d_tlsContext.d_addr.toString().c_str(), df->d_tlsContext.d_addr.toString().size()), 65535);
 
-    for(const auto& url : df->d_urls) {
+    dsc->paths = df->d_urls;
+    for (const auto& url : dsc->paths) {
       register_handler(hostconf, url.c_str(), doh_handler);
-      dsc->paths.insert(url);
     }
 
     h2o_context_init(&dsc->h2o_ctx, h2o_evloop_create(), &dsc->h2o_config);
@@ -1632,11 +1555,11 @@ void dohThread(ClientState* cs)
     setupAcceptContext(*dsc->accept_ctx, *dsc, false);
 
     if (create_listener(dsc, cs->tcpFD) != 0) {
-      throw std::runtime_error("DOH server failed to listen on " + df->d_local.toStringWithPort() + ": " + strerror(errno));
+      throw std::runtime_error("DOH server failed to listen on " + df->d_tlsContext.d_addr.toStringWithPort() + ": " + strerror(errno));
     }
     for (const auto& [addr, fd] : cs->d_additionalAddresses) {
       if (create_listener(dsc, fd) != 0) {
-        throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + df->d_local.toStringWithPort() + ": " + strerror(errno));
+        throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + df->d_tlsContext.d_addr.toStringWithPort() + ": " + strerror(errno));
       }
     }
 
@@ -1661,25 +1584,31 @@ void dohThread(ClientState* cs)
   }
 }
 
-void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, InternalQueryState&& state)
+void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, InternalQueryState&& state, const std::shared_ptr<DownstreamState>&)
 {
-  du->response = std::move(udpResponse);
+  auto du = std::unique_ptr<DOHUnit>(this);
   du->ids = std::move(state);
 
-  const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(du->response.data());
-  if (!dh->tc) {
+  {
+    const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(udpResponse.data());
+    if (dh->tc) {
+      du->truncated = true;
+    }
+  }
+  if (!du->truncated) {
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
     static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
 
-    DNSResponse dr(du->ids, du->response, du->downstream);
+    DNSResponse dr(du->ids, udpResponse, du->downstream);
     dnsheader cleartextDH;
     memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
     dr.ids.du = std::move(du);
-    if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) {
+    if (!processResponse(udpResponse, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) {
       if (dr.ids.du) {
-        dr.ids.du->status_code = 503;
-        sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules");
+        du = getDUFromIDS(dr.ids);
+        du->status_code = 503;
+        sendDoHUnitToTheMainThread(std::move(du), "Response dropped by rules");
       }
       return;
     }
@@ -1688,7 +1617,8 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse,
       return;
     }
 
-    du = std::move(dr.ids.du);
+    du = getDUFromIDS(dr.ids);
+    du->response = std::move(udpResponse);
     double udiff = du->ids.queryRealTime.udiff();
     vinfolog("Got answer from %s, relayed to %s (https), took %f us", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff);
 
@@ -1699,17 +1629,72 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse,
       ++du->ids.cs->responses;
     }
   }
-  else {
-    du->truncated = true;
-  }
 
   sendDoHUnitToTheMainThread(std::move(du), "DoH response");
 }
-#endif /* HAVE_LIBH2OEVLOOP */
-#else /* HAVE_DNS_OVER_HTTPS */
 
-void handleDOHTimeout(DOHUnitUniquePtr&& oldDU)
+void H2ODOHFrontend::rotateTicketsKey(time_t now)
+{
+  if (d_dsc && d_dsc->accept_ctx) {
+    d_dsc->accept_ctx->rotateTicketsKey(now);
+  }
+}
+
+void H2ODOHFrontend::loadTicketsKeys(const std::string& keyFile)
+{
+  if (d_dsc && d_dsc->accept_ctx) {
+    d_dsc->accept_ctx->loadTicketsKeys(keyFile);
+  }
+}
+
+void H2ODOHFrontend::handleTicketsKeyRotation()
+{
+  if (d_dsc && d_dsc->accept_ctx) {
+    d_dsc->accept_ctx->handleTicketsKeyRotation();
+  }
+}
+
+std::string H2ODOHFrontend::getNextTicketsKeyRotation() const
+{
+  if (d_dsc && d_dsc->accept_ctx) {
+    return std::to_string(d_dsc->accept_ctx->getNextTicketsKeyRotation());
+  }
+  return 0;
+}
+
+size_t H2ODOHFrontend::getTicketsKeysCount()
+{
+  size_t res = 0;
+  if (d_dsc && d_dsc->accept_ctx) {
+    res = d_dsc->accept_ctx->getTicketsKeysCount();
+  }
+  return res;
+}
+
+void H2ODOHFrontend::reloadCertificates()
+{
+  auto newAcceptContext = std::make_shared<DOHAcceptContext>();
+  setupAcceptContext(*newAcceptContext, *d_dsc, true);
+  std::atomic_store_explicit(&d_dsc->accept_ctx, newAcceptContext, std::memory_order_release);
+}
+
+void H2ODOHFrontend::setup()
 {
+  registerOpenSSLUser();
+
+  d_dsc = std::make_shared<DOHServerConfig>(d_idleTimeout, d_internalPipeBufferSize);
+
+  if  (isHTTPS()) {
+    try {
+      setupTLSContext(*d_dsc->accept_ctx,
+                      d_tlsContext.d_tlsConfig,
+                      d_tlsContext.d_tlsCounters);
+    }
+    catch (const std::runtime_error& e) {
+      throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort() + "': " + e.what());
+    }
+  }
 }
 
-#endif /* HAVE_DNS_OVER_HTTPS */
+#endif /* HAVE_LIBH2OEVLOOP */
+#endif /* HAVE_LIBH2OEVLOOP */
index c7e638b219674524ca5c28a00cd53ec6e609b6b9..9d437578f77fb48fe8e71844841e38ddbf141e27 100644 (file)
@@ -34,39 +34,10 @@ std::vector<std::unique_ptr<ClientState>> g_frontends;
 /* add stub implementations, we don't want to include the corresponding object files
    and their dependencies */
 
-#ifdef HAVE_DNS_OVER_HTTPS
-std::unordered_map<std::string, std::string> DOHUnit::getHTTPHeaders() const
-{
-  return {};
-}
-
-std::string DOHUnit::getHTTPPath() const
-{
-  return "";
-}
-
-std::string DOHUnit::getHTTPHost() const
-{
-  return "";
-}
-
-std::string DOHUnit::getHTTPScheme() const
-{
-  return "";
-}
-
-std::string DOHUnit::getHTTPQueryString() const
-{
-  return "";
-}
-
-void DOHUnit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body_, const std::string& contentType_)
-{
-}
-#endif /* HAVE_DNS_OVER_HTTPS */
-
-void handleDOHTimeout(DOHUnitUniquePtr&& oldDU)
+// NOLINTNEXTLINE(readability-convert-member-functions-to-static): this is a stub, the real one is not that simple..
+bool TLSFrontend::setupTLS()
 {
+  return true;
 }
 
 std::string DNSQuestion::getTrailingData() const
index 6f3816c300fa8f131eed7da0b86aed551e081df5..58a26f16918f6e33da6ecd2ed728800c25e324c3 100644 (file)
  */
 #pragma once
 
-#pragma once
-
-#include <unordered_map>
-
-#include "channel.hh"
-#include "iputils.hh"
-#include "libssl.hh"
-#include "noinitvector.hh"
-#include "stat_t.hh"
-
-struct DOHServerConfig;
-
-class DOHResponseMapEntry
-{
-public:
-  DOHResponseMapEntry(const std::string& regex, uint16_t status, const PacketBuffer& content, const boost::optional<std::unordered_map<std::string, std::string>>& headers): d_regex(regex), d_customHeaders(headers), d_content(content), d_status(status)
-  {
-    if (status >= 400 && !d_content.empty() && d_content.at(d_content.size() -1) != 0) {
-      // we need to make sure it's null-terminated
-      d_content.push_back(0);
-    }
-  }
-
-  bool matches(const std::string& path) const
-  {
-    return d_regex.match(path);
-  }
-
-  uint16_t getStatusCode() const
-  {
-    return d_status;
-  }
-
-  const PacketBuffer& getContent() const
-  {
-    return d_content;
-  }
-
-  const boost::optional<std::unordered_map<std::string, std::string>>& getHeaders() const
-  {
-    return d_customHeaders;
-  }
-
-private:
-  Regex d_regex;
-  boost::optional<std::unordered_map<std::string, std::string>> d_customHeaders;
-  PacketBuffer d_content;
-  uint16_t d_status;
-};
-
-struct DOHFrontend
-{
-  DOHFrontend()
-  {
-  }
-
-  std::shared_ptr<DOHServerConfig> d_dsc{nullptr};
-  std::shared_ptr<std::vector<std::shared_ptr<DOHResponseMapEntry>>> d_responsesMap;
-  TLSConfig d_tlsConfig;
-  TLSErrorCounters d_tlsCounters;
-  std::string d_serverTokens{"h2o/dnsdist"};
-  std::unordered_map<std::string, std::string> d_customResponseHeaders;
-  ComboAddress d_local;
-
-  uint32_t d_idleTimeout{30};             // HTTP idle timeout in seconds
-  std::vector<std::string> d_urls;
-
-  pdns::stat_t d_httpconnects{0};   // number of TCP/IP connections established
-  pdns::stat_t d_getqueries{0};     // valid DNS queries received via GET
-  pdns::stat_t d_postqueries{0};    // valid DNS queries received via POST
-  pdns::stat_t d_badrequests{0};     // request could not be converted to dns query
-  pdns::stat_t d_errorresponses{0}; // dnsdist set 'error' on response
-  pdns::stat_t d_redirectresponses{0}; // dnsdist set 'redirect' on response
-  pdns::stat_t d_validresponses{0}; // valid responses sent out
-
-  struct HTTPVersionStats
-  {
-    pdns::stat_t d_nbQueries{0}; // valid DNS queries received
-    pdns::stat_t d_nb200Responses{0};
-    pdns::stat_t d_nb400Responses{0};
-    pdns::stat_t d_nb403Responses{0};
-    pdns::stat_t d_nb500Responses{0};
-    pdns::stat_t d_nb502Responses{0};
-    pdns::stat_t d_nbOtherResponses{0};
-  };
-
-  HTTPVersionStats d_http1Stats;
-  HTTPVersionStats d_http2Stats;
-#ifdef __linux__
-  // On Linux this gives us 128k pending queries (default is 8192 queries),
-  // which should be enough to deal with huge spikes
-  uint32_t d_internalPipeBufferSize{1024*1024};
-#else
-  uint32_t d_internalPipeBufferSize{0};
-#endif
-  bool d_sendCacheControlHeaders{true};
-  bool d_trustForwardedForHeader{false};
-  /* whether we require tue query path to exactly match one of configured ones,
-     or accept everything below these paths. */
-  bool d_exactPathMatching{true};
-  bool d_keepIncomingHeaders{false};
-
-  time_t getTicketsKeyRotationDelay() const
-  {
-    return d_tlsConfig.d_ticketsKeyRotationDelay;
-  }
-
-  bool isHTTPS() const
-  {
-    return !d_tlsConfig.d_certKeyPairs.empty();
-  }
-
-#ifndef HAVE_DNS_OVER_HTTPS
-  void setup()
-  {
-  }
-
-  void reloadCertificates()
-  {
-  }
+#include "config.h"
 
-  void rotateTicketsKey(time_t /* now */)
-  {
-  }
-
-  void loadTicketsKeys(const std::string& /* keyFile */)
-  {
-  }
-
-  void handleTicketsKeyRotation()
-  {
-  }
-
-  time_t getNextTicketsKeyRotation() const
-  {
-    return 0;
-  }
-
-  size_t getTicketsKeysCount() const
-  {
-    size_t res = 0;
-    return res;
-  }
-
-#else
-  void setup();
-  void reloadCertificates();
-
-  void rotateTicketsKey(time_t now);
-  void loadTicketsKeys(const std::string& keyFile);
-  void handleTicketsKeyRotation();
-  time_t getNextTicketsKeyRotation() const;
-  size_t getTicketsKeysCount() const;
-#endif /* HAVE_DNS_OVER_HTTPS */
-};
+#ifdef HAVE_DNS_OVER_HTTPS
+#ifdef HAVE_LIBH2OEVLOOP
 
-#ifndef HAVE_DNS_OVER_HTTPS
-struct DOHUnit
-{
-  size_t proxyProtocolPayloadSize{0};
-  uint16_t status_code{200};
-};
+#include <ctime>
+#include <memory>
+#include <string>
 
-#else /* HAVE_DNS_OVER_HTTPS */
-#ifdef HAVE_LIBH2OEVLOOP
-#include <unordered_map>
+struct CrossProtocolQuery;
+struct DNSQuestion;
 
-#include "dnsdist-idstate.hh"
+std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse);
 
-struct st_h2o_req_t;
-struct DownstreamState;
+#include "dnsdist-doh-common.hh"
 
-struct DOHUnit
+struct H2ODOHFrontend : public DOHFrontend
 {
-  DOHUnit(PacketBuffer&& q, std::string&& p, std::string&& h): path(std::move(p)), host(std::move(h)), query(std::move(q))
-  {
-    ids.ednsAdded = false;
-  }
-
-  DOHUnit(const DOHUnit&) = delete;
-  DOHUnit& operator=(const DOHUnit&) = delete;
+public:
 
-  InternalQueryState ids;
-  std::string sni;
-  std::string path;
-  std::string scheme;
-  std::string host;
-  std::string contentType;
-  PacketBuffer query;
-  PacketBuffer response;
-  std::shared_ptr<DownstreamState> downstream{nullptr};
-  std::unique_ptr<std::unordered_map<std::string, std::string>> headers;
-  st_h2o_req_t* req{nullptr};
-  DOHUnit** self{nullptr};
-  DOHServerConfig* dsc{nullptr};
-  pdns::channel::Sender<DOHUnit>* responseSender{nullptr};
-  size_t query_at{0};
-  size_t proxyProtocolPayloadSize{0};
-  int rsock{-1};
-  /* the status_code is set from
-     processDOHQuery() (which is executed in
-     the DOH client thread) so that the correct
-     response can be sent in on_dnsdist(),
-     after the DOHUnit has been passed back to
-     the main DoH thread.
-  */
-  uint16_t status_code{200};
-  /* whether the query was re-sent to the backend over
-     TCP after receiving a truncated answer over UDP */
-  bool tcp{false};
-  bool truncated{false};
+  void setup() override;
+  void reloadCertificates() override;
 
-  std::string getHTTPPath() const;
-  std::string getHTTPHost() const;
-  std::string getHTTPScheme() const;
-  std::string getHTTPQueryString() const;
-  std::unordered_map<std::string, std::string> getHTTPHeaders() const;
-  void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType="");
+  void rotateTicketsKey(time_t now) override;
+  void loadTicketsKeys(const std::string& keyFile) override;
+  void handleTicketsKeyRotation() override;
+  std::string getNextTicketsKeyRotation() const override;
+  size_t getTicketsKeysCount() override;
 };
 
-void handleUDPResponseForDoH(std::unique_ptr<DOHUnit>&&, PacketBuffer&& response, InternalQueryState&& state);
-
-struct CrossProtocolQuery;
-struct DNSQuestion;
-
-std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion& dq, bool isResponse);
+void dohThread(ClientState* clientState);
 
 #endif /* HAVE_LIBH2OEVLOOP */
 #endif /* HAVE_DNS_OVER_HTTPS  */
-
-using DOHUnitUniquePtr = std::unique_ptr<DOHUnit>;
-
-void handleDOHTimeout(DOHUnitUniquePtr&& oldDU);
index c4fe42b8aafe46586aa0b81b34c89d167cf88287..c51a930c04f8566c09a959e927374133f4a11a5a 100644 (file)
@@ -54,9 +54,9 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs
   return false;
 }
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
 {
-  return false;
+  return true;
 }
 
 namespace dnsdist {