]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add unit and regression tests for incoming DoH w/ nghttp2
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 30 Jun 2023 15:49:35 +0000 (17:49 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Sep 2023 08:22:03 +0000 (10:22 +0200)
It is quite likely that the underlying TLS layer has buffered some
data already, so we need to consume it before trying to poll the
socket.

12 files changed:
pdns/dnsdist-doh-common.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-nghttp2-in.cc
pdns/dnsdistdist/dnsdist-nghttp2-in.hh
pdns/dnsdistdist/dnsdist-nghttp2.cc
pdns/dnsdistdist/test-dnsdistnghttp2-in_cc.cc [new file with mode: 0644]
pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc
pdns/dnsdistdist/test-dnsdistnghttp2_common.hh [new file with mode: 0644]
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/test-dnsdist_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_DOH.py

index 41166de9f3023bc2908d96f7bdae260b7576be96..f0a1adc76744649bf79807d5392f22e3cf36e07e 100644 (file)
@@ -77,6 +77,10 @@ struct DOHFrontend
   DOHFrontend()
   {
   }
+  DOHFrontend(std::shared_ptr<TLSCtx> tlsCtx):
+    d_tlsContext(std::move(tlsCtx))
+  {
+  }
 
   virtual ~DOHFrontend()
   {
index e4f30eaa83ab7dbb629b902cc03c51e590583248..0c07520108a94325287036ce25b822b02c309f4e 100644 (file)
@@ -332,7 +332,9 @@ testrunner_SOURCES = \
        test-dnsdistkvs_cc.cc \
        test-dnsdistlbpolicies_cc.cc \
        test-dnsdistluanetwork.cc \
+       test-dnsdistnghttp2-in_cc.cc \
        test-dnsdistnghttp2_cc.cc \
+       test-dnsdistnghttp2_common.hh \
        test-dnsdistpacketcache_cc.cc \
        test-dnsdistrings_cc.cc \
        test-dnsdistrules_cc.cc \
index 21098ec96ec57899207f2cbc809d639b26578320..b71601a916d4746198ca12f5036f96a75824bb4b 100644 (file)
@@ -142,7 +142,7 @@ public:
     d_query.d_contentTypeOut = contentType;
   }
 
-  void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& downstream) override
+  void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& downstream_) override
   {
     std::unique_ptr<DOHUnitInterface> unit(this);
     auto conn = d_connection.lock();
@@ -153,7 +153,7 @@ public:
 
     state.du = std::move(unit);
     TCPResponse resp(std::move(response), std::move(state), nullptr, nullptr);
-    resp.d_ds = downstream;
+    resp.d_ds = downstream_;
     struct timeval now
     {
     };
@@ -263,7 +263,7 @@ IncomingHTTP2Connection::IncomingHTTP2Connection(ConnectionInfo&& connectionInfo
 bool IncomingHTTP2Connection::checkALPN()
 {
   constexpr std::array<uint8_t, 2> h2ALPN{'h', '2'};
-  auto protocols = d_handler.getNextProtocol();
+  const auto protocols = d_handler.getNextProtocol();
   if (protocols.size() == h2ALPN.size() && memcmp(protocols.data(), h2ALPN.data(), h2ALPN.size()) == 0) {
     return true;
   }
@@ -285,6 +285,11 @@ void IncomingHTTP2Connection::handleConnectionReady()
   }
 }
 
+bool IncomingHTTP2Connection::hasPendingWrite() const
+{
+  return d_pendingWrite;
+}
+
 void IncomingHTTP2Connection::handleIO()
 {
   IOState iostate = IOState::Done;
@@ -297,7 +302,7 @@ void IncomingHTTP2Connection::handleIO()
     if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
       vinfolog("Terminating DoH connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
       stopIO();
-      d_connectionDied = true;
+      d_connectionClosing = true;
       return;
     }
 
@@ -341,56 +346,94 @@ void IncomingHTTP2Connection::handleIO()
       }
     }
 
-    if (d_state == State::waitingForQuery || d_state == State::idle) {
-      readHTTPData();
+    if (active() && !d_connectionClosing && (d_state == State::waitingForQuery || d_state == State::idle)) {
+      do {
+        iostate = readHTTPData();
+      } while (active() && !d_connectionClosing && iostate == IOState::Done);
     }
 
-    if (!d_connectionDied) {
-      auto shared = std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this());
+    if (!active()) {
+      stopIO();
+      return;
+    }
+    /*
+      So:
+      - if we have a pending write, we need to wait until the socket becomes writable
+        and then call handleWritableCallback
+      - if we have NeedWrite but no pending write, we need to wait until the socket
+        becomes writable but for handleReadableIOCallback
+      - if we have NeedRead, or nghttp2_session_want_read, wait until the socket
+        becomes readable and call handleReadableIOCallback
+    */
+    if (hasPendingWrite()) {
+      updateIO(IOState::NeedWrite, handleWritableIOCallback);
+    }
+    else if (iostate == IOState::NeedWrite) {
+      updateIO(IOState::NeedWrite, handleReadableIOCallback);
+    }
+    else if (!d_connectionClosing) {
       if (nghttp2_session_want_read(d_session.get()) != 0) {
-        d_ioState->add(IOState::NeedRead, &handleReadableIOCallback, shared, boost::none);
+        updateIO(IOState::NeedRead, handleReadableIOCallback);
       }
-      if (nghttp2_session_want_write(d_session.get()) != 0) {
-        d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, shared, boost::none);
+      else {
+        if (isIdle()) {
+          watchForRemoteHostClosingConnection();
+        }
       }
     }
   }
   catch (const std::exception& e) {
-    vinfolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+    infolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
     d_connectionDied = true;
     stopIO();
   }
 }
 
-ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+void IncomingHTTP2Connection::writeToSocket(bool socketReady)
 {
-  auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
-  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API
-  conn->d_out.insert(conn->d_out.end(), data, data + length);
-
-  if (conn->d_connectionDied || conn->d_needFlush) {
-    try {
-      conn->d_needFlush = false;
-      auto state = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
-      if (state == IOState::Done) {
-        conn->d_out.clear();
-        conn->d_outPos = 0;
-        if (!conn->isIdle()) {
-          conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
+  try {
+    d_needFlush = false;
+    IOState newState = d_handler.tryWrite(d_out, d_outPos, d_out.size());
+
+    if (newState == IOState::Done) {
+      d_pendingWrite = false;
+      d_out.clear();
+      d_outPos = 0;
+      if (active() && !d_connectionClosing) {
+        if (!isIdle()) {
+          updateIO(IOState::NeedRead, handleReadableIOCallback);
         }
         else {
-          conn->watchForRemoteHostClosingConnection();
+          watchForRemoteHostClosingConnection();
         }
       }
       else {
-        conn->updateIO(state, handleWritableIOCallback);
+        stopIO();
       }
     }
-    catch (const std::exception& e) {
-      vinfolog("Exception while trying to write (send) to incoming HTTP connection to %s: %s", conn->d_ci.remote.toStringWithPort(), e.what());
-      conn->handleIOError();
+    else {
+      updateIO(newState, handleWritableIOCallback);
+      d_pendingWrite = true;
     }
   }
+  catch (const std::exception& e) {
+    vinfolog("Exception while trying to write (%s) to HTTP client connection to %s: %s", (socketReady ? "ready" : "send"), d_ci.remote.toStringWithPort(), e.what());
+    handleIOError();
+  }
+}
+
+ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+{
+  auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+  if (conn->d_connectionDied) {
+    return static_cast<ssize_t>(length);
+  }
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API
+  conn->d_out.insert(conn->d_out.end(), data, data + length);
+
+  if (conn->d_connectionClosing || conn->d_needFlush) {
+    conn->writeToSocket(false);
+  }
 
   return static_cast<ssize_t>(length);
 }
@@ -471,7 +514,7 @@ IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResp
   sendResponse(response.d_idstate.d_streamID, context, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType);
   handleResponseSent(response);
 
-  return IOState::Done;
+  return hasPendingWrite() ? IOState::NeedWrite : IOState::Done;
 }
 
 void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPResponse&& response)
@@ -748,6 +791,12 @@ void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::Pendi
     sendResponse(streamID, query, code, d_ci.cs->dohFrontend->d_customResponseHeaders);
   };
 
+  if (query.d_method == PendingQuery::Method::Unknown ||
+      query.d_method == PendingQuery::Method::Unsupported) {
+    handleImmediateResponse(400, "DoH query not allowed because of unsupported HTTP method");
+    return;
+  }
+
   ++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries;
 
   if (d_ci.cs->dohFrontend->d_trustForwardedForHeader) {
@@ -864,44 +913,8 @@ void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::Pendi
 int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
 {
   auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
-#if 0
-  switch (frame->hd.type) {
-  case NGHTTP2_HEADERS:
-    cerr<<"got headers"<<endl;
-    if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
-      cerr<<"All headers received"<<endl;
-    }
-    if (frame->headers.cat == NGHTTP2_HCAT_REQUEST) {
-      cerr<<"All headers received - query"<<endl;
-    }
-    break;
-  case NGHTTP2_WINDOW_UPDATE:
-    cerr<<"got window update"<<endl;
-    break;
-  case NGHTTP2_SETTINGS:
-    cerr<<"got settings"<<endl;
-    cerr<<frame->settings.niv<<endl;
-    for (size_t idx = 0; idx < frame->settings.niv; idx++) {
-      cerr<<"- "<<frame->settings.iv[idx].settings_id<<" "<<frame->settings.iv[idx].value<<endl;
-    }
-    break;
-  case NGHTTP2_DATA:
-    cerr<<"got data"<<endl;
-    break;
-  }
-#endif
-
-  if (frame->hd.type == NGHTTP2_GOAWAY) {
-    conn->stopIO();
-    if (conn->isIdle()) {
-      if (nghttp2_session_want_write(conn->d_session.get()) != 0) {
-        conn->d_ioState->add(IOState::NeedWrite, &handleWritableIOCallback, conn, boost::none);
-      }
-    }
-  }
-
   /* is this the last frame for this stream? */
-  else if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) {
+  if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) {
     auto streamID = frame->hd.stream_id;
     auto stream = conn->d_currentStreams.find(streamID);
     if (stream != conn->d_currentStreams.end()) {
@@ -959,7 +972,8 @@ int IncomingHTTP2Connection::on_begin_headers_callback(nghttp2_session* session,
   if (!insertPair.second) {
     /* there is a stream ID collision, something is very wrong! */
     vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort());
-    conn->d_connectionDied = true;
+    conn->d_connectionClosing = true;
+    conn->d_needFlush = true;
     nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
     auto ret = nghttp2_session_send(conn->d_session.get());
     if (ret != 0) {
@@ -1047,8 +1061,9 @@ int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const
         query.d_method = PendingQuery::Method::Post;
       }
       else {
+        query.d_method = PendingQuery::Method::Unsupported;
         vinfolog("Unsupported method value");
-        return NGHTTP2_ERR_CALLBACK_FAILURE;
+        return 0;
       }
     }
 
@@ -1087,7 +1102,8 @@ int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib
   auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
 
   vinfolog("Error in HTTP/2 connection from %d: %s", conn->d_ci.remote.toStringWithPort(), std::string(msg, len));
-  conn->d_connectionDied = true;
+  conn->d_connectionClosing = true;
+  conn->d_needFlush = true;
   nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
   auto ret = nghttp2_session_send(conn->d_session.get());
   if (ret != 0) {
@@ -1098,55 +1114,35 @@ int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib
   return 0;
 }
 
-void IncomingHTTP2Connection::readHTTPData()
+IOState IncomingHTTP2Connection::readHTTPData()
 {
   IOState newState = IOState::Done;
-  IOStateGuard ioGuard(d_ioState);
-  do {
-    size_t got = 0;
-    if (d_in.size() < 128) {
-      d_in.resize(std::max(static_cast<size_t>(128U), d_in.capacity()));
-    }
-    try {
-      newState = d_handler.tryRead(d_in, got, d_in.size(), true);
-      d_in.resize(got);
-
-      if (got > 0) {
-        /* we got something */
-        auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size());
-        /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
-           all data should be consumed before returning */
-        if (readlen < 0 || static_cast<size_t>(readlen) < d_in.size()) {
-          throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen)));
-        }
-
-        nghttp2_session_send(d_session.get());
+  size_t got = 0;
+  if (d_in.size() < s_initialReceiveBufferSize) {
+    d_in.resize(std::max(s_initialReceiveBufferSize, d_in.capacity()));
+  }
+  try {
+    newState = d_handler.tryRead(d_in, got, d_in.size(), true);
+    d_in.resize(got);
+
+    if (got > 0) {
+      /* we got something */
+      auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size());
+      /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
+         all data should be consumed before returning */
+      if (readlen < 0 || static_cast<size_t>(readlen) < d_in.size()) {
+        throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen)));
       }
 
-      if (newState == IOState::Done) {
-        if (nghttp2_session_want_read(d_session.get()) != 0) {
-          continue;
-        }
-        if (isIdle()) {
-          watchForRemoteHostClosingConnection();
-          ioGuard.release();
-          break;
-        }
-      }
-      else {
-        if (newState == IOState::NeedWrite) {
-          updateIO(IOState::NeedWrite, handleReadableIOCallback);
-        }
-        ioGuard.release();
-        break;
-      }
+      nghttp2_session_send(d_session.get());
     }
-    catch (const std::exception& e) {
-      vinfolog("Exception while trying to read from HTTP client connection to %s: %s", d_ci.remote.toStringWithPort(), e.what());
-      handleIOError();
-      break;
-    }
-  } while (newState == IOState::Done || !isIdle());
+  }
+  catch (const std::exception& e) {
+    vinfolog("Exception while trying to read from HTTP client connection to %s: %s", d_ci.remote.toStringWithPort(), e.what());
+    handleIOError();
+    return IOState::Done;
+  }
+  return newState;
 }
 
 void IncomingHTTP2Connection::handleReadableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param)
@@ -1158,29 +1154,7 @@ void IncomingHTTP2Connection::handleReadableIOCallback([[maybe_unused]] int desc
 void IncomingHTTP2Connection::handleWritableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param)
 {
   auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
-  IOStateGuard ioGuard(conn->d_ioState);
-
-  try {
-    IOState newState = conn->d_handler.tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
-    if (newState == IOState::NeedRead) {
-      conn->updateIO(IOState::NeedRead, handleWritableIOCallback);
-    }
-    else if (newState == IOState::Done) {
-      conn->d_out.clear();
-      conn->d_outPos = 0;
-      if (!conn->isIdle()) {
-        conn->updateIO(IOState::NeedRead, handleReadableIOCallback);
-      }
-      else {
-        conn->watchForRemoteHostClosingConnection();
-      }
-    }
-    ioGuard.release();
-  }
-  catch (const std::exception& e) {
-    vinfolog("Exception while trying to write (ready) to HTTP client connection to %s: %s", conn->d_ci.remote.toStringWithPort(), e.what());
-    conn->handleIOError();
-  }
+  conn->writeToSocket(true);
 }
 
 bool IncomingHTTP2Connection::isIdle() const
@@ -1250,14 +1224,31 @@ void IncomingHTTP2Connection::updateIO(IOState newState, const FDMultiplexer::ca
 
 void IncomingHTTP2Connection::watchForRemoteHostClosingConnection()
 {
-  updateIO(IOState::NeedRead, handleReadableIOCallback);
+  if (d_connectionDied) {
+    return;
+  }
+
+  if (hasPendingWrite()) {
+    updateIO(IOState::NeedWrite, &handleWritableIOCallback);
+  }
+  else if (!d_connectionClosing) {
+    updateIO(IOState::NeedRead, handleReadableIOCallback);
+  }
 }
 
 void IncomingHTTP2Connection::handleIOError()
 {
   d_connectionDied = true;
+  d_out.clear();
+  d_outPos = 0;
   nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR);
   d_currentStreams.clear();
   stopIO();
 }
+
+bool IncomingHTTP2Connection::active() const
+{
+  return !d_connectionDied && d_ioState != nullptr;
+}
+
 #endif /* HAVE_NGHTTP2 */
index e68d2142085e7487cf4210812a0fbba5467129b3..3db7473a8e3d12a716b3db79109f568418c04d81 100644 (file)
@@ -39,7 +39,8 @@ public:
     {
       Unknown,
       Get,
-      Post
+      Post,
+      Unsupported
     };
 
     PacketBuffer d_buffer;
@@ -61,6 +62,7 @@ public:
   void handleIO() override;
   void handleResponse(const struct timeval& now, TCPResponse&& response) override;
   void notifyIOError(const struct timeval& now, TCPResponse&& response) override;
+  bool active() const override;
 
 private:
   static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data);
@@ -73,6 +75,8 @@ private:
   static void handleReadableIOCallback(int descriptor, FDMultiplexer::funcparam_t& param);
   static void handleWritableIOCallback(int descriptor, FDMultiplexer::funcparam_t& param);
 
+  static constexpr size_t s_initialReceiveBufferSize{256U};
+
   IOState sendResponse(const struct timeval& now, TCPResponse&& response) override;
   bool forwardViaUDPFirst() const override
   {
@@ -90,8 +94,10 @@ private:
   bool sendResponse(StreamID streamID, PendingQuery& context, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true);
   void handleIncomingQuery(PendingQuery&& query, StreamID streamID);
   bool checkALPN();
-  void readHTTPData();
+  IOState readHTTPData();
   void handleConnectionReady();
+  bool hasPendingWrite() const;
+  void writeToSocket(bool socketReady);
   boost::optional<struct timeval> getIdleClientReadTTD(struct timeval now) const;
 
   std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)> d_session{nullptr, nghttp2_session_del};
@@ -99,8 +105,18 @@ private:
   PacketBuffer d_out;
   PacketBuffer d_in;
   size_t d_outPos{0};
+  /* this connection is done, the remote end has closed the connection
+     or something like that. We do not want to try to write to it. */
   bool d_connectionDied{false};
+  /* we are done reading from this connection, but we might still want to
+     write to it to close it properly */
+  bool d_connectionClosing{false};
+  /* Whether we are still waiting for more data to be buffered
+     before writing to the socket (false) or not. */
   bool d_needFlush{false};
+  /* Whether we have data that we want to write to the socket,
+     but the socket is full. */
+  bool d_pendingWrite{false};
 };
 
 class NGHTTP2Headers
index 03fa0bfca5a6bc993915de49ba3f52686e8701fe..9b774f0ed4c17a67cc29a0fefc526b86fbece704 100644 (file)
@@ -300,10 +300,9 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender,
    */
   nghttp2_data_provider data_provider;
 
-  /* we will not use this pointer */
   data_provider.source.ptr = this;
   data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t {
-    auto conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+    auto conn = static_cast<DoHConnectionToBackend*>(user_data);
     auto& request = conn->d_currentStreams.at(stream_id);
     size_t toCopy = 0;
     if (request.d_queryPos < request.d_query.d_buffer.size()) {
diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2-in_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2-in_cc.cc
new file mode 100644 (file)
index 0000000..0ac62b3
--- /dev/null
@@ -0,0 +1,727 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_NO_MAIN
+
+#include <boost/test/unit_test.hpp>
+
+#include "dnswriter.hh"
+#include "dnsdist.hh"
+#include "dnsdist-proxy-protocol.hh"
+#include "dnsdist-nghttp2-in.hh"
+
+#ifdef HAVE_NGHTTP2
+#include <nghttp2/nghttp2.h>
+
+extern std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
+
+BOOST_AUTO_TEST_SUITE(test_dnsdistnghttp2_in_cc)
+
+struct ExpectedStep
+{
+public:
+  enum class ExpectedRequest
+  {
+    handshakeClient,
+    readFromClient,
+    writeToClient,
+    closeClient,
+  };
+
+  ExpectedStep(ExpectedRequest r, IOState n, size_t b = 0, std::function<void(int descriptor)> fn = nullptr) :
+    cb(fn), request(r), nextState(n), bytes(b)
+  {
+  }
+
+  std::function<void(int descriptor)> cb{nullptr};
+  ExpectedRequest request;
+  IOState nextState;
+  size_t bytes{0};
+};
+
+struct ExpectedData
+{
+  PacketBuffer d_proxyProtocolPayload;
+  std::vector<PacketBuffer> d_queries;
+  std::vector<PacketBuffer> d_responses;
+  std::vector<uint16_t> d_responseCodes;
+};
+
+class DOHConnection;
+
+static std::deque<ExpectedStep> s_steps;
+static std::map<uint64_t, ExpectedData> s_connectionContexts;
+static std::map<int, std::unique_ptr<DOHConnection>> s_connectionBuffers;
+static uint64_t s_connectionID{0};
+
+std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d);
+
+std::ostream& operator<<(std::ostream& os, const ExpectedStep::ExpectedRequest d)
+{
+  static const std::vector<std::string> requests = {"handshake with client", "read from client", "write to client", "close connection to client", "connect to the backend", "read from the backend", "write to the backend", "close connection to backend"};
+  os << requests.at(static_cast<size_t>(d));
+  return os;
+}
+
+class DOHConnection
+{
+public:
+  DOHConnection(uint64_t connectionID) :
+    d_session(std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)>(nullptr, nghttp2_session_del)), d_connectionID(connectionID)
+  {
+    const auto& context = s_connectionContexts.at(connectionID);
+    d_clientOutBuffer.insert(d_clientOutBuffer.begin(), context.d_proxyProtocolPayload.begin(), context.d_proxyProtocolPayload.end());
+    
+    nghttp2_session_callbacks* cbs = nullptr;
+    nghttp2_session_callbacks_new(&cbs);
+    std::unique_ptr<nghttp2_session_callbacks, void (*)(nghttp2_session_callbacks*)> callbacks(cbs, nghttp2_session_callbacks_del);
+    cbs = nullptr;
+    nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback);
+    nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback);
+    nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback);
+    nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback);
+    nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback);
+    nghttp2_session* sess = nullptr;
+    nghttp2_session_client_new(&sess, callbacks.get(), this);
+    d_session = std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)>(sess, nghttp2_session_del);
+
+    nghttp2_settings_entry iv[] = {
+      /* rfc7540 section-8.2.2:
+         "Advertising a SETTINGS_MAX_CONCURRENT_STREAMS value of zero disables
+         server push by preventing the server from creating the necessary
+         streams."
+      */
+      {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 0},
+      {NGHTTP2_SETTINGS_ENABLE_PUSH, 0},
+      /* we might want to make the initial window size configurable, but 16M is a large enough default */
+      {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16 * 1024 * 1024}};
+    /* client 24 bytes magic string will be sent by nghttp2 library */
+    auto result = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv) / sizeof(*iv));
+    if (result != 0) {
+      throw std::runtime_error("Error submitting settings:" + std::string(nghttp2_strerror(result)));
+    }
+
+    const std::string host("unit-tests");
+    const std::string path("/dns-query");
+    for (const auto& query : context.d_queries) {
+      const auto querySize = std::to_string(query.size());
+      std::vector<nghttp2_nv> headers;
+      /* Pseudo-headers need to come first (rfc7540 8.1.2.1) */
+      NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::METHOD_NAME, NGHTTP2Headers::HeaderConstantIndexes::METHOD_VALUE);
+      NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::SCHEME_NAME, NGHTTP2Headers::HeaderConstantIndexes::SCHEME_VALUE);
+      NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::AUTHORITY_NAME, host);
+      NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::PATH_NAME, path);
+      NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::ACCEPT_NAME, NGHTTP2Headers::HeaderConstantIndexes::ACCEPT_VALUE);
+      NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_NAME, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_VALUE);
+      NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::USER_AGENT_NAME, NGHTTP2Headers::HeaderConstantIndexes::USER_AGENT_VALUE);
+      NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_LENGTH_NAME, querySize);
+
+      d_position = 0;
+      d_currentQuery = query;
+      nghttp2_data_provider data_provider;
+      data_provider.source.ptr = this;
+      data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t {
+        auto* conn = static_cast<DOHConnection*>(user_data);
+        auto& pos = conn->d_position;
+        const auto& currentQuery = conn->d_currentQuery;
+        size_t toCopy = 0;
+        if (pos < currentQuery.size()) {
+          size_t remaining = currentQuery.size() - pos;
+          toCopy = length > remaining ? remaining : length;
+          memcpy(buf, &currentQuery.at(pos), toCopy);
+          pos += toCopy;
+        }
+
+        if (pos >= currentQuery.size()) {
+          *data_flags |= NGHTTP2_DATA_FLAG_EOF;
+        }
+        return toCopy;
+      };
+
+      auto newStreamId = nghttp2_submit_request(d_session.get(), nullptr, headers.data(), headers.size(), &data_provider, this);
+      if (newStreamId < 0) {
+        throw std::runtime_error("Error submitting HTTP request:" + std::string(nghttp2_strerror(newStreamId)));
+      }
+
+      result = nghttp2_session_send(d_session.get());
+      if (result != 0) {
+        throw std::runtime_error("Error in nghttp2_session_send:" + std::to_string(result));
+      }
+    }
+  }
+
+  std::map<int32_t, PacketBuffer> d_responses;
+  std::map<int32_t, uint16_t> d_responseCodes;
+  std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)> d_session;
+  PacketBuffer d_currentQuery;
+  PacketBuffer d_clientOutBuffer;
+  uint64_t d_connectionID{0};
+  size_t d_position{0};
+
+  size_t submitIncoming(const PacketBuffer& data, size_t pos, size_t toWrite)
+  {
+    ssize_t readlen = nghttp2_session_mem_recv(d_session.get(), &data.at(pos), toWrite);
+    if (readlen < 0) {
+      throw("Fatal error while submitting line " + std::to_string(__LINE__) + ": " + std::string(nghttp2_strerror(static_cast<int>(readlen))));
+    }
+
+    /* just in case, see if we have anything to send */
+    int rv = nghttp2_session_send(d_session.get());
+    if (rv != 0) {
+      throw("Fatal error while sending: " + std::string(nghttp2_strerror(rv)));
+    }
+
+    return readlen;
+  }
+
+private:
+  static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+  {
+    DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+    conn->d_clientOutBuffer.insert(conn->d_clientOutBuffer.end(), data, data + length);
+    return static_cast<ssize_t>(length);
+  }
+
+  static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+  {
+    DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+    if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) {
+      const auto& response = conn->d_responses.at(frame->hd.stream_id);
+      if (conn->d_responseCodes.at(frame->hd.stream_id) != 200U) {
+        return 0;
+      }
+
+      BOOST_REQUIRE_GT(response.size(), sizeof(dnsheader));
+      const auto* dh = reinterpret_cast<const dnsheader*>(response.data());
+      uint16_t id = ntohs(dh->id);
+
+      const auto& expected = s_connectionContexts.at(conn->d_connectionID).d_responses.at(id);
+      BOOST_REQUIRE_EQUAL(expected.size(), response.size());
+      for (size_t idx = 0; idx < response.size(); idx++) {
+        if (expected.at(idx) != response.at(idx)) {
+          cerr << "Mismatch at offset " << idx << ", expected " << std::to_string(response.at(idx)) << " got " << std::to_string(expected.at(idx)) << endl;
+          BOOST_CHECK(false);
+        }
+      }
+    }
+
+    return 0;
+  }
+
+  static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data)
+  {
+    DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+    auto& response = conn->d_responses[stream_id];
+    response.insert(response.end(), data, data + len);
+    return 0;
+  }
+
+  static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data)
+  {
+    DOHConnection* conn = static_cast<DOHConnection*>(user_data);
+
+    const std::string status(":status");
+    if (frame->hd.type == NGHTTP2_HEADERS && frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
+      if (namelen == status.size() && memcmp(status.data(), name, status.size()) == 0) {
+        try {
+          uint16_t responseCode{0};
+          auto expected = s_connectionContexts.at(conn->d_connectionID).d_responseCodes.at((frame->hd.stream_id - 1) / 2);
+          pdns::checked_stoi_into(responseCode, std::string(reinterpret_cast<const char*>(value), valuelen));
+          conn->d_responseCodes[frame->hd.stream_id] = responseCode;
+          if (responseCode != expected) {
+            cerr << "Mismatch response code, expected " << std::to_string(expected) << " got " << std::to_string(responseCode) << endl;
+            BOOST_CHECK(false);
+          }
+        }
+        catch (const std::exception& e) {
+          infolog("Error parsing the status header for stream ID %d: %s", frame->hd.stream_id, e.what());
+          return NGHTTP2_ERR_CALLBACK_FAILURE;
+        }
+      }
+    }
+    return 0;
+  }
+
+  static int on_stream_close_callback(nghttp2_session* session, int32_t stream_id, uint32_t error_code, void* user_data)
+  {
+    return 0;
+  }
+};
+
+class MockupTLSConnection : public TLSConnection
+{
+public:
+  MockupTLSConnection(int descriptor, [[maybe_unused]] bool client = false, [[maybe_unused]] bool needProxyProtocol = false) :
+    d_descriptor(descriptor)
+  {
+    auto connectionID = s_connectionID++;
+    auto conn = std::make_unique<DOHConnection>(connectionID);
+    s_connectionBuffers[d_descriptor] = std::move(conn);
+  }
+
+  ~MockupTLSConnection() {}
+
+  IOState tryHandshake() override
+  {
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::handshakeClient);
+
+    return step.nextState;
+  }
+
+  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
+  {
+    auto& conn = s_connectionBuffers.at(d_descriptor);
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::writeToClient);
+
+    if (step.bytes == 0) {
+      if (step.nextState == IOState::NeedWrite) {
+        return step.nextState;
+      }
+      throw std::runtime_error("Remote host closed the connection");
+    }
+
+    toWrite -= pos;
+    BOOST_REQUIRE_GE(buffer.size(), pos + toWrite);
+
+    if (step.bytes < toWrite) {
+      toWrite = step.bytes;
+    }
+
+    conn->submitIncoming(buffer, pos, toWrite);
+    pos += toWrite;
+
+    return step.nextState;
+  }
+
+  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete = false) override
+  {
+    auto& conn = s_connectionBuffers.at(d_descriptor);
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::readFromClient);
+
+    if (step.bytes == 0) {
+      if (step.nextState == IOState::NeedRead) {
+        return step.nextState;
+      }
+      throw std::runtime_error("Remote host closed the connection");
+    }
+
+    auto& externalBuffer = conn->d_clientOutBuffer;
+    toRead -= pos;
+
+    if (step.bytes < toRead) {
+      toRead = step.bytes;
+    }
+    if (allowIncomplete) {
+      if (toRead > externalBuffer.size()) {
+        toRead = externalBuffer.size();
+      }
+    }
+    else {
+      BOOST_REQUIRE_GE(externalBuffer.size(), toRead);
+    }
+
+    BOOST_REQUIRE_GE(buffer.size(), toRead);
+
+    std::copy(externalBuffer.begin(), externalBuffer.begin() + toRead, buffer.begin() + pos);
+    pos += toRead;
+    externalBuffer.erase(externalBuffer.begin(), externalBuffer.begin() + toRead);
+
+    return step.nextState;
+  }
+
+  IOState tryConnect(bool fastOpen, const ComboAddress& remote) override
+  {
+    throw std::runtime_error("Should not happen");
+  }
+
+  void close() override
+  {
+    auto step = getStep();
+    BOOST_REQUIRE_EQUAL(step.request, ExpectedStep::ExpectedRequest::closeClient);
+  }
+
+  bool hasBufferedData() const override
+  {
+    return false;
+  }
+
+  bool isUsable() const override
+  {
+    return true;
+  }
+
+  std::string getServerNameIndication() const override
+  {
+    return "";
+  }
+
+  std::vector<uint8_t> getNextProtocol() const override
+  {
+    return std::vector<uint8_t>{'h', '2'};
+  }
+
+  LibsslTLSVersion getTLSVersion() const override
+  {
+    return LibsslTLSVersion::TLS13;
+  }
+
+  bool hasSessionBeenResumed() const override
+  {
+    return false;
+  }
+
+  std::vector<std::unique_ptr<TLSSession>> getSessions() override
+  {
+    return {};
+  }
+
+  void setSession(std::unique_ptr<TLSSession>& session) override
+  {
+  }
+
+  std::vector<int> getAsyncFDs() override
+  {
+    return {};
+  }
+
+  /* unused in that context, don't bother */
+  void doHandshake() override
+  {
+  }
+
+  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) override
+  {
+  }
+
+  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0, 0}, bool allowIncomplete = false) override
+  {
+    return 0;
+  }
+
+  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) override
+  {
+    return 0;
+  }
+
+private:
+  ExpectedStep getStep() const
+  {
+    BOOST_REQUIRE(!s_steps.empty());
+    auto step = s_steps.front();
+    s_steps.pop_front();
+
+    if (step.cb) {
+      step.cb(d_descriptor);
+    }
+
+    return step;
+  }
+
+  const int d_descriptor;
+};
+
+#include "test-dnsdistnghttp2_common.hh"
+
+struct TestFixture
+{
+  TestFixture()
+  {
+    s_steps.clear();
+    s_connectionContexts.clear();
+    s_connectionBuffers.clear();
+    s_connectionID = 0;
+    s_mplexer = std::make_unique<MockupFDMultiplexer>();
+  }
+  ~TestFixture()
+  {
+    s_steps.clear();
+    s_connectionContexts.clear();
+    s_connectionBuffers.clear();
+    s_connectionID = 0;
+    s_mplexer.reset();
+  }
+};
+
+BOOST_FIXTURE_TEST_CASE(test_IncomingConnection_SelfAnswered, TestFixture)
+{
+  auto local = getBackendAddress("1", 80);
+  ClientState localCS(local, true, false, false, "", {});
+  localCS.dohFrontend = std::make_shared<DOHFrontend>(std::make_shared<MockupTLSCtx>());
+  localCS.dohFrontend->d_urls.insert("/dns-query");
+
+  TCPClientThreadData threadData;
+  threadData.mplexer = std::make_unique<MockupFDMultiplexer>();
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+
+  size_t counter = 0;
+  DNSName name("powerdns.com.");
+  PacketBuffer query;
+  GenericDNSPacketWriter<PacketBuffer> pwQ(query, name, QType::A, QClass::IN, 0);
+  pwQ.getHeader()->rd = 1;
+  pwQ.getHeader()->id = htons(counter);
+
+  PacketBuffer response;
+  GenericDNSPacketWriter<PacketBuffer> pwR(response, name, QType::A, QClass::IN, 0);
+  pwR.getHeader()->qr = 1;
+  pwR.getHeader()->rd = 1;
+  pwR.getHeader()->ra = 1;
+  pwR.getHeader()->id = htons(counter);
+  pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER);
+  pwR.xfr32BitInt(0x01020304);
+  pwR.commit();
+
+  {
+    /* dnsdist drops the query right away after receiving it, client closes the connection */
+    s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {403U}};
+    s_steps = {
+      /* opening */
+      { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+      /* settings server -> client */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+      /* settings + headers + data client -> server.. */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+      /* .. continued */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+      /* headers + data */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits<size_t>::max() },
+      /* wait for next query, but the client closes the connection */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+      /* server close */
+      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+    };
+
+    auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    state->handleIO();
+  }
+
+  {
+    /* client closes the connection right in the middle of sending the query */
+    s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, { 403U }};
+    s_steps = {
+      /* opening */
+      { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+      /* settings server -> client */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+      /* client sends one byte */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 1 },
+      /* then closes the connection */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+      /* server close */
+      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+    };
+
+    /* mark the incoming FD as always ready */
+    dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
+
+    auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    state->handleIO();
+    while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+      threadData.mplexer->run(&now);
+    }
+  }
+
+  {
+    /* dnsdist sends a response right away, client closes the connection after getting the response */
+    s_processQuery = [response](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      /* self answered */
+      dq.getMutableData() = response;
+      return ProcessQueryResult::SendAnswer;
+    };
+
+    s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}};
+
+    s_steps = {
+      /* opening */
+      { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+      /* settings server -> client */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+      /* settings + headers + data client -> server.. */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+      /* .. continued */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+      /* headers + data */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits<size_t>::max() },
+      /* wait for next query, but the client closes the connection */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+      /* server close */
+      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+    };
+
+    auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    state->handleIO();
+  }
+
+  {
+    /* dnsdist sends a response right away, but the client closes the connection without even reading the response */
+    s_processQuery = [response](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      /* self answered */
+      dq.getMutableData() = response;
+      return ProcessQueryResult::SendAnswer;
+    };
+
+    s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}};
+
+    s_steps = {
+      /* opening */
+      { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+      /* settings server -> client */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+      /* settings + headers + data client -> server.. */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+      /* .. continued */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+      /* we want to send the response but the client closes the connection */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 },
+      /* server close */
+      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+    };
+
+    /* mark the incoming FD as always ready */
+    dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
+
+    auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    state->handleIO();
+    while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+      threadData.mplexer->run(&now);
+    }
+  }
+
+  {
+    /* dnsdist sends a response right away, client closes the connection while getting the response */
+    s_processQuery = [response](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      /* self answered */
+      dq.getMutableData() = response;
+      return ProcessQueryResult::SendAnswer;
+    };
+
+    s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {200U}};
+
+    s_steps = {
+      /* opening */
+      { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+      /* settings server -> client */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+      /* settings + headers + data client -> server.. */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+      /* .. continued */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+      /* headers + data (partial write) */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::NeedWrite, 1 },
+      /* nothing to read after that */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0 },
+      /* then the client closes the connection before we are done  */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 0 },
+      /* server close */
+      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+    };
+
+    /* mark the incoming FD as always ready */
+    dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(-1);
+
+    auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    state->handleIO();
+    while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+      threadData.mplexer->run(&now);
+    }
+  }
+}
+
+BOOST_FIXTURE_TEST_CASE(test_IncomingConnection_BackendTimeout, TestFixture)
+{
+  auto local = getBackendAddress("1", 80);
+  ClientState localCS(local, true, false, false, "", {});
+  localCS.dohFrontend = std::make_shared<DOHFrontend>(std::make_shared<MockupTLSCtx>());
+  localCS.dohFrontend->d_urls.insert("/dns-query");
+
+  TCPClientThreadData threadData;
+  threadData.mplexer = std::make_unique<MockupFDMultiplexer>();
+
+  auto backend = std::make_shared<DownstreamState>(getBackendAddress("42", 53));
+
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+
+  size_t counter = 0;
+  DNSName name("powerdns.com.");
+  PacketBuffer query;
+  GenericDNSPacketWriter<PacketBuffer> pwQ(query, name, QType::A, QClass::IN, 0);
+  pwQ.getHeader()->rd = 1;
+  pwQ.getHeader()->id = htons(counter);
+
+  PacketBuffer response;
+  GenericDNSPacketWriter<PacketBuffer> pwR(response, name, QType::A, QClass::IN, 0);
+  pwR.getHeader()->qr = 1;
+  pwR.getHeader()->rd = 1;
+  pwR.getHeader()->ra = 1;
+  pwR.getHeader()->id = htons(counter);
+  pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER);
+  pwR.xfr32BitInt(0x01020304);
+  pwR.commit();
+
+  {
+    /* dnsdist forwards the query to the backend, which does not answer -> timeout */
+    s_processQuery = [backend](DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
+      selectedBackend = backend;
+      return ProcessQueryResult::PassToBackend;
+    };
+    s_connectionContexts[counter++] = ExpectedData{{}, {query}, {response}, {502U}};
+    s_steps = {
+      /* opening */
+      { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
+      /* settings server -> client */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, 15 },
+      /* settings + headers + data client -> server.. */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 128 },
+      /* .. continued */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 60 },
+        /* trying to read a new request while processing the first one */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead },
+      /* headers + data */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, std::numeric_limits<size_t>::max(), [&threadData](int desc) {
+          /* set the incoming descriptor as ready */
+          dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setReady(desc);
+        }
+      },
+      /* wait for next query, but the client closes the connection */
+      { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
+      /* server close */
+      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done },
+    };
+
+    auto state = std::make_shared<IncomingHTTP2Connection>(ConnectionInfo(&localCS, getBackendAddress("84", 4242)), threadData, now);
+    state->handleIO();
+    TCPResponse resp;
+    resp.d_idstate.d_streamID = 1;
+    state->notifyIOError(now, std::move(resp));
+    while (threadData.mplexer->getWatchedFDCount(false) != 0 || threadData.mplexer->getWatchedFDCount(true) != 0) {
+      threadData.mplexer->run(&now);
+    }
+  }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
+#endif /* HAVE_NGHTTP2 */
index d10e85ef13f58d3ffc0c627547a9ff2de25e62e7..3e5bb1631221ad4c1582c4ca17bcf3ac54fb9439 100644 (file)
@@ -486,110 +486,7 @@ private:
   bool d_client{false};
 };
 
-class MockupTLSCtx : public TLSCtx
-{
-public:
-  ~MockupTLSCtx()
-  {
-  }
-
-  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
-  {
-    return std::make_unique<MockupTLSConnection>(socket);
-  }
-
-  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
-  {
-    return std::make_unique<MockupTLSConnection>(socket, true, d_needProxyProtocol);
-  }
-
-  void rotateTicketsKey(time_t now) override
-  {
-  }
-
-  size_t getTicketsKeysCount() override
-  {
-    return 0;
-  }
-
-  std::string getName() const override
-  {
-    return "Mockup TLS";
-  }
-
-  bool d_needProxyProtocol{false};
-};
-
-class MockupFDMultiplexer : public FDMultiplexer
-{
-public:
-  MockupFDMultiplexer()
-  {
-  }
-
-  ~MockupFDMultiplexer()
-  {
-  }
-
-  int run(struct timeval* tv, int timeout = 500) override
-  {
-    int ret = 0;
-
-    gettimeofday(tv, nullptr); // MANDATORY
-
-    /* 'ready' might be altered by a callback while we are iterating */
-    const auto readyFDs = ready;
-    for (const auto fd : readyFDs) {
-      {
-        const auto& it = d_readCallbacks.find(fd);
-
-        if (it != d_readCallbacks.end()) {
-          it->d_callback(it->d_fd, it->d_parameter);
-        }
-      }
-
-      {
-        const auto& it = d_writeCallbacks.find(fd);
-
-        if (it != d_writeCallbacks.end()) {
-          it->d_callback(it->d_fd, it->d_parameter);
-        }
-      }
-    }
-
-    return ret;
-  }
-
-  void getAvailableFDs(std::vector<int>& fds, int timeout) override
-  {
-  }
-
-  void addFD(int fd, FDMultiplexer::EventKind kind) override
-  {
-  }
-
-  void removeFD(int fd, FDMultiplexer::EventKind) override
-  {
-  }
-
-  string getName() const override
-  {
-    return "mockup";
-  }
-
-  void setReady(int fd)
-  {
-    ready.insert(fd);
-  }
-
-  void setNotReady(int fd)
-  {
-    ready.erase(fd);
-  }
-
-private:
-  std::set<int> ready;
-};
+#include "test-dnsdistnghttp2_common.hh"
 
 class MockupQuerySender : public TCPQuerySender
 {
@@ -641,36 +538,6 @@ public:
   bool d_error{false};
 };
 
-static bool isIPv6Supported()
-{
-  try {
-    ComboAddress addr("[2001:db8:53::1]:53");
-    auto socket = std::make_unique<Socket>(addr.sin4.sin_family, SOCK_STREAM, 0);
-    socket->setNonBlocking();
-    int res = SConnectWithTimeout(socket->getHandle(), addr, timeval{0, 0});
-    if (res == 0 || res == EINPROGRESS) {
-      return true;
-    }
-    return false;
-  }
-  catch (const std::exception& e) {
-    return false;
-  }
-}
-
-static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t port)
-{
-  static const bool useV6 = isIPv6Supported();
-
-  if (useV6) {
-    return ComboAddress("2001:db8:53::" + lastDigit, port);
-  }
-
-  return ComboAddress("192.0.2." + lastDigit, port);
-}
-
-static std::unique_ptr<FDMultiplexer> s_mplexer;
-
 struct TestFixture
 {
   TestFixture()
diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_common.hh b/pdns/dnsdistdist/test-dnsdistnghttp2_common.hh
new file mode 100644 (file)
index 0000000..5c79679
--- /dev/null
@@ -0,0 +1,157 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+class MockupTLSCtx : public TLSCtx
+{
+public:
+  ~MockupTLSCtx()
+  {
+  }
+
+  std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
+  {
+    return std::make_unique<MockupTLSConnection>(socket);
+  }
+
+  std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
+  {
+    return std::make_unique<MockupTLSConnection>(socket, true, d_needProxyProtocol);
+  }
+
+  void rotateTicketsKey(time_t now) override
+  {
+  }
+
+  size_t getTicketsKeysCount() override
+  {
+    return 0;
+  }
+
+  std::string getName() const override
+  {
+    return "Mockup TLS";
+  }
+
+  bool d_needProxyProtocol{false};
+};
+
+class MockupFDMultiplexer : public FDMultiplexer
+{
+public:
+  MockupFDMultiplexer()
+  {
+  }
+
+  ~MockupFDMultiplexer()
+  {
+  }
+
+  int run(struct timeval* tv, int timeout = 500) override
+  {
+    int ret = 0;
+
+    gettimeofday(tv, nullptr); // MANDATORY
+
+    /* 'ready' might be altered by a callback while we are iterating */
+    const auto readyFDs = ready;
+    for (const auto fd : readyFDs) {
+      {
+        const auto& it = d_readCallbacks.find(fd);
+
+        if (it != d_readCallbacks.end()) {
+          it->d_callback(it->d_fd, it->d_parameter);
+        }
+      }
+
+      {
+        const auto& it = d_writeCallbacks.find(fd);
+
+        if (it != d_writeCallbacks.end()) {
+          it->d_callback(it->d_fd, it->d_parameter);
+        }
+      }
+    }
+
+    return ret;
+  }
+
+  void getAvailableFDs(std::vector<int>& fds, int timeout) override
+  {
+  }
+
+  void addFD(int fd, FDMultiplexer::EventKind kind) override
+  {
+  }
+
+  void removeFD(int fd, FDMultiplexer::EventKind) override
+  {
+  }
+
+  string getName() const override
+  {
+    return "mockup";
+  }
+
+  void setReady(int fd)
+  {
+    ready.insert(fd);
+  }
+
+  void setNotReady(int fd)
+  {
+    ready.erase(fd);
+  }
+
+private:
+  std::set<int> ready;
+};
+
+static bool isIPv6Supported()
+{
+  try {
+    ComboAddress addr("[2001:db8:53::1]:53");
+    auto socket = std::make_unique<Socket>(addr.sin4.sin_family, SOCK_STREAM, 0);
+    socket->setNonBlocking();
+    int res = SConnectWithTimeout(socket->getHandle(), addr, timeval{0, 0});
+    if (res == 0 || res == EINPROGRESS) {
+      return true;
+    }
+    return false;
+  }
+  catch (const std::exception& e) {
+    return false;
+  }
+}
+
+static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t port)
+{
+  static const bool useV6 = isIPv6Supported();
+
+  if (useV6) {
+    return ComboAddress("2001:db8:53::" + lastDigit, port);
+  }
+
+  return ComboAddress("192.0.2." + lastDigit, port);
+}
+
+static std::unique_ptr<FDMultiplexer> s_mplexer;
index 22e137c24b3b1629f606342ed8533df93b537629..dedfd97d2b330cc2f19e6f144f32a54a42e0e54d 100644 (file)
@@ -62,7 +62,7 @@ void handleResponseSent(const InternalQueryState& ids, double udiff, const Combo
 {
 }
 
-static std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
+std::function<ProcessQueryResult(DNSQuestion& dq, std::shared_ptr<DownstreamState>& selectedBackend)> s_processQuery;
 
 ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
 {
index 850273eb8a678b4b886ecc34d97bb8fa6c0f15bf..c51a930c04f8566c09a959e927374133f4a11a5a 100644 (file)
@@ -56,7 +56,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs
 
 bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
 {
-  return false;
+  return true;
 }
 
 namespace dnsdist {
index 1e7968aa3d2d98a78cb65a8209f1113c8f6444cb..75068bea027d1a637482a8dc8846b21bf0677d77 100644 (file)
@@ -987,21 +987,32 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return conn
 
     @classmethod
-    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
+    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, useProxyProtocol=False, conn=None):
         url = cls.getDOHGetURL(baseurl, query, rawQuery)
-        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
-        response_headers = BytesIO()
-        #conn.setopt(pycurl.VERBOSE, True)
-        conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
-        # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
-        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
+
+        if not conn:
+            print('creating a new connection')
+            conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+            # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
+            conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
+
         if useHTTPS:
+            print("disabling verify")
             conn.setopt(pycurl.SSL_VERIFYPEER, 1)
             conn.setopt(pycurl.SSL_VERIFYHOST, 2)
             if caFile:
                 conn.setopt(pycurl.CAINFO, caFile)
 
+        if useProxyProtocol:
+            print('enabling PP')
+            # 274 is CURLOPT_HAPROXYPROTOCOL
+            conn.setopt(274, 1)
+
+        response_headers = BytesIO()
+        #conn.setopt(pycurl.VERBOSE, True)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+
         conn.setopt(pycurl.HTTPHEADER, customHeaders)
         conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
 
@@ -1014,6 +1025,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         receivedQuery = None
         message = None
         cls._response_headers = ''
+        print('performing')
         data = conn.perform_rb()
         cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
         if cls._rcode == 200 and not rawResponse:
@@ -1076,8 +1088,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         cls._response_headers = response_headers.getvalue()
         return (receivedQuery, message)
 
-    def sendDOHQueryWrapper(self, query, response, useQueue=True):
-        return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
+    def sendDOHQueryWrapper(self, query, response, useQueue=True, useProxyProtocol=False):
+        return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, useProxyProtocol=useProxyProtocol)
 
     def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
         return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
index ae6aac46a44fdd1c8a17eb873186187470024bf1..f9fce6be567dbfb172223b8861b3f1d386f56b23 100644 (file)
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 
+import base64
 import dns
 import os
 import time
@@ -32,6 +33,7 @@ class DOHTests(object):
     addAction(HTTPPathRegexRule("^/PowerDNS-[0-9]"), SpoofAction("6.7.8.9"))
     addAction("http-status-action.doh.tests.powerdns.com.", HTTPStatusAction(200, "Plaintext answer", "text/plain"))
     addAction("http-status-action-redirect.doh.tests.powerdns.com.", HTTPStatusAction(307, "https://doh.powerdns.org"))
+    addAction("no-backend.doh.tests.powerdns.com.", PoolAction('this-pool-has-no-backend'))
 
     function dohHandler(dq)
       if dq:getHTTPScheme() == 'https' and dq:getHTTPHost() == '%s:%d' and dq:getHTTPPath() == '/' and dq:getHTTPQueryString() == '' then
@@ -235,9 +237,133 @@ class DOHTests(object):
         (_, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, caFile=self._caCert, query=query, response=None, useQueue=False)
         self.assertEqual(receivedResponse, expectedResponse)
 
+    def testDOHWithoutQuery(self):
+        """
+        DOH: Empty GET query
+        """
+        name = 'empty-get.doh.tests.powerdns.com.'
+        url = self._dohBaseURL
+        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+        conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+        conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+        conn.setopt(pycurl.CAINFO, self._caCert)
+        data = conn.perform_rb()
+        rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        self.assertEqual(rcode, 400)
+
+    def testDOHShortPath(self):
+        """
+        DOH: Short path in GET query
+        """
+        name = 'short-path-get.doh.tests.powerdns.com.'
+        url = self._dohBaseURL + '/AA'
+        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+        conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+        conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+        conn.setopt(pycurl.CAINFO, self._caCert)
+        data = conn.perform_rb()
+        rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        self.assertEqual(rcode, 404)
+
+    def testDOHQueryNoParameter(self):
+        """
+        DOH: No parameter GET query
+        """
+        name = 'no-parameter-get.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        wire = query.to_wire()
+        b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+        url = self._dohBaseURL + '?not-dns=' + b64
+        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+        conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+        conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+        conn.setopt(pycurl.CAINFO, self._caCert)
+        data = conn.perform_rb()
+        rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        self.assertEqual(rcode, 400)
+
+    def testDOHQueryInvalidBase64(self):
+        """
+        DOH: Invalid Base64 GET query
+        """
+        name = 'invalid-b64-get.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        wire = query.to_wire()
+        url = self._dohBaseURL + '?dns=' + '_-~~~~-_'
+        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+        conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+        conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+        conn.setopt(pycurl.CAINFO, self._caCert)
+        data = conn.perform_rb()
+        rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        self.assertEqual(rcode, 400)
+
+    def testDOHInvalidDNSHeaders(self):
+        """
+        DOH: Invalid DNS headers
+        """
+        name = 'invalid-dns-headers.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.flags |= dns.flags.QR
+        wire = query.to_wire()
+        b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+        url = self._dohBaseURL + '?dns=' + b64
+        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+        conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+        conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+        conn.setopt(pycurl.CAINFO, self._caCert)
+        data = conn.perform_rb()
+        rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        self.assertEqual(rcode, 400)
+
+    def testDOHQueryInvalidMethod(self):
+        """
+        DOH: Invalid method
+        """
+        if self._dohLibrary == 'h2o':
+            raise unittest.SkipTest('h2o does not check the HTTP method')
+        name = 'invalid-method.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        wire = query.to_wire()
+        b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+        url = self._dohBaseURL + '?dns=' + b64
+        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
+        conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+        conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+        conn.setopt(pycurl.CAINFO, self._caCert)
+        conn.setopt(pycurl.CUSTOMREQUEST, 'PATCH')
+        data = conn.perform_rb()
+        rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        self.assertEqual(rcode, 400)
+
+    def testDOHQueryInvalidALPN(self):
+        """
+        DOH: Invalid ALPN
+        """
+        alpn = ['bogus-alpn']
+        conn = self.openTLSConnection(self._dohServerPort, self._serverName, self._caCert, alpn=alpn)
+        try:
+            conn.send('AAAA')
+            response = conn.recv(65535)
+            self.assertFalse(response)
+        except:
+            pass
+
     def testDOHInvalid(self):
         """
-        DOH: Invalid query
+        DOH: Invalid DNS query
         """
         name = 'invalid.doh.tests.powerdns.com.'
         invalidQuery = dns.message.make_query(name, 'A', 'IN', use_edns=False)
@@ -268,13 +394,43 @@ class DOHTests(object):
         self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
         self.assertEqual(response, receivedResponse)
 
-    def testDOHWithoutQuery(self):
+    def testDOHInvalidHeaderName(self):
         """
-        DOH: Empty GET query
+        DOH: Invalid HTTP header name query
         """
-        name = 'empty-get.doh.tests.powerdns.com.'
-        url = self._dohBaseURL
-        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2.0)
+        name = 'invalid-header-name.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.id = 0
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery.id = 0
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        # this header is invalid, see rfc9113 section 8.2.1. Field Validity
+        customHeaders = ['{}: test']
+        try:
+            (receivedQuery, receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, customHeaders=customHeaders)
+            self.assertFalse(receivedQuery)
+            self.assertFalse(receivedResponse)
+        except pycurl.error:
+            pass
+
+    def testDOHNoBackend(self):
+        """
+        DOH: No backend
+        """
+        if self._dohLibrary == 'h2o':
+            raise unittest.SkipTest('h2o does not check the HTTP method')
+        name = 'no-backend.doh.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        wire = query.to_wire()
+        b64 = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+        url = self._dohBaseURL + '?dns=' + b64
+        conn = self.openDOHConnection(self._dohServerPort, self._caCert, timeout=2)
         conn.setopt(pycurl.URL, url)
         conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (self._serverName, self._dohServerPort)])
         conn.setopt(pycurl.SSL_VERIFYPEER, 1)
@@ -282,7 +438,7 @@ class DOHTests(object):
         conn.setopt(pycurl.CAINFO, self._caCert)
         data = conn.perform_rb()
         rcode = conn.getinfo(pycurl.RESPONSE_CODE)
-        self.assertEqual(rcode, 400)
+        self.assertEqual(rcode, 403)
 
     def testDOHEmptyPOST(self):
         """