]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Implement incoming DoH support via nghttp2
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 31 Jul 2023 15:07:05 +0000 (17:07 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Sep 2023 07:19:16 +0000 (09:19 +0200)
31 files changed:
pdns/dnsdist-doh-common.hh
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/configure.ac
pdns/dnsdistdist/dnsdist-async.cc
pdns/dnsdistdist/dnsdist-healthchecks.cc
pdns/dnsdistdist/dnsdist-internal-queries.cc
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-nghttp2-in.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-nghttp2-in.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-nghttp2.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-tcp.hh
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/m4/dnsdist_enable_doh.m4
pdns/dnsdistdist/m4/pdns_with_nghttp2.m4
pdns/dnsdistdist/test-dnsdistasync.cc
pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/doh.hh
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh
pdns/test-dnsdist_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_DOH.py
regression-tests.dnsdist/test_Protobuf.py

index 44ad826a88d8b693a3e09b444adfa954e868beb9..41166de9f3023bc2908d96f7bdae260b7576be96 100644 (file)
@@ -77,10 +77,6 @@ struct DOHFrontend
   DOHFrontend()
   {
   }
-  DOHFrontend(std::shared_ptr<TLSCtx> tlsCtx) :
-    d_tlsContext(std::move(tlsCtx))
-  {
-  }
 
   virtual ~DOHFrontend()
   {
@@ -126,6 +122,7 @@ struct DOHFrontend
 #endif
   bool d_sendCacheControlHeaders{true};
   bool d_trustForwardedForHeader{false};
+  bool d_earlyACLDrop{true};
   /* whether we require tue query path to exactly match one of configured ones,
      or accept everything below these paths. */
   bool d_exactPathMatching{true};
index a29c17c3ade06f31bd6ba907ee3d8ded6ca94acb..f71a9bbf4d97863bbd574da3bd66805ec10f48d3 100644 (file)
@@ -284,7 +284,7 @@ public:
 
     struct timeval now;
     gettimeofday(&now, nullptr);
-    sender->notifyIOError(std::move(object->query.d_idstate), now);
+    sender->notifyIOError(now, TCPResponse(std::move(object->query)));
     return true;
   }
 
index c829c2e1b5c1e35d57dc43eea9b92bb4481d7ba0..fedd9c5d868a1130dbfdaaf7870be956a56c24a1 100644 (file)
@@ -2337,14 +2337,34 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     setLuaSideEffect();
 
     auto frontend = std::make_shared<DOHFrontend>();
+    if (getOptionalValue<std::string>(vars, "library", frontend->d_library) == 0) {
+#ifdef HAVE_NGHTTP2
+      frontend->d_library = "nghttp2";
+#else /* HAVE_NGHTTP2 */
+        frontend->d_library = "h2o";
+#endif /* HAVE_NGHTTP2 */
+    }
+    if (frontend->d_library == "h2o") {
 #ifdef HAVE_LIBH2OEVLOOP
-    frontend = std::make_shared<H2ODOHFrontend>();
-    frontend->d_library = "h2o";
+      frontend = std::make_shared<H2ODOHFrontend>();
+      frontend->d_library = "h2o";
 #else /* HAVE_LIBH2OEVLOOP */
-    errlog("DOH bind %s is configured to use libh2o but the library is not available", addr);
-    return;
+        errlog("DOH bind %s is configured to use libh2o but the library is not available", addr);
+        return;
 #endif /* HAVE_LIBH2OEVLOOP */
+    }
+    else if (frontend->d_library == "nghttp2") {
+#ifndef HAVE_NGHTTP2
+      errlog("DOH bind %s is configured to use nghttp2 but the library is not available", addr);
+      return;
+#endif /* HAVE_NGHTTP2 */
+    }
+    else {
+      errlog("DOH bind %s is configured to use an unknown library ('%s')", addr, frontend->d_library);
+      return;
+    }
 
+    bool useTLS = true;
     if (certFiles && !certFiles->empty()) {
       if (!loadTLSCertificateAndKeys("addDOHLocal", frontend->d_tlsContext.d_tlsConfig.d_certKeyPairs, *certFiles, *keyFiles)) {
         return;
@@ -2355,6 +2375,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     else {
       frontend->d_tlsContext.d_addr = ComboAddress(addr, 80);
       infolog("No certificate provided for DoH endpoint %s, running in DNS over HTTP mode instead of DNS over HTTPS", frontend->d_tlsContext.d_addr.toStringWithPort());
+      useTLS = false;
     }
 
     if (urls) {
@@ -2385,6 +2406,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       parseLocalBindVars(vars, reusePort, tcpFastOpenQueueSize, interface, cpus, tcpListenQueueSize, maxInFlightQueriesPerConn, tcpMaxConcurrentConnections);
       getOptionalValue<int>(vars, "idleTimeout", frontend->d_idleTimeout);
       getOptionalValue<std::string>(vars, "serverTokens", frontend->d_serverTokens);
+      getOptionalValue<std::string>(vars, "provider", frontend->d_tlsContext.d_provider);
+      boost::algorithm::to_lower(frontend->d_tlsContext.d_provider);
 
       LuaAssociativeTable<std::string> customResponseHeaders;
       if (getOptionalValue<decltype(customResponseHeaders)>(vars, "customResponseHeaders", customResponseHeaders) > 0) {
@@ -2397,6 +2420,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       getOptionalValue<bool>(vars, "sendCacheControlHeaders", frontend->d_sendCacheControlHeaders);
       getOptionalValue<bool>(vars, "keepIncomingHeaders", frontend->d_keepIncomingHeaders);
       getOptionalValue<bool>(vars, "trustForwardedForHeader", frontend->d_trustForwardedForHeader);
+      getOptionalValue<bool>(vars, "earlyACLDrop", frontend->d_earlyACLDrop);
       getOptionalValue<int>(vars, "internalPipeBufferSize", frontend->d_internalPipeBufferSize);
       getOptionalValue<bool>(vars, "exactPathMatching", frontend->d_exactPathMatching);
 
@@ -2432,6 +2456,21 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 
       checkAllParametersConsumed("addDOHLocal", vars);
     }
+
+    if (useTLS && frontend->d_library == "nghttp2") {
+      if (!frontend->d_tlsContext.d_provider.empty()) {
+        vinfolog("Loading TLS provider '%s'", frontend->d_tlsContext.d_provider);
+      }
+      else {
+#ifdef HAVE_LIBSSL
+        const std::string provider("openssl");
+#else
+          const std::string provider("gnutls");
+#endif
+        vinfolog("Loading default TLS provider '%s'", provider);
+      }
+    }
+
     g_dohlocals.push_back(frontend);
     auto cs = std::make_unique<ClientState>(frontend->d_tlsContext.d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus);
     cs->dohFrontend = frontend;
@@ -2648,10 +2687,11 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       }
       else {
 #ifdef HAVE_LIBSSL
-        vinfolog("Loading default TLS provider 'openssl'");
+        const std::string provider("openssl");
 #else
-          vinfolog("Loading default TLS provider 'gnutls'");
+          const std::string provider("gnutls");
 #endif
+        vinfolog("Loading default TLS provider '%s'", provider);
       }
       // only works pre-startup, so no sync necessary
       auto cs = std::make_unique<ClientState>(frontend->d_addr, true, reusePort, tcpFastOpenQueueSize, interface, cpus);
index 14af2564e32c08e2ecd58f20ba0e03b52d8c199c..751ba986722a4075d1ebbc3aa75db7443c5ebae1 100644 (file)
@@ -27,6 +27,7 @@
 #include "dnsdist.hh"
 #include "dnsdist-concurrent-connections.hh"
 #include "dnsdist-ecs.hh"
+#include "dnsdist-nghttp2-in.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rings.hh"
 #include "dnsdist-tcp.hh"
@@ -96,6 +97,17 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState()
   d_handler.close();
 }
 
+dnsdist::Protocol IncomingTCPConnectionState::getProtocol() const
+{
+  if (d_ci.cs->dohFrontend) {
+    return dnsdist::Protocol::DoH;
+  }
+  if (d_handler.isTLS()) {
+    return dnsdist::Protocol::DoT;
+  }
+  return dnsdist::Protocol::DoTCP;
+}
+
 size_t IncomingTCPConnectionState::clearAllDownstreamConnections()
 {
   return t_downstreamTCPConnectionsManager.clear();
@@ -173,7 +185,7 @@ static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>&
     TCPResponse resp = std::move(state->d_queuedResponses.front());
     state->d_queuedResponses.pop_front();
     state->d_state = IncomingTCPConnectionState::State::idle;
-    result = state->sendResponse(state, now, std::move(resp));
+    result = state->sendResponse(now, std::move(resp));
     if (result != IOState::Done) {
       return result;
     }
@@ -183,28 +195,28 @@ static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>&
   return IOState::Done;
 }
 
-static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, TCPResponse& currentResponse)
+void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse)
 {
   if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) {
     return;
   }
 
-  --state->d_currentQueriesCount;
+  --d_currentQueriesCount;
 
   const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds;
   if (currentResponse.d_idstate.selfGenerated == false && ds) {
     const auto& ids = currentResponse.d_idstate;
     double udiff = ids.queryRealTime.udiff();
-    vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff);
+    vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), currentResponse.d_buffer.size(), udiff);
 
     auto backendProtocol = ds->getProtocol();
-    if (backendProtocol == dnsdist::Protocol::DoUDP) {
+    if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) {
       backendProtocol = dnsdist::Protocol::DoTCP;
     }
-    ::handleResponseSent(ids, udiff, state->d_ci.remote, ds->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true);
+    ::handleResponseSent(ids, udiff, d_ci.remote, ds->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true);
   } else {
     const auto& ids = currentResponse.d_idstate;
-    ::handleResponseSent(ids, 0., state->d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false);
+    ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false);
   }
 
   currentResponse.d_buffer.clear();
@@ -232,7 +244,8 @@ bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now)
     return false;
   }
 
-  if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
+  // for DoH, this is already handled by the underlying library
+  if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) {
     DEBUGLOG("not accepting new queries because we already have "<<d_currentQueriesCount<<" out of "<<d_ci.cs->d_maxInFlightQueriesPerConn);
     return false;
   }
@@ -284,9 +297,9 @@ void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_p
 }
 
 /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
-IOState IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
+IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response)
 {
-  state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+  d_state = IncomingTCPConnectionState::State::sendingResponse;
 
   uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size());
   const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) };
@@ -294,27 +307,27 @@ IOState IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConn
      that could occur if we had to deal with the size during the processing,
      especially alignment issues */
   response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2);
-  state->d_currentPos = 0;
-  state->d_currentResponse = std::move(response);
+  d_currentPos = 0;
+  d_currentResponse = std::move(response);
 
   try {
-    auto iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
+    auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
     if (iostate == IOState::Done) {
       DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
-      handleResponseSent(state, state->d_currentResponse);
+      handleResponseSent(d_currentResponse);
       return iostate;
     } else {
-      state->d_lastIOBlocked = true;
+      d_lastIOBlocked = true;
       DEBUGLOG("partial write");
       return iostate;
     }
   }
   catch (const std::exception& e) {
-    vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
+    vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
     DEBUGLOG("Closing TCP client connection: "<<e.what());
-    ++state->d_ci.cs->tcpDiedSendingResponse;
+    ++d_ci.cs->tcpDiedSendingResponse;
 
-    state->terminateClientConnection();
+    terminateClientConnection();
 
     return IOState::Done;
   }
@@ -408,9 +421,7 @@ void IncomingTCPConnectionState::handleAsyncReady(int fd, FDMultiplexer::funcpar
 
   if (state->active()) {
     /* and now we restart our own I/O state machine */
-    struct timeval now;
-    gettimeofday(&now, nullptr);
-    handleIO(state, now);
+    state->handleIO();
   }
   else {
     /* we were only waiting for the engine to come back,
@@ -476,16 +487,17 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe
     try {
       auto& ids = response.d_idstate;
       unsigned int qnameWireLength;
-      if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) {
+      std::shared_ptr<DownstreamState> ds = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
+      if (!ds || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, ds, qnameWireLength)) {
         state->terminateClientConnection();
         return;
       }
 
-      if (response.d_connection->getDS()) {
-        ++response.d_connection->getDS()->responses;
+      if (ds) {
+        ++ds->responses;
       }
 
-      DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS());
+      DNSResponse dr(ids, response.d_buffer, ds);
       dr.d_incomingTCPState = state;
 
       memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH));
@@ -529,7 +541,6 @@ class TCPCrossProtocolQuery : public CrossProtocolQuery
 public:
   TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
   {
-    proxyProtocolPayloadSize = 0;
   }
 
   ~TCPCrossProtocolQuery()
@@ -561,6 +572,11 @@ private:
   std::shared_ptr<IncomingTCPConnectionState> d_sender;
 };
 
+std::unique_ptr<CrossProtocolQuery> IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds)
+{
+  return std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(state), ds, shared_from_this());
+}
+
 std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq)
 {
   auto state = dq.getIncomingTCPState();
@@ -587,60 +603,63 @@ void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeva
   }
 }
 
-static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
+IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::handleQuery(PacketBuffer&& queryIn, const struct timeval& now, std::optional<int32_t> streamID)
 {
-  if (state->d_querySize < sizeof(dnsheader)) {
+  auto query = std::move(queryIn);
+  if (query.size() < sizeof(dnsheader)) {
     ++dnsdist::metrics::g_stats.nonCompliantQueries;
-    ++state->d_ci.cs->nonCompliantQueries;
-    state->terminateClientConnection();
-    return;
+    ++d_ci.cs->nonCompliantQueries;
+    return QueryProcessingResult::TooSmall;
   }
 
-  ++state->d_queriesCount;
-  ++state->d_ci.cs->queries;
+  ++d_queriesCount;
+  ++d_ci.cs->queries;
   ++dnsdist::metrics::g_stats.queries;
 
-  if (state->d_handler.isTLS()) {
-    auto tlsVersion = state->d_handler.getTLSVersion();
+  if (d_handler.isTLS()) {
+    auto tlsVersion = d_handler.getTLSVersion();
     switch (tlsVersion) {
     case LibsslTLSVersion::TLS10:
-      ++state->d_ci.cs->tls10queries;
+      ++d_ci.cs->tls10queries;
       break;
     case LibsslTLSVersion::TLS11:
-      ++state->d_ci.cs->tls11queries;
+      ++d_ci.cs->tls11queries;
       break;
     case LibsslTLSVersion::TLS12:
-      ++state->d_ci.cs->tls12queries;
+      ++d_ci.cs->tls12queries;
       break;
     case LibsslTLSVersion::TLS13:
-      ++state->d_ci.cs->tls13queries;
+      ++d_ci.cs->tls13queries;
       break;
     default:
-      ++state->d_ci.cs->tlsUnknownqueries;
+      ++d_ci.cs->tlsUnknownqueries;
     }
   }
 
+  auto state = shared_from_this();
   InternalQueryState ids;
-  ids.origDest = state->d_proxiedDestination;
-  ids.origRemote = state->d_proxiedRemote;
-  ids.cs = state->d_ci.cs;
+  ids.origDest = d_proxiedDestination;
+  ids.origRemote = d_proxiedRemote;
+  ids.cs = d_ci.cs;
   ids.queryRealTime.start();
+  if (streamID) {
+    ids.d_streamID = *streamID;
+  }
 
-  auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
+  auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true);
   if (dnsCryptResponse) {
     TCPResponse response;
-    state->d_state = IncomingTCPConnectionState::State::idle;
-    ++state->d_currentQueriesCount;
-    state->queueResponse(state, now, std::move(response));
-    return;
+    d_state = IncomingTCPConnectionState::State::idle;
+    ++d_currentQueriesCount;
+    queueResponse(state, now, std::move(response));
+    return QueryProcessingResult::SelfAnswered;
   }
 
   {
     /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
-    auto* dh = reinterpret_cast<dnsheader*>(state->d_buffer.data());
-    if (!checkQueryHeaders(dh, *state->d_ci.cs)) {
-      state->terminateClientConnection();
-      return;
+    auto* dh = reinterpret_cast<dnsheader*>(query.data());
+    if (!checkQueryHeaders(dh, *d_ci.cs)) {
+      return QueryProcessingResult::InvalidHeaders;
     }
 
     if (dh->qdcount == 0) {
@@ -648,81 +667,105 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
       dh->rcode = RCode::NotImp;
       dh->qr = true;
       response.d_idstate.selfGenerated = true;
-      response.d_buffer = std::move(state->d_buffer);
-      state->d_state = IncomingTCPConnectionState::State::idle;
-      ++state->d_currentQueriesCount;
-      state->queueResponse(state, now, std::move(response));
-      return;
+      response.d_buffer = std::move(query);
+      d_state = IncomingTCPConnectionState::State::idle;
+      ++d_currentQueriesCount;
+      queueResponse(state, now, std::move(response));
+      return QueryProcessingResult::Empty;
     }
   }
 
-  ids.qname = DNSName(reinterpret_cast<const char*>(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
-  ids.protocol = dnsdist::Protocol::DoTCP;
+  ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
+  ids.protocol = getProtocol();
   if (ids.dnsCryptQuery) {
     ids.protocol = dnsdist::Protocol::DNSCryptTCP;
   }
-  else if (state->d_handler.isTLS()) {
-    ids.protocol = dnsdist::Protocol::DoT;
-  }
 
-  DNSQuestion dq(ids, state->d_buffer);
+  DNSQuestion dq(ids, query);
   const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
   ids.origFlags = *flags;
   dq.d_incomingTCPState = state;
-  dq.sni = state->d_handler.getServerNameIndication();
+  dq.sni = d_handler.getServerNameIndication();
 
-  if (state->d_proxyProtocolValues) {
+  if (d_proxyProtocolValues) {
     /* we need to copy them, because the next queries received on that connection will
        need to get the _unaltered_ values */
-    dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*state->d_proxyProtocolValues);
+    dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*d_proxyProtocolValues);
   }
 
   if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) {
     dq.ids.skipCache = true;
   }
 
-  std::shared_ptr<DownstreamState> ds;
-  auto result = processQuery(dq, state->d_threadData.holders, ds);
+  if (forwardViaUDPFirst()) {
+    // if there was no EDNS, we add it with a large buffer size
+    // so we can use UDP to talk to the backend.
+    auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(query.data()));
+    if (!dh->arcount) {
+      if (addEDNS(query, 4096, false, 4096, 0)) {
+        dq.ids.ednsAdded = true;
+      }
+    }
+  }
 
-  if (result == ProcessQueryResult::Drop) {
-    state->terminateClientConnection();
-    return;
+  if (streamID) {
+    auto unit = getDOHUnit(*streamID);
+    dq.ids.du = std::move(unit);
   }
-  else if (result == ProcessQueryResult::Asynchronous) {
+
+  std::shared_ptr<DownstreamState> ds;
+  auto result = processQuery(dq, d_threadData.holders, ds);
+
+  if (result == ProcessQueryResult::Asynchronous) {
     /* we are done for now */
-    ++state->d_currentQueriesCount;
-    return;
+    ++d_currentQueriesCount;
+    return QueryProcessingResult::Asynchronous;
+  }
+
+  if (streamID) {
+    restoreDOHUnit(std::move(dq.ids.du));
+  }
+
+  if (result == ProcessQueryResult::Drop) {
+    return QueryProcessingResult::Dropped;
   }
 
   // the buffer might have been invalidated by now
-  const dnsheader* dh = dq.getHeader();
+  uint16_t queryID;
+  {
+    const dnsheader* dh = dq.getHeader();
+    queryID = dh->id;
+  }
+
   if (result == ProcessQueryResult::SendAnswer) {
     TCPResponse response;
-    memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH));
+    {
+      const dnsheader* dh = dq.getHeader();
+      memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH));
+    }
     response.d_idstate = std::move(ids);
-    response.d_idstate.origID = dh->id;
+    response.d_idstate.origID = queryID;
     response.d_idstate.selfGenerated = true;
-    response.d_idstate.cs = state->d_ci.cs;
-    response.d_buffer = std::move(state->d_buffer);
+    response.d_idstate.cs = d_ci.cs;
+    response.d_buffer = std::move(query);
 
-    state->d_state = IncomingTCPConnectionState::State::idle;
-    ++state->d_currentQueriesCount;
-    state->queueResponse(state, now, std::move(response));
-    return;
+    d_state = IncomingTCPConnectionState::State::idle;
+    ++d_currentQueriesCount;
+    queueResponse(state, now, std::move(response));
+    return QueryProcessingResult::SelfAnswered;
   }
 
   if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
-    state->terminateClientConnection();
-    return;
+    return QueryProcessingResult::NoBackend;
   }
 
-  dq.ids.origID = dh->id;
+  dq.ids.origID = queryID;
 
-  ++state->d_currentQueriesCount;
+  ++d_currentQueriesCount;
 
   std::string proxyProtocolPayload;
   if (ds->isDoH()) {
-    vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getNameWithAddr());
+    vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), ds->getNameWithAddr());
 
     /* we need to do this _before_ creating the cross protocol query because
        after that the buffer will have been moved */
@@ -730,21 +773,30 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
       proxyProtocolPayload = getProxyProtocolPayload(dq);
     }
 
-    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state);
+    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(ids), ds, state);
     cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
     ds->passCrossProtocolQuery(std::move(cpq));
-    return;
+    return QueryProcessingResult::Forwarded;
+  }
+  else if (!ds->isTCPOnly() && forwardViaUDPFirst()) {
+    auto unit = getDOHUnit(*streamID);
+    dq.ids.du = std::move(unit);
+    if (assignOutgoingUDPQueryToBackend(ds, queryID, dq, query)) {
+      return QueryProcessingResult::Forwarded;
+    }
+    restoreDOHUnit(std::move(dq.ids.du));
+    // fallback to the normal flow
   }
 
-  prependSizeToTCPQuery(state->d_buffer, 0);
+  prependSizeToTCPQuery(query, 0);
 
-  auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
+  auto downstreamConnection = getDownstreamConnection(ds, dq.proxyProtocolValues, now);
 
   if (ds->d_config.useProxyProtocol) {
     /* if we ever sent a TLV over a connection, we can never go back */
-    if (!state->d_proxyProtocolPayloadHasTLV) {
-      state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
+    if (!d_proxyProtocolPayloadHasTLV) {
+      d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
     }
 
     proxyProtocolPayload = getProxyProtocolPayload(dq);
@@ -754,12 +806,13 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
     downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues));
   }
 
-  TCPQuery query(std::move(state->d_buffer), std::move(ids));
-  query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
+  TCPQuery tcpquery(std::move(query), std::move(ids));
+  tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
-  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getNameWithAddr());
+  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), ds->getNameWithAddr());
   std::shared_ptr<TCPQuerySender> incoming = state;
-  downstreamConnection->queueQuery(incoming, std::move(query));
+  downstreamConnection->queueQuery(incoming, std::move(tcpquery));
+  return QueryProcessingResult::Forwarded;
 }
 
 void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
@@ -769,159 +822,194 @@ void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcpar
     throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor()));
   }
 
-  struct timeval now;
-  gettimeofday(&now, nullptr);
-  handleIO(conn, now);
+  conn->handleIO();
 }
 
-void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
+void IncomingTCPConnectionState::handleHandshakeDone(const struct timeval& now)
+{
+  if (d_handler.isTLS()) {
+    if (!d_handler.hasTLSSessionBeenResumed()) {
+      ++d_ci.cs->tlsNewSessions;
+    }
+    else {
+      ++d_ci.cs->tlsResumptions;
+    }
+    if (d_handler.getResumedFromInactiveTicketKey()) {
+      ++d_ci.cs->tlsInactiveTicketKey;
+    }
+    if (d_handler.getUnknownTicketKey()) {
+      ++d_ci.cs->tlsUnknownTicketKey;
+    }
+  }
+
+  d_handshakeDoneTime = now;
+}
+
+IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::handleProxyProtocolPayload()
+{
+  do {
+    DEBUGLOG("reading proxy protocol header");
+    auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed);
+    if (iostate == IOState::Done) {
+      d_buffer.resize(d_currentPos);
+      ssize_t remaining = isProxyHeaderComplete(d_buffer);
+      if (remaining == 0) {
+        vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", d_ci.remote.toStringWithPort());
+        ++dnsdist::metrics::g_stats.proxyProtocolInvalid;
+        return ProxyProtocolResult::Error;
+      }
+      else if (remaining < 0) {
+        d_proxyProtocolNeed += -remaining;
+        d_buffer.resize(d_currentPos + d_proxyProtocolNeed);
+        /* we need to keep reading, since we might have buffered data */
+      }
+      else {
+        /* proxy header received */
+        std::vector<ProxyProtocolValue> proxyProtocolValues;
+        if (!handleProxyProtocol(d_ci.remote, true, *d_threadData.holders.acl, d_buffer, d_proxiedRemote, d_proxiedDestination, proxyProtocolValues)) {
+          vinfolog("Error handling the Proxy Protocol received from TCP client %s", d_ci.remote.toStringWithPort());
+          return ProxyProtocolResult::Error;
+        }
+
+        if (!proxyProtocolValues.empty()) {
+          d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
+        }
+
+        return ProxyProtocolResult::Done;
+      }
+    }
+    else {
+      d_lastIOBlocked = true;
+    }
+  }
+  while (active() && !d_lastIOBlocked);
+
+  return ProxyProtocolResult::Reading;
+}
+
+void IncomingTCPConnectionState::handleIO()
 {
   // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
   // even though the underlying socket is not ready, so we need to actually ask for the data first
   IOState iostate = IOState::Done;
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+
   do {
     iostate = IOState::Done;
-    IOStateGuard ioGuard(state->d_ioState);
+    IOStateGuard ioGuard(d_ioState);
 
-    if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
-      vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+    if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+      vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
       // will be handled by the ioGuard
       //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
       return;
     }
 
-    state->d_lastIOBlocked = false;
+    d_lastIOBlocked = false;
 
     try {
-      if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+      if (d_state == IncomingTCPConnectionState::State::doingHandshake) {
         DEBUGLOG("doing handshake");
-        iostate = state->d_handler.tryHandshake();
+        iostate = d_handler.tryHandshake();
         if (iostate == IOState::Done) {
           DEBUGLOG("handshake done");
-          if (state->d_handler.isTLS()) {
-            if (!state->d_handler.hasTLSSessionBeenResumed()) {
-              ++state->d_ci.cs->tlsNewSessions;
-            }
-            else {
-              ++state->d_ci.cs->tlsResumptions;
-            }
-            if (state->d_handler.getResumedFromInactiveTicketKey()) {
-              ++state->d_ci.cs->tlsInactiveTicketKey;
-            }
-            if (state->d_handler.getUnknownTicketKey()) {
-              ++state->d_ci.cs->tlsUnknownTicketKey;
-            }
-          }
+          handleHandshakeDone(now);
 
-          state->d_handshakeDoneTime = now;
-          if (expectProxyProtocolFrom(state->d_ci.remote)) {
-            state->d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
-            state->d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
-            state->d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+          if (expectProxyProtocolFrom(d_ci.remote)) {
+            d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
+            d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+            d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
           }
           else {
-            state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+            d_state = IncomingTCPConnectionState::State::readingQuerySize;
           }
         }
         else {
-          state->d_lastIOBlocked = true;
+          d_lastIOBlocked = true;
         }
       }
 
-      if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
-        do {
-          DEBUGLOG("reading proxy protocol header");
-          iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_proxyProtocolNeed);
-          if (iostate == IOState::Done) {
-            state->d_buffer.resize(state->d_currentPos);
-            ssize_t remaining = isProxyHeaderComplete(state->d_buffer);
-            if (remaining == 0) {
-              vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", state->d_ci.remote.toStringWithPort());
-              ++dnsdist::metrics::g_stats.proxyProtocolInvalid;
-              break;
-            }
-            else if (remaining < 0) {
-              state->d_proxyProtocolNeed += -remaining;
-              state->d_buffer.resize(state->d_currentPos + state->d_proxyProtocolNeed);
-              /* we need to keep reading, since we might have buffered data */
-              iostate = IOState::NeedRead;
-            }
-            else {
-              /* proxy header received */
-              std::vector<ProxyProtocolValue> proxyProtocolValues;
-              if (!handleProxyProtocol(state->d_ci.remote, true, *state->d_threadData.holders.acl, state->d_buffer, state->d_proxiedRemote, state->d_proxiedDestination, proxyProtocolValues)) {
-                vinfolog("Error handling the Proxy Protocol received from TCP client %s", state->d_ci.remote.toStringWithPort());
-                break;
-              }
-
-              if (!proxyProtocolValues.empty()) {
-                state->d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
-              }
-
-              state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
-              state->d_buffer.resize(sizeof(uint16_t));
-              state->d_currentPos = 0;
-              state->d_proxyProtocolNeed = 0;
-              break;
-            }
-          }
-          else {
-            state->d_lastIOBlocked = true;
-          }
+      if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
+        auto status = handleProxyProtocolPayload();
+        if (status == ProxyProtocolResult::Done) {
+          d_state = IncomingTCPConnectionState::State::readingQuerySize;
+          d_buffer.resize(sizeof(uint16_t));
+          d_currentPos = 0;
+          d_proxyProtocolNeed = 0;
+        }
+        else if (status == ProxyProtocolResult::Error) {
+          iostate = IOState::Done;
+        }
+        else {
+          iostate = IOState::NeedRead;
         }
-        while (state->active() && !state->d_lastIOBlocked);
       }
 
-      if (!state->d_lastIOBlocked && (state->d_state == IncomingTCPConnectionState::State::waitingForQuery ||
-                                      state->d_state == IncomingTCPConnectionState::State::readingQuerySize)) {
+      if (!d_lastIOBlocked && (d_state == IncomingTCPConnectionState::State::waitingForQuery ||
+                                      d_state == IncomingTCPConnectionState::State::readingQuerySize)) {
         DEBUGLOG("reading query size");
-        state->d_buffer.resize(sizeof(uint16_t));
-        iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
-        if (state->d_currentPos > 0) {
+        d_buffer.resize(sizeof(uint16_t));
+        iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t));
+        if (d_currentPos > 0) {
           /* if we got at least one byte, we can't go around sending responses */
-          state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+          d_state = IncomingTCPConnectionState::State::readingQuerySize;
         }
 
         if (iostate == IOState::Done) {
           DEBUGLOG("query size received");
-          state->d_state = IncomingTCPConnectionState::State::readingQuery;
-          state->d_querySizeReadTime = now;
-          if (state->d_queriesCount == 0) {
-            state->d_firstQuerySizeReadTime = now;
+          d_state = IncomingTCPConnectionState::State::readingQuery;
+          d_querySizeReadTime = now;
+          if (d_queriesCount == 0) {
+            d_firstQuerySizeReadTime = now;
           }
-          state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
-          if (state->d_querySize < sizeof(dnsheader)) {
+          d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1);
+          if (d_querySize < sizeof(dnsheader)) {
             /* go away */
-            state->terminateClientConnection();
+            terminateClientConnection();
             return;
           }
 
           /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
              or to add ECS without allocating a new buffer */
-          state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
-          state->d_currentPos = 0;
+          d_buffer.resize(std::max(d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
+          d_currentPos = 0;
         }
         else {
-          state->d_lastIOBlocked = true;
+          d_lastIOBlocked = true;
         }
       }
 
-      if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+      if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::readingQuery) {
         DEBUGLOG("reading query");
-        iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+        iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize);
         if (iostate == IOState::Done) {
           DEBUGLOG("query received");
-          state->d_buffer.resize(state->d_querySize);
+          d_buffer.resize(d_querySize);
+
+          d_state = IncomingTCPConnectionState::State::idle;
+          auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt);
+          switch (processingResult) {
+          case QueryProcessingResult::TooSmall:
+            /* fall-through */
+          case QueryProcessingResult::InvalidHeaders:
+            /* fall-through */
+          case QueryProcessingResult::Dropped:
+            /* fall-through */
+          case QueryProcessingResult::NoBackend:
+            terminateClientConnection();
+            break;
+          default:
+            break;
+          }
 
-          state->d_state = IncomingTCPConnectionState::State::idle;
-          handleQuery(state, now);
           /* the state might have been updated in the meantime, we don't want to override it
              in that case */
-          if (state->active() && state->d_state != IncomingTCPConnectionState::State::idle) {
-            if (state->d_ioState->isWaitingForRead()) {
+          if (active() && d_state != IncomingTCPConnectionState::State::idle) {
+            if (d_ioState->isWaitingForRead()) {
               iostate = IOState::NeedRead;
             }
-            else if (state->d_ioState->isWaitingForWrite()) {
+            else if (d_ioState->isWaitingForWrite()) {
               iostate = IOState::NeedWrite;
             }
             else {
@@ -930,55 +1018,56 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
           }
         }
         else {
-          state->d_lastIOBlocked = true;
+          d_lastIOBlocked = true;
         }
       }
 
-      if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      if (!d_lastIOBlocked && d_state == IncomingTCPConnectionState::State::sendingResponse) {
         DEBUGLOG("sending response");
-        iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
+        iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size());
         if (iostate == IOState::Done) {
           DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__);
-          handleResponseSent(state, state->d_currentResponse);
-          state->d_state = IncomingTCPConnectionState::State::idle;
+          handleResponseSent(d_currentResponse);
+          d_state = IncomingTCPConnectionState::State::idle;
         }
         else {
-          state->d_lastIOBlocked = true;
+          d_lastIOBlocked = true;
         }
       }
 
-      if (state->active() &&
-          !state->d_lastIOBlocked &&
+      if (active() &&
+          !d_lastIOBlocked &&
           iostate == IOState::Done &&
-          (state->d_state == IncomingTCPConnectionState::State::idle ||
-           state->d_state == IncomingTCPConnectionState::State::waitingForQuery))
+          (d_state == IncomingTCPConnectionState::State::idle ||
+           d_state == IncomingTCPConnectionState::State::waitingForQuery))
       {
         // try sending queued responses
         DEBUGLOG("send responses, if any");
+        auto state = shared_from_this();
         iostate = sendQueuedResponses(state, now);
 
-        if (!state->d_lastIOBlocked && state->active() && iostate == IOState::Done) {
+        if (!d_lastIOBlocked && active() && iostate == IOState::Done) {
           // if the query has been passed to a backend, or dropped, and the responses have been sent,
           // we can start reading again
-          if (state->canAcceptNewQueries(now)) {
-            state->resetForNewQuery();
+          if (canAcceptNewQueries(now)) {
+            resetForNewQuery();
             iostate = IOState::NeedRead;
           }
           else {
-            state->d_state = IncomingTCPConnectionState::State::idle;
+            d_state = IncomingTCPConnectionState::State::idle;
             iostate = IOState::Done;
           }
         }
       }
 
-      if (state->d_state != IncomingTCPConnectionState::State::idle &&
-          state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
-          state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader &&
-          state->d_state != IncomingTCPConnectionState::State::waitingForQuery &&
-          state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
-          state->d_state != IncomingTCPConnectionState::State::readingQuery &&
-          state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
-        vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
+      if (d_state != IncomingTCPConnectionState::State::idle &&
+          d_state != IncomingTCPConnectionState::State::doingHandshake &&
+          d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader &&
+          d_state != IncomingTCPConnectionState::State::waitingForQuery &&
+          d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+          d_state != IncomingTCPConnectionState::State::readingQuery &&
+          d_state != IncomingTCPConnectionState::State::sendingResponse) {
+        vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(d_state));
       }
     }
     catch (const std::exception& e) {
@@ -986,55 +1075,56 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
          but it might also be a real IO error or something else.
          Let's just drop the connection
       */
-      if (state->d_state == IncomingTCPConnectionState::State::idle ||
-          state->d_state == IncomingTCPConnectionState::State::waitingForQuery) {
+      if (d_state == IncomingTCPConnectionState::State::idle ||
+          d_state == IncomingTCPConnectionState::State::waitingForQuery) {
         /* no need to increase any counters in that case, the client is simply done with us */
       }
-      else if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
-               state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
-               state->d_state == IncomingTCPConnectionState::State::waitingForQuery ||
-               state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
-               state->d_state == IncomingTCPConnectionState::State::readingQuery) {
-        ++state->d_ci.cs->tcpDiedReadingQuery;
+      else if (d_state == IncomingTCPConnectionState::State::doingHandshake ||
+               d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader ||
+               d_state == IncomingTCPConnectionState::State::waitingForQuery ||
+               d_state == IncomingTCPConnectionState::State::readingQuerySize ||
+               d_state == IncomingTCPConnectionState::State::readingQuery) {
+        ++d_ci.cs->tcpDiedReadingQuery;
       }
-      else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      else if (d_state == IncomingTCPConnectionState::State::sendingResponse) {
         /* unlikely to happen here, the exception should be handled in sendResponse() */
-        ++state->d_ci.cs->tcpDiedSendingResponse;
+        ++d_ci.cs->tcpDiedSendingResponse;
       }
 
-      if (state->d_ioState->isWaitingForWrite() || state->d_queriesCount == 0) {
+      if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) {
         DEBUGLOG("Got an exception while handling TCP query: "<<e.what());
-        vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->isWaitingForRead() ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+        vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), e.what());
       }
       else {
-        vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
+        vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what());
         DEBUGLOG("Closing TCP client connection: "<<e.what());
       }
       /* remove this FD from the IO multiplexer */
-      state->terminateClientConnection();
+      terminateClientConnection();
     }
 
-    if (!state->active()) {
+    if (!active()) {
       DEBUGLOG("state is no longer active");
       return;
     }
 
+    auto state = shared_from_this();
     if (iostate == IOState::Done) {
-      state->d_ioState->update(iostate, handleIOCallback, state);
+      d_ioState->update(iostate, handleIOCallback, state);
     }
     else {
       updateIO(state, iostate, now);
     }
     ioGuard.release();
   }
-  while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !state->d_lastIOBlocked);
+  while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !d_lastIOBlocked);
 }
 
-void IncomingTCPConnectionState::notifyIOError(InternalQueryState&& query, const struct timeval& now)
+void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPResponse&& response)
 {
   if (std::this_thread::get_id() != d_creatorThreadID) {
     /* empty buffer will signal an IO error */
-    TCPResponse response(PacketBuffer(), std::move(query), nullptr, nullptr);
+    response.d_buffer.clear();
     handleCrossProtocolResponse(now, std::move(response));
     return;
   }
@@ -1115,8 +1205,17 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param
 
   struct timeval now;
   gettimeofday(&now, nullptr);
-  auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
-  IncomingTCPConnectionState::handleIO(state, now);
+
+  if (citmp->cs->dohFrontend) {
+#ifdef HAVE_NGHTTP2
+    auto state = std::make_shared<IncomingHTTP2Connection>(std::move(*citmp), *threadData, now);
+    state->handleIO();
+#endif /* HAVE_NGHTTP2 */
+  }
+  else {
+    auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now);
+    state->handleIO();
+  }
 }
 
 static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
@@ -1141,20 +1240,18 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
   std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender();
   auto query = std::move(cpq->query);
   auto downstreamServer = std::move(cpq->downstream);
-  auto proxyProtocolPayloadSize = cpq->proxyProtocolPayloadSize;
 
   try {
     auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
 
-    prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
-    query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize;
+    prependSizeToTCPQuery(query.d_buffer, query.d_idstate.d_proxyProtocolPayloadSize);
 
     vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr());
 
     downstream->queueQuery(tqs, std::move(query));
   }
   catch (...) {
-    tqs->notifyIOError(std::move(query.d_idstate), now);
+    tqs->notifyIOError(now, std::move(query));
   }
 }
 
@@ -1178,7 +1275,7 @@ static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t&
 
   try {
     if (response.d_response.d_buffer.empty()) {
-      response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now);
+      response.d_state->notifyIOError(response.d_now, std::move(response.d_response));
     }
     else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) {
       response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response));
@@ -1337,7 +1434,8 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
 {
   auto& cs = param.cs;
   auto& acl = param.acl;
-  int socket = param.socket;
+  const bool checkACL = !cs.dohFrontend || (!cs.dohFrontend->d_trustForwardedForHeader && cs.dohFrontend->d_earlyACLDrop);
+  const int socket = param.socket;
   bool tcpClientCountIncremented = false;
   ComboAddress remote;
   remote.sin4.sin_family = param.local.sin4.sin_family;
@@ -1358,7 +1456,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
       throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str());
     }
 
-    if (!acl->match(remote)) {
+    if (checkACL && !acl->match(remote)) {
       ++dnsdist::metrics::g_stats.aclDrops;
       vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
       return;
@@ -1395,6 +1493,7 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
     vinfolog("Got TCP connection from %s", remote.toStringWithPort());
 
     ci.remote = remote;
+
     if (threadData == nullptr) {
       if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(ci)))) {
         if (tcpClientCountIncremented) {
@@ -1405,8 +1504,17 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa
     else {
       struct timeval now;
       gettimeofday(&now, nullptr);
-      auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now);
-      IncomingTCPConnectionState::handleIO(state, now);
+
+      if (ci.cs->dohFrontend) {
+#ifdef HAVE_NGHTTP2        
+        auto state = std::make_shared<IncomingHTTP2Connection>(std::move(ci), *threadData, now);
+        state->handleIO();
+#endif /* HAVE_NGHTTP2 */
+      }
+      else {
+        auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now);
+        state->handleIO();
+      }
     }
   }
   catch (const std::exception& e) {
index fdf2797104c14916769ae5eb9aafcdf0dccc95d9..19e477ce798fa9f65516af850c73676087a6aa7e 100644 (file)
@@ -1469,7 +1469,7 @@ public:
     return handleResponse(now, std::move(response));
   }
 
-  void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+  void notifyIOError(const struct timeval&, TCPResponse&&) override
   {
     // nothing to do
   }
@@ -2573,18 +2573,24 @@ int main(int argc, char** argv)
         cout<<"gnutls";
 #ifdef HAVE_LIBSSL
         cout<<" ";
-#endif /* HAVE_LIBSSL */
+#endif
 #endif /* HAVE_GNUTLS */
 #ifdef HAVE_LIBSSL
         cout<<"openssl";
-#endif /* HAVE_LIBSSL */
+#endif
         cout<<") ";
 #endif /* HAVE_DNS_OVER_TLS */
 #ifdef HAVE_DNS_OVER_HTTPS
         cout<<"dns-over-https(";
 #ifdef HAVE_LIBH2OEVLOOP
         cout<<"h2o";
+#ifdef HAVE_NGHTTP2
+        cout<<" ";
+#endif
 #endif /* HAVE_LIBH2OEVLOOP */
+#ifdef HAVE_NGHTTP2
+        cout<<"nghttp2";
+#endif
         cout<<") ";
 #endif /* HAVE_DNS_OVER_HTTPS */
 #ifdef HAVE_DNSCRYPT
@@ -2608,9 +2614,6 @@ int main(int argc, char** argv)
 #ifdef HAVE_LMDB
         cout<<"lmdb ";
 #endif
-#ifdef HAVE_NGHTTP2
-        cout<<"outgoing-dns-over-https(nghttp2) ";
-#endif
 #ifndef DISABLE_PROTOBUF
         cout<<"protobuf ";
 #endif
@@ -2914,8 +2917,8 @@ int main(int argc, char** argv)
 
     std::vector<ClientState*> tcpStates;
     std::vector<ClientState*> udpStates;
-    for(auto& cs : g_frontends) {
-      if (cs->dohFrontend != nullptr) {
+    for (auto& cs : g_frontends) {
+      if (cs->dohFrontend != nullptr && cs->dohFrontend->d_library == "h2o") {
 #ifdef HAVE_DNS_OVER_HTTPS
 #ifdef HAVE_LIBH2OEVLOOP
         std::thread t1(dohThread, cs.get());
index 99d7cdbe6475655094ba72a714a2ec6c7f6049a6..e4f30eaa83ab7dbb629b902cc03c51e590583248 100644 (file)
@@ -80,6 +80,10 @@ if HAVE_LIBSSL
 AM_CPPFLAGS += $(LIBSSL_CFLAGS)
 endif
 
+if HAVE_GNUTLS
+AM_CPPFLAGS += $(GNUTLS_CFLAGS)
+endif
+
 if HAVE_LIBH2OEVLOOP
 AM_CPPFLAGS += $(LIBH2OEVLOOP_CFLAGS)
 endif
@@ -178,6 +182,7 @@ dnsdist_SOURCES = \
        dnsdist-lua.cc dnsdist-lua.hh \
        dnsdist-mac-address.cc dnsdist-mac-address.hh \
        dnsdist-metrics.cc dnsdist-metrics.hh \
+       dnsdist-nghttp2-in.cc dnsdist-nghttp2-in.hh \
        dnsdist-nghttp2.cc dnsdist-nghttp2.hh \
        dnsdist-prometheus.hh \
        dnsdist-protobuf.cc dnsdist-protobuf.hh \
@@ -274,6 +279,7 @@ testrunner_SOURCES = \
        dnsdist-lua-vars.cc \
        dnsdist-mac-address.cc dnsdist-mac-address.hh \
        dnsdist-metrics.cc dnsdist-metrics.hh \
+       dnsdist-nghttp2-in.cc dnsdist-nghttp2-in.hh \
        dnsdist-nghttp2.cc dnsdist-nghttp2.hh \
        dnsdist-protocols.cc dnsdist-protocols.hh \
        dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \
@@ -411,6 +417,10 @@ endif
 
 if HAVE_DNS_OVER_HTTPS
 
+if HAVE_GNUTLS
+dnsdist_LDADD += -lgnutls
+endif
+
 if HAVE_LIBH2OEVLOOP
 dnsdist_LDADD += $(LIBH2OEVLOOP_LIBS)
 endif
index 5805bbdfe8f35cd209e65bf93f487fcca202f168..af9f8c422a15327bc2687744694c57d1d42486fb 100644 (file)
@@ -71,6 +71,7 @@ AM_CONDITIONAL([HAVE_GNUTLS], [false])
 AM_CONDITIONAL([HAVE_LIBH2OEVLOOP], [false])
 AM_CONDITIONAL([HAVE_LIBSSL], [false])
 AM_CONDITIONAL([HAVE_LMDB], [false])
+AM_CONDITIONAL([HAVE_NGHTTP2], [false])
 
 PDNS_CHECK_LIBCRYPTO
 
@@ -81,30 +82,28 @@ DNSDIST_ENABLE_DNS_OVER_HTTPS
 
 AS_IF([test "x$enable_dns_over_tls" != "xno" -o "x$enable_dns_over_https" != "xno"], [
   PDNS_WITH_LIBSSL
+  PDNS_WITH_GNUTLS
 ])
 
 AS_IF([test "x$enable_dns_over_tls" != "xno"], [
-  PDNS_WITH_GNUTLS
-
   AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [
     AC_MSG_ERROR([DNS over TLS support requested but neither GnuTLS nor OpenSSL are available])
   ])
 ])
 
 AS_IF([test "x$enable_dns_over_https" != "xno"], [
+  PDNS_WITH_NGHTTP2
   PDNS_WITH_LIBH2OEVLOOP
 
-  AS_IF([test "x$HAVE_LIBH2OEVLOOP" != "x1"], [
-    AC_MSG_ERROR([DNS over HTTPS support requested but libh2o-evloop was not found])
+  AS_IF([test "x$HAVE_LIBH2OEVLOOP" != "x1" -a "x$HAVE_NGHTTP2" != "x1" ], [
+    AC_MSG_ERROR([DNS over HTTPS support requested but neither libh2o-evloop nor nghttp2 was not found])
   ])
 
-  AS_IF([test "x$HAVE_LIBSSL" != "x1"], [
-    AC_MSG_ERROR([DNS over HTTPS support requested but OpenSSL was not found])
+  AS_IF([test "x$HAVE_GNUTLS" != "x1" -a "x$HAVE_LIBSSL" != "x1"], [
+    AC_MSG_ERROR([DNS over HTTPS support requested but neither GnuTLS nor OpenSSL are available])
   ])
 ])
 
-PDNS_WITH_NGHTTP2
-
 DNSDIST_WITH_CDB
 PDNS_CHECK_LMDB
 PDNS_ENABLE_IPCIPHER
index f54b1c0b144644d9ff5c90b60b5f2bf0f74eb43b..9cb96d83a226b0e88a1075d247e195a13def212c 100644 (file)
@@ -137,7 +137,8 @@ void AsynchronousHolder::mainThread(std::shared_ptr<Data> data)
         vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec);
         auto sender = query->getTCPQuerySender();
         if (sender) {
-          sender->notifyIOError(std::move(query->query.d_idstate), now);
+          TCPResponse tresponse(std::move(query->query));
+          sender->notifyIOError(now, std::move(tresponse));
         }
       }
       else {
index 66f421479e711d24ac19ab515b790a34fb8b915b..67b65796d42965b18bccd99da2cccdaf8df523ff 100644 (file)
@@ -168,7 +168,7 @@ public:
     throw std::runtime_error("Unexpected XFR reponse to a health check query");
   }
 
-  void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+  void notifyIOError(const struct timeval& now, TCPResponse&&) override
   {
     ++d_data->d_ds->d_healthCheckMetrics.d_networkErrors;
     d_data->d_ds->submitHealthCheckResult(d_data->d_initial, false);
index 49f95e42b4b4f5af51805ab37503c69ce9fe5f54..ea4c5413970946d56f885862baa8df019260429e 100644 (file)
@@ -20,6 +20,7 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 #include "dnsdist-internal-queries.hh"
+#include "dnsdist-nghttp2-in.hh"
 #include "dnsdist-tcp.hh"
 #include "doh.hh"
 
@@ -35,7 +36,12 @@ std::unique_ptr<CrossProtocolQuery> getInternalQueryFromDQ(DNSQuestion& dq, bool
   }
 #ifdef HAVE_DNS_OVER_HTTPS
   else if (protocol == dnsdist::Protocol::DoH) {
-    return getDoHCrossProtocolQueryFromDQ(dq, isResponse);
+#ifdef HAVE_LIBH2OEVLOOP
+    if (dq.ids.cs->dohFrontend->d_library == "h2o") {
+      return getDoHCrossProtocolQueryFromDQ(dq, isResponse);
+    }
+#endif /* HAVE_LIBH2OEVLOOP */
+    return getTCPCrossProtocolQueryFromDQ(dq);
   }
 #endif
   else {
index 20c866931a847a2ee74db6d53aac84be7569811c..48ce507da8e7331fe56e48aab3fb0a90ff7bd97d 100644 (file)
@@ -929,7 +929,8 @@ bool dnsdist_ffi_drop_from_async(uint16_t asyncID, uint16_t queryID)
 
   struct timeval now;
   gettimeofday(&now, nullptr);
-  sender->notifyIOError(std::move(query->query.d_idstate), now);
+  TCPResponse tresponse(std::move(query->query));
+  sender->notifyIOError(now, std::move(tresponse));
 
   return true;
 }
diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc
new file mode 100644 (file)
index 0000000..aefa50d
--- /dev/null
@@ -0,0 +1,1214 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "base64.hh"
+#include "dnsdist-nghttp2-in.hh"
+#include "dnsdist-proxy-protocol.hh"
+#include "dnsparser.hh"
+
+#ifdef HAVE_NGHTTP2
+
+#if 0
+class IncomingDoHCrossProtocolContext : public CrossProtocolContext
+{
+public:
+  IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr<IncomingHTTP2Connection> connection, IncomingHTTP2Connection::StreamID streamID): CrossProtocolContext(std::move(query.d_buffer)), d_connection(connection), d_query(std::move(query))
+  {
+  }
+
+  std::optional<std::string> getHTTPPath() const override
+  {
+    return d_query.d_path;
+  }
+
+  std::optional<std::string> getHTTPScheme() const override
+  {
+    return d_query.d_scheme;
+  }
+
+  std::optional<std::string> getHTTPHost() const override
+  {
+    return d_query.d_host;
+  }
+
+  std::optional<std::string> getHTTPQueryString() const override
+  {
+    return d_query.d_queryString;
+  }
+
+  std::optional<HeadersMap> getHTTPHeaders() const override
+  {
+    if (!d_query.d_headers) {
+      return std::nullopt;
+    }
+    return *d_query.d_headers;
+  }
+
+  void handleResponse(PacketBuffer&& response, InternalQueryState&& state) override
+  {
+    auto conn = d_connection.lock();
+    if (!conn) {
+      /* the connection has been closed in the meantime */
+      return;
+    }
+  }
+
+  void handleTimeout() override
+  {
+    auto conn = d_connection.lock();
+    if (!conn) {
+      /* the connection has been closed in the meantime */
+      return;
+    }
+  }
+
+  ~IncomingDoHCrossProtocolContext() override
+  {
+  }
+
+private:
+  std::weak_ptr<IncomingHTTP2Connection> d_connection;
+  IncomingHTTP2Connection::PendingQuery d_query;
+  IncomingHTTP2Connection::StreamID d_streamID{-1};
+};
+#endif
+
+class IncomingDoHCrossProtocolContext : public DOHUnitInterface
+{
+public:
+  IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr<IncomingHTTP2Connection> connection, IncomingHTTP2Connection::StreamID streamID) :
+    d_connection(connection), d_query(std::move(query)), d_streamID(streamID)
+  {
+  }
+
+  std::string getHTTPPath() const override
+  {
+    return d_query.d_path;
+  }
+
+  const std::string& getHTTPScheme() const override
+  {
+    return d_query.d_scheme;
+  }
+
+  const std::string& getHTTPHost() const override
+  {
+    return d_query.d_host;
+  }
+
+  std::string getHTTPQueryString() const override
+  {
+    return d_query.d_queryString;
+  }
+
+  const HeadersMap& getHTTPHeaders() const override
+  {
+    if (!d_query.d_headers) {
+      static const HeadersMap empty{};
+      return empty;
+    }
+    return *d_query.d_headers;
+  }
+
+  void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "") override
+  {
+    d_query.d_statusCode = statusCode;
+    d_query.d_response = std::move(body);
+    d_query.d_contentTypeOut = contentType;
+  }
+
+  void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds) override
+  {
+    std::unique_ptr<DOHUnitInterface> unit(this);
+    auto conn = d_connection.lock();
+    if (!conn) {
+      /* the connection has been closed in the meantime */
+      return;
+    }
+
+    state.du = std::move(unit);
+    TCPResponse resp(std::move(response), std::move(state), nullptr, nullptr);
+    resp.d_ds = ds;
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+    conn->handleResponse(now, std::move(resp));
+  }
+
+  void handleTimeout() override
+  {
+    std::unique_ptr<DOHUnitInterface> unit(this);
+    auto conn = d_connection.lock();
+    if (!conn) {
+      /* the connection has been closed in the meantime */
+      return;
+    }
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+    TCPResponse resp;
+    resp.d_idstate.d_streamID = d_streamID;
+    conn->notifyIOError(now, std::move(resp));
+  }
+
+  ~IncomingDoHCrossProtocolContext() override
+  {
+  }
+
+  std::weak_ptr<IncomingHTTP2Connection> d_connection;
+  IncomingHTTP2Connection::PendingQuery d_query;
+  IncomingHTTP2Connection::StreamID d_streamID{-1};
+};
+
+void IncomingHTTP2Connection::handleResponse(const struct timeval& now, TCPResponse&& response)
+{
+  if (std::this_thread::get_id() != d_creatorThreadID) {
+    handleCrossProtocolResponse(now, std::move(response));
+    return;
+  }
+
+  auto& state = response.d_idstate;
+  if (state.forwardedOverUDP) {
+    dnsheader* responseDH = reinterpret_cast<struct dnsheader*>(response.d_buffer.data());
+
+    if (responseDH->tc && state.d_packet && state.d_packet->size() > state.d_proxyProtocolPayloadSize && state.d_packet->size() - state.d_proxyProtocolPayloadSize > sizeof(dnsheader)) {
+      auto& query = *state.d_packet;
+      dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(query.data() + state.d_proxyProtocolPayloadSize);
+      /* restoring the original ID */
+      queryDH->id = state.origID;
+
+      state.forwardedOverUDP = false;
+      auto cpq = getCrossProtocolQuery(std::move(query), std::move(state), response.d_ds);
+      cpq->query.d_proxyProtocolPayloadAdded = state.d_proxyProtocolPayloadSize > 0;
+      if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
+        return;
+      }
+      else {
+        vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP");
+        notifyIOError(now, std::move(response));
+        return;
+      }
+    }
+  }
+
+  IncomingTCPConnectionState::handleResponse(now, std::move(response));
+}
+
+std::unique_ptr<DOHUnitInterface> IncomingHTTP2Connection::getDOHUnit(uint32_t streamID)
+{
+  auto query = std::move(d_currentStreams.at(streamID));
+  return std::make_unique<IncomingDoHCrossProtocolContext>(std::move(query), std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this()), streamID);
+}
+
+void IncomingHTTP2Connection::restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&& unit)
+{
+  auto context = std::unique_ptr<IncomingDoHCrossProtocolContext>(dynamic_cast<IncomingDoHCrossProtocolContext*>(unit.release()));
+  d_currentStreams.at(context->d_streamID) = std::move(context->d_query);
+}
+
+void IncomingHTTP2Connection::restoreContext(uint32_t streamID, IncomingHTTP2Connection::PendingQuery&& context)
+{
+  d_currentStreams.at(streamID) = std::move(context);
+}
+
+IncomingHTTP2Connection::IncomingHTTP2Connection(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now) :
+  IncomingTCPConnectionState(std::move(ci), threadData, now)
+{
+  nghttp2_session_callbacks* cbs = nullptr;
+  if (nghttp2_session_callbacks_new(&cbs) != 0) {
+    throw std::runtime_error("Unable to create a callback object for a new incoming HTTP/2 session");
+  }
+  std::unique_ptr<nghttp2_session_callbacks, void (*)(nghttp2_session_callbacks*)> callbacks(cbs, nghttp2_session_callbacks_del);
+  cbs = nullptr;
+
+  nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback);
+  nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback);
+  nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback);
+  nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback);
+  nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback);
+  nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback);
+  nghttp2_session_callbacks_set_error_callback2(callbacks.get(), on_error_callback);
+
+  nghttp2_session* sess = nullptr;
+  if (nghttp2_session_server_new(&sess, callbacks.get(), this) != 0) {
+    throw std::runtime_error("Coult not allocate a new incoming HTTP/2 session");
+  }
+
+  d_session = std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)>(sess, nghttp2_session_del);
+  sess = nullptr;
+}
+
+bool IncomingHTTP2Connection::checkALPN()
+{
+  constexpr std::array<uint8_t, 2> h2{'h', '2'};
+  auto protocols = d_handler.getNextProtocol();
+  if (protocols.size() == h2.size() && memcmp(protocols.data(), h2.data(), h2.size()) == 0) {
+    return true;
+  }
+  vinfolog("DoH connection from %s expected ALPN value 'h2', got '%s'", d_ci.remote.toStringWithPort(), std::string(protocols.begin(), protocols.end()));
+  return false;
+}
+
+void IncomingHTTP2Connection::handleConnectionReady()
+{
+  constexpr std::array<nghttp2_settings_entry, 1> iv{{{NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100U}}};
+  auto ret = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv.data(), iv.size());
+  if (ret != 0) {
+    throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret)));
+  }
+  ret = nghttp2_session_send(d_session.get());
+  if (ret != 0) {
+    throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret)));
+  }
+}
+
+void IncomingHTTP2Connection::handleIO()
+{
+  IOState iostate = IOState::Done;
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+
+  try {
+    if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+      vinfolog("Terminating DoH connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
+      stopIO();
+      d_connectionDied = true;
+      return;
+    }
+
+    if (d_state == State::doingHandshake) {
+      iostate = d_handler.tryHandshake();
+      if (iostate == IOState::Done) {
+        handleHandshakeDone(now);
+        if (d_handler.isTLS()) {
+          if (!checkALPN()) {
+            d_connectionDied = true;
+            stopIO();
+            return;
+          }
+        }
+
+        if (expectProxyProtocolFrom(d_ci.remote)) {
+          d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader;
+          d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+          d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+        }
+        else {
+          d_state = State::waitingForQuery;
+          handleConnectionReady();
+        }
+      }
+    }
+
+    if (d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) {
+      auto status = handleProxyProtocolPayload();
+      if (status == ProxyProtocolResult::Done) {
+        d_currentPos = 0;
+        d_proxyProtocolNeed = 0;
+        d_buffer.clear();
+        d_state = State::waitingForQuery;
+        handleConnectionReady();
+      }
+      else if (status == ProxyProtocolResult::Error) {
+        d_connectionDied = true;
+        stopIO();
+        return;
+      }
+    }
+
+    if (d_state == State::waitingForQuery) {
+      readHTTPData();
+    }
+
+    if (!d_connectionDied) {
+      auto shared = std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this());
+      if (nghttp2_session_want_read(d_session.get())) {
+        d_ioState->add(IOState::NeedRead, &handleReadableIOCallback, shared, boost::none);
+      }
+      if (nghttp2_session_want_write(d_session.get())) {
+        d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, shared, boost::none);
+      }
+    }
+  }
+  catch (const std::exception& e) {
+    vinfolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+    d_connectionDied = true;
+    stopIO();
+  }
+}
+
+ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+{
+  IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+  bool bufferWasEmpty = conn->d_out.empty();
+  conn->d_out.insert(conn->d_out.end(), data, data + length);
+
+  if (bufferWasEmpty) {
+    try {
+      auto state = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
+      if (state == IOState::Done) {
+        conn->d_out.clear();
+        conn->d_outPos = 0;
+        if (!conn->isIdle()) {
+          conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
+        }
+        else {
+          conn->watchForRemoteHostClosingConnection();
+        }
+      }
+      else {
+        conn->updateIO(state, handleWritableIOCallback);
+      }
+    }
+    catch (const std::exception& e) {
+      vinfolog("Exception while trying to write (send) to incoming HTTP connection: %s", e.what());
+      conn->handleIOError();
+    }
+  }
+
+  return length;
+}
+
+static const std::unordered_map<std::string, std::string> s_constants{
+  {"200-value", "200"},
+  {"method-name", ":method"},
+  {"method-value", "POST"},
+  {"scheme-name", ":scheme"},
+  {"scheme-value", "https"},
+  {"authority-name", ":authority"},
+  {"x-forwarded-for-name", "x-forwarded-for"},
+  {"path-name", ":path"},
+  {"content-length-name", "content-length"},
+  {"status-name", ":status"},
+  {"location-name", "location"},
+  {"accept-name", "accept"},
+  {"accept-value", "application/dns-message"},
+  {"cache-control-name", "cache-control"},
+  {"content-type-name", "content-type"},
+  {"content-type-value", "application/dns-message"},
+  {"user-agent-name", "user-agent"},
+  {"user-agent-value", "nghttp2-" NGHTTP2_VERSION "/dnsdist"},
+  {"x-forwarded-port-name", "x-forwarded-port"},
+  {"x-forwarded-proto-name", "x-forwarded-proto"},
+  {"x-forwarded-proto-value-dns-over-udp", "dns-over-udp"},
+  {"x-forwarded-proto-value-dns-over-tcp", "dns-over-tcp"},
+  {"x-forwarded-proto-value-dns-over-tls", "dns-over-tls"},
+  {"x-forwarded-proto-value-dns-over-http", "dns-over-http"},
+  {"x-forwarded-proto-value-dns-over-https", "dns-over-https"},
+};
+
+static const std::string s_authorityHeaderName(":authority");
+static const std::string s_pathHeaderName(":path");
+static const std::string s_methodHeaderName(":method");
+static const std::string s_schemeHeaderName(":scheme");
+static const std::string s_xForwardedForHeaderName("x-forwarded-for");
+
+void NGHTTP2Headers::addStaticHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& valueKey)
+{
+  const auto& name = s_constants.at(nameKey);
+  const auto& value = s_constants.at(valueKey);
+
+  headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
+}
+
+void NGHTTP2Headers::addCustomDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& name, const std::string_view& value)
+{
+  headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.data())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.data())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
+}
+
+void NGHTTP2Headers::addDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string_view& value)
+{
+  const auto& name = s_constants.at(nameKey);
+  NGHTTP2Headers::addCustomDynamicHeader(headers, name, value);
+}
+
+IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResponse&& response)
+{
+  assert(response.d_idstate.d_streamID != -1);
+  auto& context = d_currentStreams.at(response.d_idstate.d_streamID);
+
+  uint32_t statusCode = 200U;
+  std::string contentType;
+  bool sendContentType = true;
+  auto& responseBuffer = context.d_buffer;
+  if (context.d_statusCode != 0) {
+    responseBuffer = std::move(context.d_response);
+    statusCode = context.d_statusCode;
+    contentType = std::move(context.d_contentTypeOut);
+  }
+  else {
+    responseBuffer = std::move(response.d_buffer);
+  }
+
+  sendResponse(response.d_idstate.d_streamID, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType);
+  handleResponseSent(response);
+
+  return IOState::Done;
+}
+
+void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPResponse&& response)
+{
+  if (std::this_thread::get_id() != d_creatorThreadID) {
+    /* empty buffer will signal an IO error */
+    response.d_buffer.clear();
+    handleCrossProtocolResponse(now, std::move(response));
+    return;
+  }
+
+  assert(response.d_idstate.d_streamID != -1);
+  d_currentStreams.at(response.d_idstate.d_streamID).d_buffer = std::move(response.d_buffer);
+  sendResponse(response.d_idstate.d_streamID, 502, d_ci.cs->dohFrontend->d_customResponseHeaders);
+}
+
+bool IncomingHTTP2Connection::sendResponse(IncomingHTTP2Connection::StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType, bool addContentType)
+{
+  /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set.
+   */
+  nghttp2_data_provider data_provider;
+
+  data_provider.source.ptr = this;
+  data_provider.read_callback = [](nghttp2_session*, IncomingHTTP2Connection::StreamID stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* cb_data) -> ssize_t {
+    auto connection = reinterpret_cast<IncomingHTTP2Connection*>(cb_data);
+    auto& obj = connection->d_currentStreams.at(stream_id);
+    size_t toCopy = 0;
+    if (obj.d_queryPos < obj.d_buffer.size()) {
+      size_t remaining = obj.d_buffer.size() - obj.d_queryPos;
+      toCopy = length > remaining ? remaining : length;
+      memcpy(buf, &obj.d_buffer.at(obj.d_queryPos), toCopy);
+      obj.d_queryPos += toCopy;
+    }
+
+    if (obj.d_queryPos >= obj.d_buffer.size()) {
+      *data_flags |= NGHTTP2_DATA_FLAG_EOF;
+    }
+    return toCopy;
+  };
+
+  const auto& df = d_ci.cs->dohFrontend;
+  auto& responseBody = d_currentStreams.at(streamID).d_buffer;
+
+  std::vector<nghttp2_nv> headers;
+  std::string responseCodeStr;
+  std::string cacheControlValue;
+  std::string location;
+  /* remember that dynamic header values should be kept alive
+     until we have called nghttp2_submit_response(), at least */
+
+  if (responseCode == 200) {
+    NGHTTP2Headers::addStaticHeader(headers, "status-name", "200-value");
+    ++df->d_validresponses;
+    ++df->d_http2Stats.d_nb200Responses;
+
+    if (addContentType) {
+      if (contentType.empty()) {
+        NGHTTP2Headers::addStaticHeader(headers, "content-type-name", "content-type-value");
+      }
+      else {
+        NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", contentType);
+      }
+    }
+
+    if (df->d_sendCacheControlHeaders && responseBody.size() > sizeof(dnsheader)) {
+      uint32_t minTTL = getDNSPacketMinTTL(reinterpret_cast<const char*>(responseBody.data()), responseBody.size());
+      if (minTTL != std::numeric_limits<uint32_t>::max()) {
+        cacheControlValue = "max-age=" + std::to_string(minTTL);
+        NGHTTP2Headers::addDynamicHeader(headers, "cache-control-name", cacheControlValue);
+      }
+    }
+  }
+  else {
+    responseCodeStr = std::to_string(responseCode);
+    NGHTTP2Headers::addDynamicHeader(headers, "status-name", responseCodeStr);
+
+    if (responseCode >= 300 && responseCode < 400) {
+      location = std::string(reinterpret_cast<const char*>(responseBody.data()), responseBody.size());
+      NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", "text/html; charset=utf-8");
+      NGHTTP2Headers::addDynamicHeader(headers, "location-name", location);
+      static const std::string s_redirectStart{"<!DOCTYPE html><TITLE>Moved</TITLE><P>The document has moved <A HREF=\""};
+      static const std::string s_redirectEnd{"\">here</A>"};
+      responseBody.reserve(s_redirectStart.size() + responseBody.size() + s_redirectEnd.size());
+      responseBody.insert(responseBody.begin(), s_redirectStart.begin(), s_redirectStart.end());
+      responseBody.insert(responseBody.end(), s_redirectEnd.begin(), s_redirectEnd.end());
+      ++df->d_redirectresponses;
+    }
+    else {
+      ++df->d_errorresponses;
+      switch (responseCode) {
+      case 400:
+        ++df->d_http2Stats.d_nb400Responses;
+        break;
+      case 403:
+        ++df->d_http2Stats.d_nb403Responses;
+        break;
+      case 500:
+        ++df->d_http2Stats.d_nb500Responses;
+        break;
+      case 502:
+        ++df->d_http2Stats.d_nb502Responses;
+        break;
+      default:
+        ++df->d_http2Stats.d_nbOtherResponses;
+        break;
+      }
+
+      if (!responseBody.empty()) {
+        NGHTTP2Headers::addDynamicHeader(headers, "content-type-name", "text/plain; charset=utf-8");
+      }
+      else {
+        static const std::string invalid{"invalid DNS query"};
+        static const std::string notAllowed{"dns query not allowed"};
+        static const std::string noDownstream{"no downstream server available"};
+        static const std::string internalServerError{"Internal Server Error"};
+
+        switch (responseCode) {
+        case 400:
+          responseBody.insert(responseBody.begin(), invalid.begin(), invalid.end());
+          break;
+        case 403:
+          responseBody.insert(responseBody.begin(), notAllowed.begin(), notAllowed.end());
+          break;
+        case 502:
+          responseBody.insert(responseBody.begin(), noDownstream.begin(), noDownstream.end());
+          break;
+        case 500:
+          /* fall-through */
+        default:
+          responseBody.insert(responseBody.begin(), internalServerError.begin(), internalServerError.end());
+          break;
+        }
+      }
+    }
+  }
+
+  const std::string contentLength = std::to_string(responseBody.size());
+  NGHTTP2Headers::addDynamicHeader(headers, "content-length-name", contentLength);
+
+  for (const auto& [key, value] : customResponseHeaders) {
+    NGHTTP2Headers::addCustomDynamicHeader(headers, key, value);
+  }
+
+  auto ret = nghttp2_submit_response(d_session.get(), streamID, headers.data(), headers.size(), &data_provider);
+  if (ret != 0) {
+    d_currentStreams.erase(streamID);
+    vinfolog("Error submitting HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret));
+    return false;
+  }
+
+  ret = nghttp2_session_send(d_session.get());
+  if (ret != 0) {
+    d_currentStreams.erase(streamID);
+    vinfolog("Error flushing HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret));
+    return false;
+  }
+
+  return true;
+}
+
+static void processForwardedForHeader(const std::unique_ptr<HeadersMap>& headers, ComboAddress& remote)
+{
+  if (!headers) {
+    return;
+  }
+
+  auto it = headers->find(s_xForwardedForHeaderName);
+  if (it == headers->end()) {
+    return;
+  }
+
+  std::string_view value = it->second;
+  try {
+    auto pos = value.rfind(',');
+    if (pos != std::string_view::npos) {
+      ++pos;
+      for (; pos < value.size() && value[pos] == ' '; ++pos) {
+      }
+
+      if (pos < value.size()) {
+        value = value.substr(pos);
+      }
+    }
+    auto newRemote = ComboAddress(std::string(value));
+    remote = newRemote;
+  }
+  catch (const std::exception& e) {
+    vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.what());
+  }
+  catch (const PDNSException& e) {
+    vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.reason);
+  }
+}
+
+static std::optional<PacketBuffer> getPayloadFromPath(const std::string_view& path)
+{
+  std::optional<PacketBuffer> result{std::nullopt};
+
+  if (path.size() <= 5) {
+    return result;
+  }
+
+  auto pos = path.find("?dns=");
+  if (pos == string::npos) {
+    pos = path.find("&dns=");
+  }
+
+  if (pos == string::npos) {
+    return result;
+  }
+
+  // need to base64url decode this
+  string sdns(path.substr(pos + 5));
+  boost::replace_all(sdns, "-", "+");
+  boost::replace_all(sdns, "_", "/");
+
+  // re-add padding that may have been missing
+  switch (sdns.size() % 4) {
+  case 2:
+    sdns.append(2, '=');
+    break;
+  case 3:
+    sdns.append(1, '=');
+    break;
+  }
+
+  PacketBuffer decoded;
+  /* rough estimate so we hopefully don't need a new allocation later */
+  /* We reserve at few additional bytes to be able to add EDNS later */
+  const size_t estimate = ((sdns.size() * 3) / 4);
+  decoded.reserve(estimate);
+  if (B64Decode(sdns, decoded) < 0) {
+    return result;
+  }
+
+  result = std::move(decoded);
+  return result;
+}
+
+void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::PendingQuery&& query, IncomingHTTP2Connection::StreamID streamID)
+{
+  const auto handleImmediateResponse = [this, &query, streamID](uint16_t code, const std::string& reason, PacketBuffer&& response = PacketBuffer()) {
+    if (response.empty()) {
+      query.d_buffer.clear();
+      query.d_buffer.insert(query.d_buffer.begin(), reason.begin(), reason.end());
+    }
+    else {
+      query.d_buffer = std::move(response);
+    }
+    vinfolog("Sending an immediate %d response to incoming DoH query: %s", code, reason);
+    sendResponse(streamID, code, d_ci.cs->dohFrontend->d_customResponseHeaders);
+  };
+
+  ++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries;
+
+  if (d_ci.cs->dohFrontend->d_trustForwardedForHeader) {
+    processForwardedForHeader(query.d_headers, d_proxiedRemote);
+
+    /* second ACL lookup based on the updated address */
+    auto& holders = d_threadData.holders;
+    if (!holders.acl->match(d_proxiedRemote)) {
+      ++dnsdist::metrics::g_stats.aclDrops;
+      vinfolog("Query from %s (%s) (DoH) dropped because of ACL", d_ci.remote.toStringWithPort(), d_proxiedRemote.toStringWithPort());
+      handleImmediateResponse(403, "DoH query not allowed because of ACL");
+      return;
+    }
+
+    if (!d_ci.cs->dohFrontend->d_keepIncomingHeaders) {
+      query.d_headers.reset();
+    }
+  }
+
+  if (d_ci.cs->dohFrontend->d_exactPathMatching) {
+    if (d_ci.cs->dohFrontend->d_urls.count(query.d_path) == 0) {
+      handleImmediateResponse(404, "there is no endpoint configured for this path");
+      return;
+    }
+  }
+  else {
+    bool found = false;
+    for (const auto& path : d_ci.cs->dohFrontend->d_urls) {
+      if (boost::starts_with(query.d_path, path)) {
+        found = true;
+        break;
+      }
+    }
+    if (!found) {
+      handleImmediateResponse(404, "there is no endpoint configured for this path");
+      return;
+    }
+  }
+
+  /* the responses map can be updated at runtime, so we need to take a copy of
+     the shared pointer, increasing the reference counter */
+  auto responsesMap = d_ci.cs->dohFrontend->d_responsesMap;
+  if (responsesMap) {
+    for (const auto& entry : *responsesMap) {
+      if (entry->matches(query.d_path)) {
+        const auto& customHeaders = entry->getHeaders();
+        query.d_buffer = entry->getContent();
+        if (entry->getStatusCode() >= 400 && query.d_buffer.size() >= 1) {
+          // legacy trailing 0 from the h2o era
+          query.d_buffer.pop_back();
+        }
+
+        sendResponse(streamID, entry->getStatusCode(), customHeaders ? *customHeaders : d_ci.cs->dohFrontend->d_customResponseHeaders, std::string(), false);
+        return;
+      }
+    }
+  }
+
+  if (query.d_buffer.empty() && query.d_method == PendingQuery::Method::Get && !query.d_queryString.empty()) {
+    auto payload = getPayloadFromPath(query.d_queryString);
+    if (payload) {
+      query.d_buffer = std::move(*payload);
+    }
+    else {
+      ++d_ci.cs->dohFrontend->d_badrequests;
+      handleImmediateResponse(400, "DoH unable to decode BASE64-URL");
+      return;
+    }
+  }
+
+  if (query.d_method == PendingQuery::Method::Get) {
+    ++d_ci.cs->dohFrontend->d_getqueries;
+  }
+  else if (query.d_method == PendingQuery::Method::Post) {
+    ++d_ci.cs->dohFrontend->d_postqueries;
+  }
+
+  try {
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+    auto processingResult = handleQuery(std::move(query.d_buffer), now, streamID);
+
+    switch (processingResult) {
+    case QueryProcessingResult::TooSmall:
+      handleImmediateResponse(400, "DoH non-compliant query");
+      break;
+    case QueryProcessingResult::InvalidHeaders:
+      handleImmediateResponse(400, "DoH invalid headers");
+      break;
+    case QueryProcessingResult::Empty:
+      handleImmediateResponse(200, "DoH empty query", std::move(query.d_buffer));
+      break;
+    case QueryProcessingResult::Dropped:
+      handleImmediateResponse(403, "DoH dropped query");
+      break;
+    case QueryProcessingResult::NoBackend:
+      handleImmediateResponse(502, "DoH no backend available");
+      return;
+    case QueryProcessingResult::Forwarded:
+    case QueryProcessingResult::Asynchronous:
+    case QueryProcessingResult::SelfAnswered:
+      break;
+    }
+  }
+  catch (const std::exception& e) {
+    vinfolog("Exception while processing DoH query: %s", e.what());
+    handleImmediateResponse(400, "DoH non-compliant query");
+    return;
+  }
+}
+
+int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+{
+  IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+#if 0
+  switch (frame->hd.type) {
+  case NGHTTP2_HEADERS:
+    cerr<<"got headers"<<endl;
+    if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
+      cerr<<"All headers received"<<endl;
+    }
+    if (frame->headers.cat == NGHTTP2_HCAT_REQUEST) {
+      cerr<<"All headers received - query"<<endl;
+    }
+    break;
+  case NGHTTP2_WINDOW_UPDATE:
+    cerr<<"got window update"<<endl;
+    break;
+  case NGHTTP2_SETTINGS:
+    cerr<<"got settings"<<endl;
+    cerr<<frame->settings.niv<<endl;
+    for (size_t idx = 0; idx < frame->settings.niv; idx++) {
+      cerr<<"- "<<frame->settings.iv[idx].settings_id<<" "<<frame->settings.iv[idx].value<<endl;
+    }
+    break;
+  case NGHTTP2_DATA:
+    cerr<<"got data"<<endl;
+    break;
+  }
+#endif
+
+  if (frame->hd.type == NGHTTP2_GOAWAY) {
+    conn->stopIO();
+    if (conn->isIdle()) {
+      if (nghttp2_session_want_write(conn->d_session.get())) {
+        conn->d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, conn, boost::none);
+      }
+    }
+  }
+
+  /* is this the last frame for this stream? */
+  else if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) {
+    auto streamID = frame->hd.stream_id;
+    auto stream = conn->d_currentStreams.find(streamID);
+    if (stream != conn->d_currentStreams.end()) {
+      conn->handleIncomingQuery(std::move(stream->second), streamID);
+
+      if (conn->isIdle()) {
+        conn->watchForRemoteHostClosingConnection();
+      }
+    }
+    else {
+      vinfolog("Stream %d NOT FOUND", streamID);
+      return NGHTTP2_ERR_CALLBACK_FAILURE;
+    }
+  }
+
+  return 0;
+}
+
+int IncomingHTTP2Connection::on_stream_close_callback(nghttp2_session* session, IncomingHTTP2Connection::StreamID stream_id, uint32_t error_code, void* user_data)
+{
+  IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+
+  if (error_code == 0) {
+    return 0;
+  }
+
+  auto stream = conn->d_currentStreams.find(stream_id);
+  if (stream == conn->d_currentStreams.end()) {
+    /* we don't care, then */
+    return 0;
+  }
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+  auto request = std::move(stream->second);
+  conn->d_currentStreams.erase(stream->first);
+
+  if (conn->isIdle()) {
+    conn->watchForRemoteHostClosingConnection();
+  }
+
+  return 0;
+}
+
+int IncomingHTTP2Connection::on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+{
+  if (frame->hd.type != NGHTTP2_HEADERS || frame->headers.cat != NGHTTP2_HCAT_REQUEST) {
+    return 0;
+  }
+
+  IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+  auto insertPair = conn->d_currentStreams.insert({frame->hd.stream_id, PendingQuery()});
+  if (!insertPair.second) {
+    /* there is a stream ID collision, something is very wrong! */
+    vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort());
+    conn->d_connectionDied = true;
+    nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
+    auto ret = nghttp2_session_send(conn->d_session.get());
+    if (ret != 0) {
+      vinfolog("Error flushing HTTP response for stream %d from %s: %s", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret));
+      return NGHTTP2_ERR_CALLBACK_FAILURE;
+    }
+
+    return 0;
+  }
+
+  return 0;
+}
+
+static std::string::size_type getLengthOfPathWithoutParameters(const std::string_view& path)
+{
+  auto pos = path.find("?");
+  if (pos == string::npos) {
+    return path.size();
+  }
+
+  return pos;
+}
+
+int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t nameLen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data)
+{
+  IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+
+  if (frame->hd.type == NGHTTP2_HEADERS && frame->headers.cat == NGHTTP2_HCAT_REQUEST) {
+    if (nghttp2_check_header_name(name, nameLen) == 0) {
+      vinfolog("Invalid header name");
+      return NGHTTP2_ERR_CALLBACK_FAILURE;
+    }
+
+#if HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113
+    if (nghttp2_check_header_value_rfc9113(value, valuelen) == 0) {
+      vinfolog("Invalid header value");
+      return NGHTTP2_ERR_CALLBACK_FAILURE;
+    }
+#endif /* HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113 */
+
+    auto headerMatches = [name, nameLen](const std::string& expected) -> bool {
+      return nameLen == expected.size() && memcmp(name, expected.data(), expected.size()) == 0;
+    };
+
+    auto stream = conn->d_currentStreams.find(frame->hd.stream_id);
+    if (stream == conn->d_currentStreams.end()) {
+      vinfolog("Unable to match the stream ID %d to a known one!", frame->hd.stream_id);
+      return NGHTTP2_ERR_CALLBACK_FAILURE;
+    }
+    auto& query = stream->second;
+    auto valueView = std::string_view(reinterpret_cast<const char*>(value), valuelen);
+    if (headerMatches(s_pathHeaderName)) {
+#if HAVE_NGHTTP2_CHECK_PATH
+      if (nghttp2_check_path(value, valuelen) == 0) {
+        vinfolog("Invalid path value");
+        return NGHTTP2_ERR_CALLBACK_FAILURE;
+      }
+#endif /* HAVE_NGHTTP2_CHECK_PATH */
+
+      auto pathLen = getLengthOfPathWithoutParameters(valueView);
+      query.d_path = valueView.substr(0, pathLen);
+      if (pathLen < valueView.size()) {
+        query.d_queryString = valueView.substr(pathLen);
+      }
+    }
+    else if (headerMatches(s_authorityHeaderName)) {
+      query.d_host = valueView;
+    }
+    else if (headerMatches(s_schemeHeaderName)) {
+      query.d_scheme = valueView;
+    }
+    else if (headerMatches(s_methodHeaderName)) {
+#if HAVE_NGHTTP2_CHECK_METHOD
+      if (nghttp2_check_method(value, valuelen) == 0) {
+        vinfolog("Invalid method value");
+        return NGHTTP2_ERR_CALLBACK_FAILURE;
+      }
+#endif /* HAVE_NGHTTP2_CHECK_METHOD */
+
+      if (valueView == "GET") {
+        query.d_method = PendingQuery::Method::Get;
+      }
+      else if (valueView == "POST") {
+        query.d_method = PendingQuery::Method::Post;
+      }
+      else {
+        vinfolog("Unsupported method value");
+        return NGHTTP2_ERR_CALLBACK_FAILURE;
+      }
+    }
+
+    if (conn->d_ci.cs->dohFrontend->d_keepIncomingHeaders || (conn->d_ci.cs->dohFrontend->d_trustForwardedForHeader && headerMatches(s_xForwardedForHeaderName))) {
+      if (!query.d_headers) {
+        query.d_headers = std::make_unique<HeadersMap>();
+      }
+      query.d_headers->insert({std::string(reinterpret_cast<const char*>(name), nameLen), std::string(valueView)});
+    }
+  }
+  return 0;
+}
+
+int IncomingHTTP2Connection::on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, IncomingHTTP2Connection::StreamID stream_id, const uint8_t* data, size_t len, void* user_data)
+{
+  IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+  auto stream = conn->d_currentStreams.find(stream_id);
+  if (stream == conn->d_currentStreams.end()) {
+    vinfolog("Unable to match the stream ID %d to a known one!", stream_id);
+    return NGHTTP2_ERR_CALLBACK_FAILURE;
+  }
+  if (len > std::numeric_limits<uint16_t>::max() || (std::numeric_limits<uint16_t>::max() - stream->second.d_buffer.size()) < len) {
+    vinfolog("Data frame of size %d is too large for a DNS query (we already have %d)", len, stream->second.d_buffer.size());
+    return NGHTTP2_ERR_CALLBACK_FAILURE;
+  }
+
+  stream->second.d_buffer.insert(stream->second.d_buffer.end(), data, data + len);
+
+  return 0;
+}
+
+int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data)
+{
+  IncomingHTTP2Connection* conn = reinterpret_cast<IncomingHTTP2Connection*>(user_data);
+
+  vinfolog("Error in HTTP/2 connection from %d: %s", conn->d_ci.remote.toStringWithPort(), std::string(msg, len));
+  conn->d_connectionDied = true;
+  nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
+  auto ret = nghttp2_session_send(conn->d_session.get());
+  if (ret != 0) {
+    vinfolog("Error flushing HTTP response on connection from %s: %s", conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret));
+    return NGHTTP2_ERR_CALLBACK_FAILURE;
+  }
+
+  return 0;
+}
+
+void IncomingHTTP2Connection::readHTTPData()
+{
+  IOStateGuard ioGuard(d_ioState);
+  do {
+    size_t got = 0;
+    d_in.resize(d_in.size() + 512);
+    try {
+      IOState newState = d_handler.tryRead(d_in, got, d_in.size(), true);
+      d_in.resize(got);
+
+      if (got > 0) {
+        /* we got something */
+        auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size());
+        /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
+           all data should be consumed before returning */
+        if (readlen < 0 || static_cast<size_t>(readlen) < d_in.size()) {
+          throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen)));
+        }
+
+        nghttp2_session_send(d_session.get());
+      }
+
+      if (newState == IOState::Done) {
+        if (isIdle()) {
+          watchForRemoteHostClosingConnection();
+          ioGuard.release();
+          break;
+        }
+      }
+      else {
+        if (newState == IOState::NeedWrite) {
+          updateIO(IOState::NeedWrite, handleReadableIOCallback);
+        }
+        ioGuard.release();
+        break;
+      }
+    }
+    catch (const std::exception& e) {
+      vinfolog("Exception while trying to read from HTTP backend connection: %s", e.what());
+      handleIOError();
+      break;
+    }
+  } while (getConcurrentStreamsCount() > 0);
+}
+
+void IncomingHTTP2Connection::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
+  conn->handleIO();
+}
+
+void IncomingHTTP2Connection::handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
+  IOStateGuard ioGuard(conn->d_ioState);
+
+  try {
+    IOState newState = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
+    if (newState == IOState::NeedRead) {
+      conn->updateIO(IOState::NeedRead, handleWritableIOCallback);
+    }
+    else if (newState == IOState::Done) {
+      conn->d_out.clear();
+      conn->d_outPos = 0;
+      if (!conn->isIdle()) {
+        conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
+      }
+      else {
+        conn->watchForRemoteHostClosingConnection();
+      }
+    }
+    ioGuard.release();
+  }
+  catch (const std::exception& e) {
+    vinfolog("Exception while trying to write (ready) to HTTP backend connection: %s", e.what());
+    conn->handleIOError();
+  }
+}
+
+bool IncomingHTTP2Connection::isIdle() const
+{
+  return getConcurrentStreamsCount() == 0;
+}
+
+void IncomingHTTP2Connection::stopIO()
+{
+  d_ioState->reset();
+}
+
+uint32_t IncomingHTTP2Connection::getConcurrentStreamsCount() const
+{
+  return d_currentStreams.size();
+}
+
+boost::optional<struct timeval> IncomingHTTP2Connection::getIdleClientReadTTD(struct timeval now) const
+{
+  auto idleTimeout = d_ci.cs->dohFrontend->d_idleTimeout;
+  if (g_maxTCPConnectionDuration == 0 && idleTimeout == 0) {
+    return boost::none;
+  }
+
+  if (g_maxTCPConnectionDuration > 0) {
+    auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
+    if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+      return now;
+    }
+    auto remaining = g_maxTCPConnectionDuration - elapsed;
+    if (idleTimeout == 0 || remaining <= static_cast<size_t>(idleTimeout)) {
+      now.tv_sec += remaining;
+      return now;
+    }
+  }
+
+  now.tv_sec += idleTimeout;
+  return now;
+}
+
+void IncomingHTTP2Connection::updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback)
+{
+  boost::optional<struct timeval> ttd{boost::none};
+
+  auto shared = std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this());
+  if (shared) {
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+
+    if (newState == IOState::NeedRead) {
+      if (isIdle()) {
+        ttd = getIdleClientReadTTD(now);
+      }
+      else {
+        ttd = getClientReadTTD(now);
+      }
+      d_ioState->update(newState, callback, shared, ttd);
+    }
+    else if (newState == IOState::NeedWrite) {
+      ttd = getClientWriteTTD(now);
+      d_ioState->update(newState, callback, shared, ttd);
+    }
+  }
+}
+
+void IncomingHTTP2Connection::watchForRemoteHostClosingConnection()
+{
+  updateIO(IOState::NeedRead, handleReadableIOCallback);
+}
+
+void IncomingHTTP2Connection::handleIOError()
+{
+  d_connectionDied = true;
+  nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR);
+  d_currentStreams.clear();
+  stopIO();
+}
+#endif /* HAVE_NGHTTP2 */
diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.hh b/pdns/dnsdistdist/dnsdist-nghttp2-in.hh
new file mode 100644 (file)
index 0000000..3ee1c96
--- /dev/null
@@ -0,0 +1,114 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "config.h"
+#ifdef HAVE_NGHTTP2
+#include <nghttp2/nghttp2.h>
+
+#include "dnsdist-tcp-upstream.hh"
+
+class IncomingHTTP2Connection : public IncomingTCPConnectionState
+{
+public:
+  using StreamID = int32_t;
+
+  class PendingQuery
+  {
+  public:
+    enum class Method : uint8_t
+    {
+      Unknown,
+      Get,
+      Post
+    };
+
+    PacketBuffer d_buffer;
+    PacketBuffer d_response;
+    std::string d_path;
+    std::string d_scheme;
+    std::string d_host;
+    std::string d_queryString;
+    std::string d_sni;
+    std::string d_contentTypeOut;
+    std::unique_ptr<HeadersMap> d_headers;
+    size_t d_queryPos{0};
+    uint32_t d_statusCode{0};
+    Method d_method{Method::Unknown};
+  };
+
+  IncomingHTTP2Connection(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now);
+  ~IncomingHTTP2Connection() = default;
+  void handleIO() override;
+  void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+  void notifyIOError(const struct timeval& now, TCPResponse&& response) override;
+  void restoreContext(uint32_t streamID, PendingQuery&& context);
+
+private:
+  static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data);
+  static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data);
+  static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, StreamID stream_id, const uint8_t* data, size_t len, void* user_data);
+  static int on_stream_close_callback(nghttp2_session* session, StreamID stream_id, uint32_t error_code, void* user_data);
+  static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data);
+  static int on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data);
+  static int on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data);
+  static void handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+  static void handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+
+  IOState sendResponse(const struct timeval& now, TCPResponse&& response) override;
+  bool forwardViaUDPFirst() const override
+  {
+    return true;
+  }
+  void restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&&) override;
+  std::unique_ptr<DOHUnitInterface> getDOHUnit(uint32_t streamID) override;
+
+  void stopIO();
+  bool isIdle() const;
+  uint32_t getConcurrentStreamsCount() const;
+  void updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback);
+  void watchForRemoteHostClosingConnection();
+  void handleIOError();
+  bool sendResponse(StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true);
+  void handleIncomingQuery(PendingQuery&& query, StreamID streamID);
+  bool checkALPN();
+  void readHTTPData();
+  void handleConnectionReady();
+  boost::optional<struct timeval> getIdleClientReadTTD(struct timeval now) const;
+
+  std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)> d_session{nullptr, nghttp2_session_del};
+  std::unordered_map<StreamID, PendingQuery> d_currentStreams;
+  PacketBuffer d_out;
+  PacketBuffer d_in;
+  size_t d_outPos{0};
+  bool d_connectionDied{false};
+};
+
+class NGHTTP2Headers
+{
+public:
+  static void addStaticHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& valueKey);
+  static void addDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string_view& value);
+  static void addCustomDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& name, const std::string_view& value);
+};
+
+#endif /* HAVE_NGHTTP2 */
index 39e60009e05d78552351054412600c7b1adc684f..692b73275771a32f3007ef235ca55d29d70a2911 100644 (file)
@@ -27,6 +27,7 @@
 #endif /* HAVE_NGHTTP2 */
 
 #include "dnsdist-nghttp2.hh"
+#include "dnsdist-nghttp2-in.hh"
 #include "dnsdist-tcp.hh"
 #include "dnsdist-tcp-downstream.hh"
 #include "dnsdist-downstream-connection.hh"
@@ -153,7 +154,11 @@ void DoHConnectionToBackend::handleResponse(PendingRequest&& request)
       }
     }
 
-    request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this(), d_ds));
+    TCPResponse response(std::move(request.d_query));
+    response.d_buffer = std::move(request.d_buffer);
+    response.d_connection = shared_from_this();
+    response.d_ds = d_ds;
+    request.d_sender->handleResponse(now, std::move(response));
   }
   catch (const std::exception& e) {
     vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what());
@@ -167,7 +172,8 @@ void DoHConnectionToBackend::handleResponseError(PendingRequest&& request, const
       d_ds->reportTimeoutOrError();
     }
 
-    request.d_sender->notifyIOError(std::move(request.d_query.d_idstate), now);
+    TCPResponse response(PacketBuffer(), std::move(request.d_query.d_idstate), nullptr, nullptr);
+    request.d_sender->notifyIOError(now, std::move(response));
   }
   catch (const std::exception& e) {
     vinfolog("Got exception while handling response for cross-protocol DoH: %s", e.what());
@@ -230,45 +236,6 @@ bool DoHConnectionToBackend::isIdle() const
   return getConcurrentStreamsCount() == 0;
 }
 
-const std::unordered_map<std::string, std::string> DoHConnectionToBackend::s_constants = {
-  {"method-name", ":method"},
-  {"method-value", "POST"},
-  {"scheme-name", ":scheme"},
-  {"scheme-value", "https"},
-  {"accept-name", "accept"},
-  {"accept-value", "application/dns-message"},
-  {"content-type-name", "content-type"},
-  {"content-type-value", "application/dns-message"},
-  {"user-agent-name", "user-agent"},
-  {"user-agent-value", "nghttp2-" NGHTTP2_VERSION "/dnsdist"},
-  {"authority-name", ":authority"},
-  {"path-name", ":path"},
-  {"content-length-name", "content-length"},
-  {"x-forwarded-for-name", "x-forwarded-for"},
-  {"x-forwarded-port-name", "x-forwarded-port"},
-  {"x-forwarded-proto-name", "x-forwarded-proto"},
-  {"x-forwarded-proto-value-dns-over-udp", "dns-over-udp"},
-  {"x-forwarded-proto-value-dns-over-tcp", "dns-over-tcp"},
-  {"x-forwarded-proto-value-dns-over-tls", "dns-over-tls"},
-  {"x-forwarded-proto-value-dns-over-http", "dns-over-http"},
-  {"x-forwarded-proto-value-dns-over-https", "dns-over-https"},
-};
-
-void DoHConnectionToBackend::addStaticHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& valueKey)
-{
-  const auto& name = s_constants.at(nameKey);
-  const auto& value = s_constants.at(valueKey);
-
-  headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
-}
-
-void DoHConnectionToBackend::addDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& value)
-{
-  const auto& name = s_constants.at(nameKey);
-
-  headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
-}
-
 void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query)
 {
   auto payloadSize = std::to_string(query.d_buffer.size());
@@ -284,37 +251,37 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender,
   headers.reserve(8 + (addXForwarded ? 3 : 0));
 
   /* Pseudo-headers need to come first (rfc7540 8.1.2.1) */
-  addStaticHeader(headers, "method-name", "method-value");
-  addStaticHeader(headers, "scheme-name", "scheme-value");
-  addDynamicHeader(headers, "authority-name", d_ds->d_config.d_tlsSubjectName);
-  addDynamicHeader(headers, "path-name", d_ds->d_config.d_dohPath);
-  addStaticHeader(headers, "accept-name", "accept-value");
-  addStaticHeader(headers, "content-type-name", "content-type-value");
-  addStaticHeader(headers, "user-agent-name", "user-agent-value");
-  addDynamicHeader(headers, "content-length-name", payloadSize);
+  NGHTTP2Headers::addStaticHeader(headers, "method-name", "method-value");
+  NGHTTP2Headers::addStaticHeader(headers, "scheme-name", "scheme-value");
+  NGHTTP2Headers::addDynamicHeader(headers, "authority-name", d_ds->d_config.d_tlsSubjectName);
+  NGHTTP2Headers::addDynamicHeader(headers, "path-name", d_ds->d_config.d_dohPath);
+  NGHTTP2Headers::addStaticHeader(headers, "accept-name", "accept-value");
+  NGHTTP2Headers::addStaticHeader(headers, "content-type-name", "content-type-value");
+  NGHTTP2Headers::addStaticHeader(headers, "user-agent-name", "user-agent-value");
+  NGHTTP2Headers::addDynamicHeader(headers, "content-length-name", payloadSize);
   /* no need to add these headers for health-check queries */
   if (addXForwarded && query.d_idstate.origRemote.getPort() != 0) {
     remote = query.d_idstate.origRemote.toString();
     remotePort = std::to_string(query.d_idstate.origRemote.getPort());
-    addDynamicHeader(headers, "x-forwarded-for-name", remote);
-    addDynamicHeader(headers, "x-forwarded-port-name", remotePort);
+    NGHTTP2Headers::addDynamicHeader(headers, "x-forwarded-for-name", remote);
+    NGHTTP2Headers::addDynamicHeader(headers, "x-forwarded-port-name", remotePort);
     if (query.d_idstate.cs != nullptr) {
       if (query.d_idstate.cs->isUDP()) {
-        addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-udp");
+        NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-udp");
       }
       else if (query.d_idstate.cs->isDoH()) {
         if (query.d_idstate.cs->hasTLS()) {
-          addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-https");
+          NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-https");
         }
         else {
-          addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-http");
+          NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-http");
         }
       }
       else if (query.d_idstate.cs->hasTLS()) {
-        addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tls");
+        NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tls");
       }
       else {
-        addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tcp");
+        NGHTTP2Headers::addStaticHeader(headers, "x-forwarded-proto-name", "x-forwarded-proto-value-dns-over-tcp");
       }
     }
   }
@@ -920,7 +887,8 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
     downstream->queueQuery(tqs, std::move(query));
   }
   catch (...) {
-    tqs->notifyIOError(std::move(query.d_idstate), now);
+    TCPResponse response(std::move(query));
+    tqs->notifyIOError(now, std::move(response));
   }
 }
 
index 6c6fcf222902d0bafc5e1aba580b168d09a7d7a9..43de71fc58f7751cfc3950d949ed2a21473c2a3a 100644 (file)
@@ -173,7 +173,7 @@ static uint32_t getSerialFromRawSOAContent(const std::vector<uint8_t>& raw)
 static bool getSerialFromIXFRQuery(TCPQuery& query)
 {
   try {
-    size_t proxyPayloadSize = query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0;
+    size_t proxyPayloadSize = query.d_proxyProtocolPayloadAdded ? query.d_idstate.d_proxyProtocolPayloadSize : 0;
     if (query.d_buffer.size() <= (proxyPayloadSize + sizeof(uint16_t))) {
       return false;
     }
@@ -232,24 +232,24 @@ static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState quer
     if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) {
       query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end());
       query.d_proxyProtocolPayloadAdded = true;
-      query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size();
+      query.d_idstate.d_proxyProtocolPayloadSize = query.d_proxyProtocolPayload.size();
     }
   }
   else if (connectionState == ConnectionState::proxySent) {
     if (query.d_proxyProtocolPayloadAdded) {
-      if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) {
+      if (query.d_buffer.size() < query.d_idstate.d_proxyProtocolPayloadSize) {
         throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size()));
       }
-      query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize);
+      query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_idstate.d_proxyProtocolPayloadSize);
       query.d_proxyProtocolPayloadAdded = false;
-      query.d_proxyProtocolPayloadAddedSize = 0;
+      query.d_idstate.d_proxyProtocolPayloadSize = 0;
     }
   }
   if (query.d_idstate.qclass == QClass::IN && query.d_idstate.qtype == QType::IXFR) {
     getSerialFromIXFRQuery(query);
   }
 
-  editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true);
+  editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_idstate.d_proxyProtocolPayloadSize : 0, true);
 }
 
 IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
@@ -433,7 +433,8 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
                 /* this one can't be restarted, sorry */
                 DEBUGLOG("A XFR for which a response has already been sent cannot be restarted");
                 try {
-                  pending.second.d_sender->notifyIOError(std::move(pending.second.d_query.d_idstate), now);
+                  TCPResponse response(std::move(pending.second.d_query));
+                  pending.second.d_sender->notifyIOError(now, std::move(response));
                 }
                 catch (const std::exception& e) {
                   vinfolog("Got an exception while notifying: %s", e.what());
@@ -608,7 +609,8 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F
       increaseCounters(d_currentQuery.d_query.d_idstate.cs);
       auto sender = d_currentQuery.d_sender;
       if (sender->active()) {
-        sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now);
+        TCPResponse response(std::move(d_currentQuery.d_query));
+        sender->notifyIOError(now, std::move(response));
       }
     }
 
@@ -616,7 +618,8 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F
       increaseCounters(query.d_query.d_idstate.cs);
       auto sender = query.d_sender;
       if (sender->active()) {
-        sender->notifyIOError(std::move(query.d_query.d_idstate), now);
+        TCPResponse response(std::move(query.d_query));
+        sender->notifyIOError(now, std::move(response));
       }
     }
 
@@ -624,7 +627,8 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F
       increaseCounters(response.second.d_query.d_idstate.cs);
       auto sender = response.second.d_sender;
       if (sender->active()) {
-        sender->notifyIOError(std::move(response.second.d_query.d_idstate), now);
+        TCPResponse tresp(std::move(response.second.d_query));
+        sender->notifyIOError(now, std::move(tresp));
       }
     }
   }
@@ -726,7 +730,8 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
   if (sender->active()) {
     DEBUGLOG("passing response to client connection for "<<ids.qname);
     // make sure that we still exist after calling handleResponse()
-    sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds));
+    TCPResponse response(std::move(d_responseBuffer), std::move(ids), conn, conn->d_ds);
+    sender->handleResponse(now, std::move(response));
   }
 
   if (!d_pendingQueries.empty()) {
index b668c2f9eb7cad21c5001d9f72cc33d969c24382..4318892659f6233b8dfc172f9d93bba040e24601 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "dolog.hh"
 #include "dnsdist-tcp.hh"
+#include "dnsdist-tcp-downstream.hh"
 
 struct TCPCrossProtocolResponse;
 
@@ -26,7 +27,10 @@ public:
 class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
 {
 public:
-  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
+  enum class QueryProcessingResult : uint8_t { Forwarded, TooSmall, InvalidHeaders, Empty, Dropped, SelfAnswered, NoBackend, Asynchronous };
+  enum class ProxyProtocolResult : uint8_t { Reading, Done, Error };
+
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext.getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
   {
     d_origDest.reset();
     d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
@@ -46,7 +50,7 @@ public:
   IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
   IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
 
-  ~IncomingTCPConnectionState();
+  virtual ~IncomingTCPConnectionState();
 
   void resetForNewQuery();
 
@@ -118,24 +122,27 @@ public:
 
   static size_t clearAllDownstreamConnections();
 
-  static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now);
   static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
   static void handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param);
   static void updateIO(std::shared_ptr<IncomingTCPConnectionState>& state, IOState newState, const struct timeval& now);
 
-  static IOState sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
   static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
-static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
+  static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
+
+  virtual void handleIO();
 
-  /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */
-  void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+  QueryProcessingResult handleQuery(PacketBuffer&& query, const struct timeval& now, std::optional<int32_t> streamID);
+  virtual void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+  virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) override;
   void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
-  void notifyIOError(InternalQueryState&& query, const struct timeval& now) override;
 
+  virtual IOState sendResponse(const struct timeval& now, TCPResponse&& response);
+  void handleResponseSent(TCPResponse& currentResponse);
+  void handleHandshakeDone(const struct timeval& now);
+  ProxyProtocolResult handleProxyProtocolPayload();
   void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);
 
   void terminateClientConnection();
-  void queueQuery(TCPQuery&& query);
 
   bool canAcceptNewQueries(const struct timeval& now);
 
@@ -143,6 +150,20 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
   {
     return d_ioState != nullptr;
   }
+  virtual bool forwardViaUDPFirst() const
+  {
+    return false;
+  }
+  virtual std::unique_ptr<DOHUnitInterface> getDOHUnit(uint32_t streamID)
+  {
+    throw std::runtime_error("Getting a DOHUnit state from a generic TCP/DoT connection is not supported");
+  }
+  virtual void restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&&)
+  {
+    throw std::runtime_error("Restoring a DOHUnit state to a generic TCP/DoT connection is not supported");
+  }
+
+  std::unique_ptr<CrossProtocolQuery> getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& ds);
 
   std::string toString() const
   {
@@ -151,6 +172,8 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
     return o.str();
   }
 
+  dnsdist::Protocol getProtocol() const;
+
   enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
 
   TCPResponse d_currentResponse;
index d5f2edb0d1e895bfa67bee48cbc48ab994fbd086..aef6cf6ec38af31cf0a95b59379d8cd00a994abc 100644 (file)
@@ -21,6 +21,7 @@
  */
 #pragma once
 
+#include <optional>
 #include <unistd.h>
 #include "channel.hh"
 #include "iputils.hh"
@@ -100,7 +101,6 @@ public:
   InternalQueryState d_idstate;
   std::string d_proxyProtocolPayload;
   PacketBuffer d_buffer;
-  uint32_t d_proxyProtocolPayloadAddedSize{0};
   uint32_t d_ixfrQuerySerial{0};
   uint32_t d_xfrMasterSerial{0};
   uint32_t d_xfrSerialCount{0};
@@ -133,6 +133,17 @@ struct TCPResponse : public TCPQuery
     }
   }
 
+  TCPResponse(TCPQuery&& query) :
+    TCPQuery(std::move(query))
+  {
+    if (d_buffer.size() >= sizeof(dnsheader)) {
+      memcpy(&d_cleartextDH, reinterpret_cast<const dnsheader*>(d_buffer.data()), sizeof(d_cleartextDH));
+    }
+    else {
+      memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
+    }
+  }
+
   bool isAsync() const
   {
     return d_async;
@@ -154,7 +165,7 @@ public:
   virtual bool active() const = 0;
   virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0;
   virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0;
-  virtual void notifyIOError(InternalQueryState&& query, const struct timeval& now) = 0;
+  virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) = 0;
 
   /* whether the connection should be automatically released to the pool after handleResponse()
      has been called */
@@ -199,7 +210,6 @@ struct CrossProtocolQuery
 
   InternalQuery query;
   std::shared_ptr<DownstreamState> downstream{nullptr};
-  size_t proxyProtocolPayloadSize{0};
   bool d_isResponse{false};
 };
 
index 91dcd9ad767418df8fe93538bcd662783ec50da9..a2747400f47bf91e213cfa8effa38fbcac9159e2 100644 (file)
@@ -534,8 +534,9 @@ public:
     return handleResponse(now, std::move(response));
   }
 
-  void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+  void notifyIOError(const struct timeval& now, TCPResponse&& response) override
   {
+    auto& query = response.d_idstate;
     if (!query.du) {
       return;
     }
@@ -1041,7 +1042,7 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req)
     if (!holders.acl->match(remote)) {
       ++dnsdist::metrics::g_stats.aclDrops;
       vinfolog("Query from %s (DoH) dropped because of ACL", remote.toStringWithPort());
-      h2o_send_error_403(req, "Forbidden", "dns query not allowed because of ACL", 0);
+      h2o_send_error_403(req, "Forbidden", "DoH query not allowed because of ACL", 0);
       return 0;
     }
 
@@ -1344,6 +1345,13 @@ static void on_accept(h2o_socket_t *listener, const char *err)
     return;
   }
 
+  if (dsc->df->d_earlyACLDrop && !dsc->df->d_trustForwardedForHeader && !dsc->holders.acl->match(remote)) {
+    ++dnsdist::metrics::g_stats.aclDrops;
+      vinfolog("Dropping DoH connection from %s because of ACL", remote.toStringWithPort());
+      h2o_socket_close(sock);
+      return;
+    }
+
   if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) {
     vinfolog("Dropping DoH connection from %s because we have too many from this client already", remote.toStringWithPort());
     h2o_socket_close(sock);
index 876a21890f2ba3da91f5fc7784d1ee66c98883bf..baf9118e67296fc499f925d989e7a6f6f455de58 100644 (file)
@@ -1,7 +1,7 @@
 AC_DEFUN([DNSDIST_ENABLE_DNS_OVER_HTTPS], [
   AC_MSG_CHECKING([whether to enable incoming DNS over HTTPS (DoH) support])
   AC_ARG_ENABLE([dns-over-https],
-    AS_HELP_STRING([--enable-dns-over-https], [enable incoming DNS over HTTPS (DoH) support (requires libh2o) @<:@default=no@:>@]),
+    AS_HELP_STRING([--enable-dns-over-https], [enable incoming DNS over HTTPS (DoH) support (requires libh2o or nghttp2) @<:@default=no@:>@]),
     [enable_dns_over_https=$enableval],
     [enable_dns_over_https=no]
   )
index 8305b2b906737c33298e735981a18bde58b5bf7c..273385cf242d8ced19cc8832eb0f39d30a97edf5 100644 (file)
@@ -13,6 +13,13 @@ AC_DEFUN([PDNS_WITH_NGHTTP2], [
       PKG_CHECK_MODULES([NGHTTP2], [libnghttp2], [
         [HAVE_NGHTTP2=1]
         AC_DEFINE([HAVE_NGHTTP2], [1], [Define to 1 if you have nghttp2])
+        save_CFLAGS=$CFLAGS
+        save_LIBS=$LIBS
+        CFLAGS="$NGHTTP2_CFLAGS $CFLAGS"
+        LIBS="$NGHTTP2_LIBS $LIBS"
+        AC_CHECK_FUNCS([nghttp2_check_header_value_rfc9113 nghttp2_check_method nghttp2_check_path])
+        CFLAGS=$save_CFLAGS
+        LIBS=$save_LIBS
       ], [ : ])
     ])
   ])
index b1fbebca053b8b483e523f875ceebfa51614747b..65a9c4a53fba11cd2af6dc20aacc380c28793bf2 100644 (file)
@@ -44,7 +44,7 @@ public:
   {
   }
 
-  void notifyIOError(InternalQueryState&&, const struct timeval&) override
+  void notifyIOError(const struct timeval&, TCPResponse&&) override
   {
     errorRaised = true;
   }
index 9d437578f77fb48fe8e71844841e38ddbf141e27..19cafc004ce2ed220de8884e321eeefe1bbefdd7 100644 (file)
@@ -34,7 +34,6 @@ std::vector<std::unique_ptr<ClientState>> g_frontends;
 /* add stub implementations, we don't want to include the corresponding object files
    and their dependencies */
 
-// NOLINTNEXTLINE(readability-convert-member-functions-to-static): this is a stub, the real one is not that simple..
 bool TLSFrontend::setupTLS()
 {
   return true;
index 41d9992cda3bf8330a3546c7c15a876cf7393c9c..c43297b53e5f810f8715e63d92f048bae47d221d 100644 (file)
@@ -626,11 +626,11 @@ public:
     d_valid = true;
   }
 
-  void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
+  void handleXFRResponse(const struct timeval&, TCPResponse&&) override
   {
   }
 
-  void notifyIOError(InternalQueryState&& query, const struct timeval& now) override
+  void notifyIOError(const struct timeval&, TCPResponse&&) override
   {
     d_error = true;
   }
index 2aa5adbe2063b782e577df4af7baaa8a9f6d88d2..22e137c24b3b1629f606342ed8533df93b537629 100644 (file)
@@ -500,7 +500,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
   }
 
@@ -523,7 +523,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
     BOOST_CHECK(s_writeBuffer == query);
   }
@@ -558,7 +558,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -582,7 +582,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
   }
 
@@ -610,7 +610,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * count);
 #endif
   }
@@ -636,7 +636,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
     struct timeval later = now;
     later.tv_sec += g_tcpRecvTimeout + 1;
@@ -672,7 +672,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
     struct timeval later = now;
     later.tv_sec += g_tcpRecvTimeout + 1;
@@ -705,7 +705,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_SelfAnswered)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
   }
 }
@@ -766,7 +766,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered)
     dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * 2U);
   }
@@ -793,7 +793,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
 
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
   }
@@ -823,7 +823,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered)
     dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(-1);
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(threadData.mplexer->run(&now), 0);
     struct timeval later = now;
     later.tv_sec += g_tcpRecvTimeout + 1;
@@ -903,7 +903,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
     BOOST_CHECK(s_writeBuffer == query);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
@@ -943,7 +943,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
     BOOST_CHECK(s_backendWriteBuffer == query);
@@ -982,7 +982,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
     BOOST_CHECK(s_backendWriteBuffer == query);
@@ -1025,7 +1025,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
     BOOST_CHECK(s_backendWriteBuffer == query);
@@ -1052,7 +1052,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
@@ -1090,7 +1090,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
@@ -1160,7 +1160,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     /* set the incoming descriptor as ready! */
     dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -1221,7 +1221,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
@@ -1257,7 +1257,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     struct timeval later = now;
     later.tv_sec += backend->d_config.tcpSendTimeout + 1;
     auto expiredWriteConns = threadData.mplexer->getTimeouts(later, true);
@@ -1303,7 +1303,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     struct timeval later = now;
     later.tv_sec += backend->d_config.tcpRecvTimeout + 1;
     auto expiredConns = threadData.mplexer->getTimeouts(later, false);
@@ -1360,7 +1360,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
@@ -1416,7 +1416,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
     BOOST_CHECK(s_writeBuffer == query);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
@@ -1475,7 +1475,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
@@ -1527,7 +1527,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size() * backend->d_config.d_retries);
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
@@ -1587,7 +1587,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size());
     BOOST_CHECK(s_writeBuffer == query);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size() * backend->d_config.d_retries);
@@ -1628,7 +1628,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), 0U);
     BOOST_CHECK_EQUAL(s_backendWriteBuffer.size(), query.size());
     BOOST_CHECK(s_backendWriteBuffer == query);
@@ -1690,7 +1690,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(s_writeBuffer.size(), query.size() * count);
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
 
@@ -1732,7 +1732,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     BOOST_CHECK_EQUAL(backend->outstanding.load(), 0U);
 
     /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */
@@ -1916,7 +1916,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -2048,7 +2048,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
 
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
@@ -2228,7 +2228,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
 
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
@@ -2304,7 +2304,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -2387,7 +2387,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while ((threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -2504,7 +2504,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -2656,7 +2656,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -2863,7 +2863,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -3037,7 +3037,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -3301,7 +3301,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -3427,7 +3427,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -3512,7 +3512,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -3577,7 +3577,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -3768,7 +3768,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -3853,7 +3853,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
@@ -4085,7 +4085,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
       threadData.mplexer->run(&now);
     }
@@ -4137,7 +4137,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR)
     };
 
     auto state = std::make_shared<IncomingTCPConnectionState>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
-    IncomingTCPConnectionState::handleIO(state, now);
+    state->handleIO();
     while (!timeout && (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0)) {
       threadData.mplexer->run(&now);
     }
index 58a26f16918f6e33da6ecd2ed728800c25e324c3..c482b7a0fa2a405ddcebb99793df73abd224ad20 100644 (file)
@@ -51,7 +51,7 @@ public:
   size_t getTicketsKeysCount() override;
 };
 
-void dohThread(ClientState* clientState);
+void dohThread(ClientState* cs);
 
 #endif /* HAVE_LIBH2OEVLOOP */
 #endif /* HAVE_DNS_OVER_HTTPS  */
index 1b7018c028ba3d906fc727b8ff90f628a10a598a..78f23f9df414a15629f60aa83dd201d31859478f 100644 (file)
@@ -1850,6 +1850,7 @@ bool TLSFrontend::setupTLS()
     newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
   }
 #endif /* HAVE_LIBSSL */
+
   if (!newCtx) {
 #ifdef HAVE_LIBSSL
     newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
@@ -1874,7 +1875,7 @@ bool TLSFrontend::setupTLS()
 
 std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameters& params)
 {
-#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS)
+#ifdef HAVE_DNS_OVER_TLS
   /* get the "best" available provider */
   if (!params.d_provider.empty()) {
 #ifdef HAVE_GNUTLS
@@ -1897,6 +1898,6 @@ std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameter
 #endif /* HAVE_GNUTLS */
 #endif /* HAVE_LIBSSL */
 
-#endif /* HAVE_DNS_OVER_TLS || HAVE_DNS_OVER_HTTPS */
+#endif /* HAVE_DNS_OVER_TLS */
   return nullptr;
 }
index 29b59a01f9a418fe3478d5d1ebfadd92bfd2f94b..5e1d23e737c6dd613f8d701e95a1f64e323d15d8 100644 (file)
@@ -138,7 +138,7 @@ class TLSFrontend
 public:
   enum class ALPN : uint8_t { Unset, DoT, DoH };
 
-  TLSFrontend(ALPN alpn) : d_alpn(alpn)
+  TLSFrontend(ALPN alpn): d_alpn(alpn)
   {
   }
 
@@ -233,7 +233,6 @@ protected:
 class TCPIOHandler
 {
 public:
-  enum class Type : uint8_t { Client, Server };
 
   TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx): d_socket(socket)
   {
index c51a930c04f8566c09a959e927374133f4a11a5a..850273eb8a678b4b886ecc34d97bb8fa6c0f15bf 100644 (file)
@@ -56,7 +56,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs
 
 bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
 {
-  return true;
+  return false;
 }
 
 namespace dnsdist {
index 156ba192198ff5b406e92e5991ead2330f3a71b2..6bc56cdb7a71a5f8ff0a4539f0b9c34229f57695 100644 (file)
@@ -624,7 +624,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return sock
 
     @classmethod
-    def openTLSConnection(cls, port, serverName, caCert=None, timeout=None):
+    def openTLSConnection(cls, port, serverName, caCert=None, timeout=None, alpn=[]):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
@@ -633,6 +633,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         # 2.7.9+
         if hasattr(ssl, 'create_default_context'):
             sslctx = ssl.create_default_context(cafile=caCert)
+            if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
+              sslctx.set_alpn_protocols(alpn)
             sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
         else:
             sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
@@ -992,6 +994,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         #conn.setopt(pycurl.VERBOSE, True)
         conn.setopt(pycurl.URL, url)
         conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
         if useHTTPS:
             conn.setopt(pycurl.SSL_VERIFYPEER, 1)
             conn.setopt(pycurl.SSL_VERIFYHOST, 2)
@@ -1036,6 +1040,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         #conn.setopt(pycurl.VERBOSE, True)
         conn.setopt(pycurl.URL, url)
         conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
         if useHTTPS:
             conn.setopt(pycurl.SSL_VERIFYPEER, 1)
             conn.setopt(pycurl.SSL_VERIFYHOST, 2)
index 5999021f5dd80633a9fc09445ed74f9f26382395..d4d6606faf013f16d2cd176940ba8d6b2534b089 100644 (file)
@@ -573,7 +573,7 @@ class TestDOHSubPaths(DNSDistDOHTest):
         # this path is not in the URLs map and should lead to a 404
         (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL + "NotPowerDNS", query, caFile=self._caCert, useQueue=False, rawResponse=True)
         self.assertTrue(receivedResponse)
-        self.assertEqual(receivedResponse, b'not found')
+        self.assertIn(receivedResponse, [b'there is no endpoint configured for this path', b'not found'])
         self.assertEqual(self._rcode, 404)
 
         # this path is below one in the URLs map and exactPathMatching is false, so we should be good
@@ -1116,7 +1116,7 @@ class TestDOHForwardedFor(DNSDistDOHTest):
         (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 127.0.0.1:42, 127.0.0.1'])
 
         self.assertEqual(self._rcode, 403)
-        self.assertEqual(receivedResponse, b'dns query not allowed because of ACL')
+        self.assertEqual(receivedResponse, b'DoH query not allowed because of ACL')
 
 class TestDOHForwardedForNoTrusted(DNSDistDOHTest):
 
@@ -1130,7 +1130,7 @@ class TestDOHForwardedForNoTrusted(DNSDistDOHTest):
     newServer{address="127.0.0.1:%s"}
 
     setACL('192.0.2.1/32')
-    addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" })
+    addDOHLocal("127.0.0.1:%s", "%s", "%s", { "/" }, {earlyACLDrop=true})
     """
     _config_params = ['_testServerPort', '_dohServerPort', '_serverCert', '_serverKey']
 
@@ -1151,10 +1151,15 @@ class TestDOHForwardedForNoTrusted(DNSDistDOHTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 192.0.2.1:4200'])
+        dropped = False
+        try:
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=False, rawResponse=True, customHeaders=['x-forwarded-for: 192.0.2.1:4200'])
+            self.assertEqual(self._rcode, 403)
+            self.assertEqual(receivedResponse, b'DoH query not allowed because of ACL')
+        except pycurl.error as e:
+            dropped = True
 
-        self.assertEqual(self._rcode, 403)
-        self.assertEqual(receivedResponse, b'dns query not allowed because of ACL')
+        self.assertTrue(dropped)
 
 class TestDOHFrontendLimits(DNSDistDOHTest):
 
@@ -1190,7 +1195,7 @@ class TestDOHFrontendLimits(DNSDistDOHTest):
 
         for idx in range(self._maxTCPConnsPerDOHFrontend + 1):
             try:
-                conns.append(self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert))
+                conns.append(self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert, alpn=['h2']))
             except:
                 conns.append(None)
 
index e4f0b0232c17e7bf0917a7b5c037bbb72be5e7c5..092cdef7453429235aa811041460e737f6b5d922 100644 (file)
@@ -546,7 +546,6 @@ class TestProtobufMetaDOH(DNSDistProtobufTest):
             elif method == "sendDOHQueryWrapper":
                 pbMessageType = dnsmessage_pb2.PBDNSMessage.DOH
 
-            print(method)
             self.checkProtobufQuery(msg, pbMessageType, query, dns.rdataclass.IN, dns.rdatatype.A, name)
             self.assertEqual(len(msg.meta), 5)
             tags = {}