]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Better downstream DoH support, better DoT/DoH ALPN handling
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 6 Aug 2021 15:01:03 +0000 (17:01 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 13 Sep 2021 13:28:27 +0000 (15:28 +0200)
16 files changed:
m4/pdns_with_gnutls.m4
m4/pdns_with_libssl.m4
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-nghttp2.cc
pdns/dnsdistdist/dnsdist-nghttp2.hh
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/tcpiohandler-mplexer.hh
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/libssl.cc
pdns/libssl.hh
pdns/tcpiohandler.cc
pdns/tcpiohandler.hh

index c693dff81fdb7f29044845de0e678ac6d1a0e044..b6ad100bbbd67199aba9b04c789031f6d66be3d9 100644 (file)
@@ -18,7 +18,7 @@ AC_DEFUN([PDNS_WITH_GNUTLS], [
         save_LIBS=$LIBS
         CFLAGS="$GNUTLS_CFLAGS $CFLAGS"
         LIBS="$GNUTLS_LIBS $LIBS"
-        AC_CHECK_FUNCS([gnutls_memset gnutls_session_set_verify_cert gnutls_session_get_verify_cert_status])
+        AC_CHECK_FUNCS([gnutls_memset gnutls_session_set_verify_cert gnutls_session_get_verify_cert_status gnutls_alpn_set_protocols])
         CFLAGS=$save_CFLAGS
         LIBS=$save_LIBS
 
index c42905fd1dbabd755fa78a02dd48841c76a28c57..3e32bc40864f0116cf7e01bc89736a210e6b9f35 100644 (file)
@@ -17,7 +17,7 @@ AC_DEFUN([PDNS_WITH_LIBSSL], [
         save_LIBS=$LIBS
         CFLAGS="$LIBSSL_CFLAGS $CFLAGS"
         LIBS="$LIBSSL_LIBS -lcrypto $LIBS"
-        AC_CHECK_FUNCS([SSL_CTX_set_ciphersuites OCSP_basic_sign SSL_CTX_set_num_tickets SSL_CTX_set_keylog_callback SSL_CTX_get0_privatekey SSL_CTX_set_min_proto_version SSL_set_hostflags])
+        AC_CHECK_FUNCS([SSL_CTX_set_ciphersuites OCSP_basic_sign SSL_CTX_set_num_tickets SSL_CTX_set_keylog_callback SSL_CTX_get0_privatekey SSL_CTX_set_min_proto_version SSL_set_hostflags SSL_CTX_set_alpn_protos SSL_CTX_set_next_proto_select_cb SSL_get0_alpn_selected SSL_get0_next_proto_negotiated SSL_CTX_set_alpn_select_cb])
         CFLAGS=$save_CFLAGS
         LIBS=$save_LIBS
 
index b1f8465b1b7944680948e76d5cbfc55dadba3321..c63486621d47fe135d37c3bbc25f7270e9c5eed4 100644 (file)
@@ -40,6 +40,7 @@
 #ifdef LUAJIT_VERSION
 #include "dnsdist-lua-ffi.hh"
 #endif /* LUAJIT_VERSION */
+#include "dnsdist-nghttp2.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rings.hh"
 #include "dnsdist-secpoll.hh"
@@ -528,10 +529,16 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         }
 
         ret->d_tlsCtx = getTLSContext(tlsParams);
-      }
 
-      if (vars.count("dohPath")) {
-        ret->d_dohPath = boost::get<string>(vars.at("dohPath"));
+        if (vars.count("dohPath")) {
+          ret->d_dohPath = boost::get<string>(vars.at("dohPath"));
+          if (ret->d_tlsCtx) {
+            setupDoHClientProtocolNegotiation(ret->d_tlsCtx);
+          }
+        }
+        else {
+          setupDoTProtocolNegotiation(ret->d_tlsCtx);
+        }
       }
 
       /* this needs to be done _AFTER_ the order has been set,
index 4e0fdec163c05c7ffe32e871704fc6eff9b9e7b8..2ca6320c7f0d07f4a42b6b963acef1b2b4677c33 100644 (file)
@@ -592,6 +592,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
 
   prependSizeToTCPQuery(state->d_buffer, 0);
 
+#warning FIXME: handle DoH backends here
   auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
 
   bool proxyProtocolPayloadAdded = false;
@@ -784,7 +785,15 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
           /* the state might have been updated in the meantime, we don't want to override it
              in that case */
           if (state->active() && state->d_state != IncomingTCPConnectionState::State::idle) {
-            iostate = state->d_ioState->getState();
+            if (state->d_ioState->isWaitingForRead()) {
+              iostate = IOState::NeedRead;
+            }
+            else if (state->d_ioState->isWaitingForWrite()) {
+              iostate = IOState::NeedWrite;
+            }
+            else {
+              iostate = IOState::Done;
+            }
           }
         }
         else {
@@ -860,9 +869,9 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
         ++state->d_ci.cs->tcpDiedSendingResponse;
       }
 
-      if (state->d_ioState->getState() == IOState::NeedWrite || state->d_queriesCount == 0) {
+      if (state->d_ioState->isWaitingForWrite() || state->d_queriesCount == 0) {
         DEBUGLOG("Got an exception while handling TCP query: "<<e.what());
-        vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->getState() == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+        vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->isWaitingForRead() ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
       }
       else {
         vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what());
@@ -1018,15 +1027,19 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
     delete tmp;
     tmp = nullptr;
 
-    auto downstream = DownstreamConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now);
+    try {
+      auto downstream = DownstreamConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now);
 
-    prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
-    downstream->queueQuery(tqs, std::move(query));
+      prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
+      downstream->queueQuery(tqs, std::move(query));
+    }
+    catch (...) {
+      tqs->notifyIOError(std::move(query.d_idstate), now);
+    }
   }
   catch (...) {
     delete tmp;
     tmp = nullptr;
-    throw;
   }
 }
 
index 592a21f6d4d64dbcf7adf5068d9de6c076536b9a..b248a98a1e7c3976028002acb6abeecd528dd973 100644 (file)
@@ -34,9 +34,6 @@
 #include "threadname.hh"
 #include "sstuff.hh"
 
-#warning remove me
-#include "dnswriter.hh"
-
 std::atomic<uint64_t> g_dohStatesDumpRequested{0};
 std::unique_ptr<DoHClientCollection> g_dohClientThreads{nullptr};
 
@@ -45,28 +42,24 @@ class DoHConnectionToBackend: public TCPConnectionToBackend
 public:
   DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now);
 
-  void handleTimeout(const struct timeval& now, bool write) override
-  {
-#warning FIXME: we should notify the owners of pending queries / responses
-  }
-
+  void handleTimeout(const struct timeval& now, bool write) override;
   void queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query) override;
 
   std::string toString() const override
   {
     ostringstream o;
-    //o << "DoH connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses";
-    o << "DoH connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket");
+    o << "DoH connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", "<<getConcurrentStreamsCount()<<" streams";
     return o.str();
   }
 
+  bool canBeReused() const override;
+
 private:
   static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data);
   static int on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data);
   static int on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data);
   static int on_stream_close_callback(nghttp2_session* session, int32_t stream_id, uint32_t error_code, void* user_data);
   static int on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data);
-  static int on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data);
   static int on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data);
   static void handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
   static void handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
@@ -78,31 +71,103 @@ private:
     std::shared_ptr<TCPQuerySender> d_sender{nullptr};
     TCPQuery d_query;
     PacketBuffer d_buffer;
+    uint16_t d_responseCode{0};
     bool d_finished{false};
   };
+  void addToIOState(IOState state, FDMultiplexer::callbackfunc_t callback);
   void updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback);
   void stopIO();
   void handleResponse(PendingRequest&& request);
+  void handleResponseError(PendingRequest&& request, const struct timeval& now);
+  uint32_t getConcurrentStreamsCount() const;
 
-  //std::deque<TCPQuery> d_pendingQueries;
+  size_t getUsageCount() const
+  {
+    auto ref = shared_from_this();
+    return ref.use_count();
+  }
+
+  static const std::unordered_map<std::string, std::string> s_constants;
 
   std::unique_ptr<nghttp2_session, void(*)(nghttp2_session*)> d_session{nullptr, nghttp2_session_del};
   std::unordered_map<int32_t, PendingRequest> d_currentStreams;
   PacketBuffer d_out;
   PacketBuffer d_in;
+  size_t d_queryPos{0};
   size_t d_outPos{0};
   size_t d_inPos{0};
+  uint32_t d_highestStreamID{0};
 };
 
+class DownstreamDoHConnectionsManager
+{
+public:
+  static std::shared_ptr<DoHConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now);
+  static void releaseDownstreamConnection(std::shared_ptr<DoHConnectionToBackend>&& conn);
+  static void cleanupClosedConnections(struct timeval now);
+  static size_t clear();
+
+  static void setMaxCachedConnectionsPerDownstream(size_t max)
+  {
+    s_maxCachedConnectionsPerDownstream = max;
+  }
+
+  static void setCleanupInterval(uint16_t interval)
+  {
+    s_cleanupInterval = interval;
+  }
+
+private:
+  static thread_local map<boost::uuids::uuid, std::deque<std::shared_ptr<DoHConnectionToBackend>>> t_downstreamConnections;
+  static size_t s_maxCachedConnectionsPerDownstream;
+  static time_t s_nextCleanup;
+  static uint16_t s_cleanupInterval;
+};
+
+uint32_t DoHConnectionToBackend::getConcurrentStreamsCount() const
+{
+  return d_currentStreams.size();
+}
 
 void DoHConnectionToBackend::handleResponse(PendingRequest&& request)
 {
-  cerr<<"handle response!"<<endl;
   struct timeval now;
   gettimeofday(&now, nullptr);
   request.d_sender->handleResponse(now, TCPResponse(std::move(request.d_buffer), std::move(request.d_query.d_idstate), shared_from_this()));
 }
 
+void DoHConnectionToBackend::handleResponseError(PendingRequest&& request, const struct timeval& now)
+{
+  request.d_sender->notifyIOError(std::move(request.d_query.d_idstate), now);
+}
+
+void DoHConnectionToBackend::handleTimeout(const struct timeval& now, bool write)
+{
+  d_connectionDied = true;
+  for (auto& request : d_currentStreams) {
+    handleResponseError(std::move(request.second), now);
+  }
+  d_currentStreams.clear();
+}
+
+bool DoHConnectionToBackend::canBeReused() const
+{
+  if (d_connectionDied) {
+    return false;
+  }
+  const uint32_t maximumStreamID = (static_cast<uint32_t>(1) << 31) - 1;
+  if (d_highestStreamID == maximumStreamID) {
+    return false;
+  }
+
+  //cerr<<"Got "<<getConcurrentStreamsCount()<<" concurrent streams, max is "<<nghttp2_session_get_remote_settings(d_session.get(), NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS)<<endl;
+  if (nghttp2_session_get_remote_settings(d_session.get(), NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS) <= getConcurrentStreamsCount()) {
+    return false;
+  }
+
+  return true;
+}
+
 #define MAKE_NV(NAME, VALUE, VALUELEN)                                         \
   {                                                                            \
     (uint8_t *)NAME, (uint8_t *)VALUE, sizeof(NAME) - 1, VALUELEN,             \
@@ -115,6 +180,11 @@ void DoHConnectionToBackend::handleResponse(PendingRequest&& request)
         NGHTTP2_NV_FLAG_NONE                                                   \
   }
 
+const std::unordered_map<std::string, std::string> DoHConnectionToBackend::s_constants = {
+  { "method-name", ":method" },
+  { "method-value", "POST" },
+};
+
 void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query)
 {
   /* we could use nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_NAME and nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_VALUE
@@ -122,8 +192,9 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender,
      and that it is already lowercased. */
   auto payloadSize = std::to_string(query.d_buffer.size());
   d_currentQuery = std::move(query);
+  d_queryPos = 0;
   const nghttp2_nv hdrs[] = {
-      MAKE_NV2(":method", "POST"),
+    { const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(s_constants.at("method-name").c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(s_constants.at("method-value").c_str())), s_constants.at("method-name").size(), s_constants.at("method-value").size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE },
       MAKE_NV2(":scheme", "https"),
       MAKE_NV(":authority", d_ds->d_tlsSubjectName.c_str(), d_ds->d_tlsSubjectName.size()),
       MAKE_NV(":path", d_ds->d_dohPath.c_str(), d_ds->d_dohPath.size()),
@@ -135,36 +206,36 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender,
 
   /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key in nva (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set
    */
-  cerr<<"Remote size window is "<<nghttp2_session_get_remote_window_size(d_session.get())<<endl;
 
   nghttp2_data_provider data_provider;
   data_provider.source.ptr = this;
   data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t
   {
-    cerr<<"in data provider"<<endl;
     auto userData = reinterpret_cast<DoHConnectionToBackend*>(user_data);
-    if (userData->d_inPos >= userData->d_currentQuery.d_buffer.size()) {
+    size_t toCopy = 0;
+    if (userData->d_queryPos < userData->d_currentQuery.d_buffer.size()) {
+      size_t remaining = userData->d_currentQuery.d_buffer.size()- userData->d_queryPos;
+      toCopy = length > remaining ? remaining : length;
+      memcpy(buf, &userData->d_currentQuery.d_buffer.at(userData->d_queryPos), toCopy);
+      userData->d_queryPos += toCopy;
+    }
+
+    if (userData->d_queryPos >= userData->d_currentQuery.d_buffer.size()) {
        *data_flags |= NGHTTP2_DATA_FLAG_EOF;
-       cerr<<"EOF"<<endl;
-       return 0;
-    }
-    size_t remaining = userData->d_currentQuery.d_buffer.size()- userData->d_inPos;
-    size_t toCopy = length > remaining ? remaining : length;
-    memcpy(buf, &userData->d_currentQuery.d_buffer.at(userData->d_inPos), toCopy);
-    userData->d_inPos += toCopy;
-    cerr<<toCopy<<" written"<<endl;
+    }
     return toCopy;
   };
 
   auto stream_id = nghttp2_submit_request(d_session.get(), nullptr, hdrs, sizeof(hdrs)/sizeof(*hdrs), &data_provider, this);
   if (stream_id < 0) {
-    cerr<<"Could not submit HTTP request: "<<nghttp2_strerror(stream_id)<<endl;
-    return;
+    d_connectionDied = true;
+    throw std::runtime_error("Error submitting HTTP request:" + std::string(nghttp2_strerror(stream_id)));
   }
-  cerr<<"stream ID is "<<stream_id<<endl;
+  //cerr<<"stream ID is "<<stream_id<<" for a query of size "<<payloadSize<<endl;
+
   auto rv = nghttp2_session_send(d_session.get());
-  cerr<<"nghttp2_session_send returned "<<rv<<endl;
   if (rv != 0) {
+    d_connectionDied = true;
     throw std::runtime_error("Error in nghttp2_session_send:" + std::to_string(rv));
   }
   PendingRequest request;
@@ -172,11 +243,13 @@ void DoHConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender,
   request.d_sender = std::move(sender);
   auto insertPair = d_currentStreams.insert({stream_id, std::move(request)});
   if (!insertPair.second) {
-    cerr<<"collision!!"<<endl;
     /* there is a stream ID collision, something is very wrong! */
+    d_connectionDied = true;
     nghttp2_session_terminate_session(d_session.get(), NGHTTP2_NO_ERROR);
     throw std::runtime_error("Stream ID collision");
   }
+
+  d_highestStreamID = stream_id;
 }
 
 class DoHClientThreadData
@@ -195,7 +268,6 @@ void DoHConnectionToBackend::handleIO(std::shared_ptr<DoHConnectionToBackend>& c
 
 void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
 {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<", param is "<<param.type().name()<<endl;
   auto conn = boost::any_cast<std::shared_ptr<DoHConnectionToBackend>>(param);
   if (fd != conn->getHandle()) {
     throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle()));
@@ -205,23 +277,22 @@ void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::fun
   do {
     conn->d_inPos = 0;
     conn->d_in.resize(conn->d_in.size() + 512);
-    cerr<<"trying to read "<<conn->d_in.size()<<endl;
+    //cerr<<"trying to read "<<conn->d_in.size()<<endl;
     try {
       IOState newState = conn->d_handler->tryRead(conn->d_in, conn->d_inPos, conn->d_in.size(), true);
       // userData.d_handler->tryRead(userData.d_in, pos, userData.d_in.size());
-      cerr<<"got a "<<(int)newState<<" state and "<<conn->d_inPos<<" bytes"<<endl;
+      //cerr<<"got a "<<(int)newState<<" state and "<<conn->d_inPos<<" bytes"<<endl;
       conn->d_in.resize(conn->d_inPos);
       if (newState == IOState::Done) {
         auto readlen = nghttp2_session_mem_recv(conn->d_session.get(), conn->d_in.data(), conn->d_inPos);
-        cerr<<"nghttp2_session_mem_recv returned "<<readlen<<endl;
+        //cerr<<"nghttp2_session_mem_recv returned "<<readlen<<endl;
         /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
            all data should be consumed before returning */
         if (readlen > 0 && static_cast<size_t>(readlen) < conn->d_inPos) {
           cerr<<"Fatal error: "<<nghttp2_strerror((int)readlen)<<endl;
           return;
         }
-        int rv = nghttp2_session_send(conn->d_session.get());
-        cerr<<"nghttp2_session_send returned "<<rv<<endl;
+        nghttp2_session_send(conn->d_session.get());
       }
       else {
         if (newState == IOState::NeedWrite) {
@@ -232,34 +303,30 @@ void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::fun
       }
     }
     catch (const std::exception& e) {
-      cerr<<"got exception "<<e.what()<<endl;
+      cerr<<"Exception while trying to read from HTTP backend connection: "<<e.what()<<endl;
       break;
     }
   }
-  while (true);
-
-  //struct timeval now;
-  //gettimeofday(&now, nullptr);
-  //handleIO(conn, now);
+  while (conn->getConcurrentStreamsCount() > 0);
 }
 
 void DoHConnectionToBackend::handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
 {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<", param is "<<param.type().name()<<endl;
   auto conn = boost::any_cast<std::shared_ptr<DoHConnectionToBackend>>(param);
   if (fd != conn->getHandle()) {
     throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle()));
   }
   IOStateGuard ioGuard(conn->d_ioState);
 
-  cerr<<"trying to write "<<conn->d_out.size()-conn->d_outPos<<endl;
+  //cerr<<"trying to write "<<conn->d_out.size()-conn->d_outPos<<endl;
   try {
     IOState newState = conn->d_handler->tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
-    cerr<<"got a "<<(int)newState<<" state, "<<conn->d_out.size()-conn->d_inPos<<" bytes remaining"<<endl;
+    //cerr<<"got a "<<(int)newState<<" state, "<<conn->d_out.size()-conn->d_outPos<<" bytes remaining"<<endl;
     if (newState == IOState::NeedRead) {
       conn->updateIO(IOState::NeedRead, handleWritableIOCallback);
     }
     else if (newState == IOState::Done) {
+      ++conn->d_queries;
       conn->d_out.clear();
       conn->d_outPos = 0;
       conn->stopIO();
@@ -268,12 +335,8 @@ void DoHConnectionToBackend::handleWritableIOCallback(int fd, FDMultiplexer::fun
     ioGuard.release();
   }
   catch (const std::exception& e) {
-    cerr<<"got exception "<<e.what()<<endl;
+    cerr<<"Exception while trying to write (ready) to HTTP backend connection: "<<e.what()<<endl;
   }
-
-  //struct timeval now;
-  //gettimeofday(&now, nullptr);
-  //handleIO(conn, now);
 }
 
 void DoHConnectionToBackend::stopIO()
@@ -308,27 +371,53 @@ void DoHConnectionToBackend::updateIO(IOState newState, FDMultiplexer::callbackf
   }
 }
 
+void DoHConnectionToBackend::addToIOState(IOState state, FDMultiplexer::callbackfunc_t callback)
+{
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+  boost::optional<struct timeval> ttd{boost::none};
+  if (state == IOState::NeedRead) {
+    ttd = getBackendReadTTD(now);
+  }
+  else if (isFresh() && d_queries == 0) {
+    /* first write just after the non-blocking connect */
+    ttd = getBackendConnectTTD(now);
+  }
+  else {
+    ttd = getBackendWriteTTD(now);
+  }
+
+  auto shared = std::dynamic_pointer_cast<DoHConnectionToBackend>(shared_from_this());
+  if (shared) {
+    if (state == IOState::NeedRead) {
+      d_ioState->add(state, callback, shared, ttd);
+    }
+    else if (state == IOState::NeedWrite) {
+      d_ioState->add(state, callback, shared, ttd);
+    }
+  }
+}
+
 ssize_t DoHConnectionToBackend::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data) {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
-  cerr<<"asked to send "<<length<<" bytes"<<endl;
-  DoHConnectionToBackend* userData = reinterpret_cast<DoHConnectionToBackend*>(user_data);
-  bool bufferWasEmpty = userData->d_out.empty();
-  userData->d_out.insert(userData->d_out.end(), data, data + length);
+  DoHConnectionToBackend* conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+  bool bufferWasEmpty = conn->d_out.empty();
+  conn->d_out.insert(conn->d_out.end(), data, data + length);
 
   if (bufferWasEmpty) {
-    auto state = userData->d_handler->tryWrite(userData->d_out, userData->d_outPos, userData->d_out.size());
-    if (state == IOState::Done) {
-      userData->d_out.clear();
-#warning FIXME from now on we need to read, as we might get an answer
-      cerr<<"FIXME now we need to read!"<<endl;
-      //if (currentIOState does not have NeedRead) {
-      //  userData->addToIOState(IOState::NeedRead);
-      //}
+    try {
+      auto state = conn->d_handler->tryWrite(conn->d_out, conn->d_outPos, conn->d_out.size());
+      if (state == IOState::Done) {
+        ++conn->d_queries;
+        conn->d_out.clear();
+        conn->d_outPos = 0;
+        conn->addToIOState(IOState::NeedRead, handleReadableIOCallback);
+      }
+      else {
+        conn->updateIO(state, handleWritableIOCallback);
+      }
     }
-    else {
-#warning write me should be addIO() instead, perhaps?
-      cerr<<"now we need to wait for a writable (or readable) socket"<<endl;
-      userData->updateIO(state, handleWritableIOCallback);
+    catch (const std::exception& e) {
+      cerr<<"Exception while trying to write (send) to HTTP backend connection: "<<e.what()<<endl;
     }
   }
 
@@ -336,9 +425,9 @@ ssize_t DoHConnectionToBackend::send_callback(nghttp2_session* session, const ui
 }
 
 int DoHConnectionToBackend::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
   DoHConnectionToBackend* conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
-  cerr<<"Frame type is "<<std::to_string(frame->hd.type)<<endl;
+  //cerr<<"Frame type is "<<std::to_string(frame->hd.type)<<endl;
+#if 0
   switch (frame->hd.type) {
   case NGHTTP2_HEADERS:
     cerr<<"got headers"<<endl;
@@ -359,27 +448,36 @@ int DoHConnectionToBackend::on_frame_recv_callback(nghttp2_session* session, con
   case NGHTTP2_DATA:
     cerr<<"got data"<<endl;
     break;
-  case NGHTTP2_PRIORITY:
-    cerr<<"got priority"<<endl;
-    break;
-  case NGHTTP2_GOAWAY:
-    cerr<<"Got GO AWAY"<<endl;
-    break;
+  case NGHTTP2_GOAWAY;
   }
+#endif
 
   /* is this the last frame for this stream? */
   if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && frame->hd.flags & NGHTTP2_FLAG_END_STREAM) {
     auto stream = conn->d_currentStreams.find(frame->hd.stream_id);
     if (stream != conn->d_currentStreams.end()) {
-      cerr<<"Stream "<<frame->hd.stream_id<<" is now finished"<<endl;
+      //cerr<<"Stream "<<frame->hd.stream_id<<" is now finished"<<endl;
       stream->second.d_finished = true;
 
       auto request = std::move(stream->second);
       conn->d_currentStreams.erase(stream->first);
-      conn->handleResponse(std::move(request));
+      if (request.d_responseCode == 200U) {
+        conn->handleResponse(std::move(request));
+      } else {
+        vinfolog("HTTP response has a non-200 status code: %d", request.d_responseCode);
+        struct timeval now;
+        gettimeofday(&now, nullptr);
+
+        conn->handleResponseError(std::move(request), now);
+      }
+      if (conn->getConcurrentStreamsCount() == 0) {
+        conn->stopIO();
+      }
     }
     else {
-      cerr<<"Stream "<<frame->hd.stream_id<<" NOT FOUND"<<endl;
+      vinfolog("Stream %d NOT FOUND", frame->hd.stream_id);
+      conn->d_connectionDied = true;
+      return NGHTTP2_ERR_CALLBACK_FAILURE;
     }
   }
 
@@ -387,270 +485,134 @@ int DoHConnectionToBackend::on_frame_recv_callback(nghttp2_session* session, con
 }
 
 int DoHConnectionToBackend::on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, int32_t stream_id, const uint8_t* data, size_t len, void* user_data) {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
   DoHConnectionToBackend* conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
-  cerr<<"Got data of size "<<len<<" for stream "<<stream_id<<endl;
+  //cerr<<"Got data of size "<<len<<" for stream "<<stream_id<<endl;
   auto stream = conn->d_currentStreams.find(stream_id);
   if (stream == conn->d_currentStreams.end()) {
-    cerr<<"Unable to match the stream ID "<<stream_id<<" to a known one!"<<endl;
+    vinfolog("Unable to match the stream ID %d to a known one!", stream_id);
+    conn->d_connectionDied = true;
+    return NGHTTP2_ERR_CALLBACK_FAILURE;
+  }
+  if (len > std::numeric_limits<uint16_t>::max() || (std::numeric_limits<uint16_t>::max() - stream->second.d_buffer.size()) < len) {
+    vinfolog("Data frame of size %d is too large for a DNS response (we already have %d)", len, stream->second.d_buffer.size());
+    conn->d_connectionDied = true;
     return NGHTTP2_ERR_CALLBACK_FAILURE;
   }
+
   stream->second.d_buffer.insert(stream->second.d_buffer.end(), data, data + len);
   if (stream->second.d_finished) {
-    cerr<<"we now have the full response!"<<endl;
+    //cerr<<"we now have the full response!"<<endl;
+    //cerr<<std::string(reinterpret_cast<const char*>(data), len)<<endl;
+
     auto request = std::move(stream->second);
     conn->d_currentStreams.erase(stream->first);
-    conn->handleResponse(std::move(request));
-    cerr<<std::string(reinterpret_cast<const char*>(data), len)<<endl;
+    if (request.d_responseCode == 200U) {
+      conn->handleResponse(std::move(request));
+    } else {
+      vinfolog("HTTP response has a non-200 status code: %d", request.d_responseCode);
+      struct timeval now;
+      gettimeofday(&now, nullptr);
+
+      conn->handleResponseError(std::move(request), now);
+    }
+    if (conn->getConcurrentStreamsCount() == 0) {
+      conn->stopIO();
+    }
   }
   else {
-    cerr<<"but the stream is not finished yet"<<endl;
+    //cerr<<"but the stream is not finished yet"<<endl;
   }
 
   return 0;
 }
 
 int DoHConnectionToBackend::on_stream_close_callback(nghttp2_session* session, int32_t stream_id, uint32_t error_code, void* user_data) {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
-  //DoHConnectionToBackend* userData = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+  DoHConnectionToBackend* conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+
+  if (error_code == 0) {
+    return 0;
+  }
 
   cerr<<"Stream "<<stream_id<<" closed with error_code="<<error_code<<endl;
-  auto rv = nghttp2_session_terminate_session(session, NGHTTP2_NO_ERROR);
-  if (rv != 0) {
-    return NGHTTP2_ERR_CALLBACK_FAILURE;
+  conn->d_connectionDied = true;
+
+  auto stream = conn->d_currentStreams.find(stream_id);
+  if (stream == conn->d_currentStreams.end()) {
+    /* we don't care, then */
+    cerr<<"we don't care"<<endl;
+    return 0;
   }
 
-  return 0;
-}
+  struct timeval now;
+  gettimeofday(&now, nullptr);
+  auto request = std::move(stream->second);
+  conn->d_currentStreams.erase(stream->first);
 
-int DoHConnectionToBackend::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data) {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
-  //DoHConnectionToBackend* userData = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+  //cerr<<"in "<<__PRETTY_FUNCTION__<<", looking for a connection to send a query of size "<<request.d_query.d_buffer.size()<<endl;
+  auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(conn->d_mplexer, conn->d_ds, now);
+  downstream->queueQuery(request.d_sender, std::move(request.d_query));
 
-  switch (frame->hd.type) {
-  case NGHTTP2_HEADERS:
-    if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
-      /* Print response headers for the initiated request. */
-      cerr<<"got header for "<<frame->hd.stream_id<<":"<<endl;
-      cerr<<"- "<<std::string(reinterpret_cast<const char*>(name), namelen)<<endl;
-      cerr<<"- "<<std::string(reinterpret_cast<const char*>(value), valuelen)<<endl;
-      break;
-    }
+  //cerr<<"we now have "<<conn->getConcurrentStreamsCount()<<" concurrent connections"<<endl;
+  if (conn->getConcurrentStreamsCount() == 0) {
+    //cerr<<"stopping IO"<<endl;
+    conn->stopIO();
+    //cerr<<"our current refcnt is now "<<conn->getUsageCount()<<endl;
   }
+
   return 0;
 }
 
-int DoHConnectionToBackend::on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data) {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
-  //DoHConnectionToBackend* userData = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+int DoHConnectionToBackend::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t namelen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data) {
+  DoHConnectionToBackend* conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
 
+  const std::string status(":status");
   switch (frame->hd.type) {
   case NGHTTP2_HEADERS:
     if (frame->headers.cat == NGHTTP2_HCAT_RESPONSE) {
-      cerr<<"Response headers for stream ID="<<frame->hd.stream_id<<endl;
-    }
-    break;
-  }
-  return 0;
-}
-
-int DoHConnectionToBackend::on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data) {
-  cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
-  cerr<<"Error is "<<std::string(msg, len)<<endl;
-  //DoHConnectionToBackend* userData = reinterpret_cast<DoHConnectionToBackend*>(user_data);
-
-  return 0;
-}
-
-#if 0
-static void doReadData(DoHConnectionToBackend& userData)
-{
-  do {
-    size_t pos = 0;
-    userData.d_in.resize(512);
-    cerr<<"trying to read "<<userData.d_in.size()<<endl;
-    try {
-      pos = userData.d_handler->read(userData.d_in.data(), userData.d_in.size(), timeval{2, 0}, timeval{2, 0}, true);
-      // userData.d_handler->tryRead(userData.d_in, pos, userData.d_in.size());
-      cerr<<"got "<<pos<<endl;
-      userData.d_in.resize(pos);
-      if (pos > 0) {
-        auto readlen = nghttp2_session_mem_recv(userData.d_session.get(), userData.d_in.data(), pos);
-        cerr<<"nghttp2_session_mem_recv returned "<<readlen<<endl;
-        if (readlen < 0) {
-          cerr<<"Fatal error: "<<nghttp2_strerror((int)readlen)<<endl;
-          return;
+      //cerr<<"got header for "<<frame->hd.stream_id<<":"<<endl;
+      //cerr<<"- "<<std::string(reinterpret_cast<const char*>(name), namelen)<<endl;
+      //cerr<<"- "<<std::string(reinterpret_cast<const char*>(value), valuelen)<<endl;
+      if (namelen == status.size() && memcmp(status.data(), name, status.size()) == 0) {
+        auto stream = conn->d_currentStreams.find(frame->hd.stream_id);
+        if (stream == conn->d_currentStreams.end()) {
+          vinfolog("Unable to match the stream ID %d to a known one!", frame->hd.stream_id);
+          conn->d_connectionDied = true;
+          return NGHTTP2_ERR_CALLBACK_FAILURE;
+        }
+        try {
+          stream->second.d_responseCode = pdns_stou(std::string(reinterpret_cast<const char*>(value), valuelen));
+        }
+        catch (...) {
+          vinfolog("Error parsing the status header for stream ID %d", frame->hd.stream_id);
+          conn->d_connectionDied = true;
+          return NGHTTP2_ERR_CALLBACK_FAILURE;
         }
-        int rv = nghttp2_session_send(userData.d_session.get());
-        cerr<<"nghttp2_session_send returned "<<rv<<endl;
-      }
-      else {
-        break;
       }
-    }
-    catch (const std::exception& e) {
-      cerr<<"got exception "<<e.what()<<endl;
+
       break;
     }
   }
-  while (true);
+  return 0;
 }
 
-void sendHTTP2Query()
-{
-  auto remote = ComboAddress("9.9.9.11:443");
-  std::string host("dns11.quad9.net");
-  std::string path("/dns-query");
-  struct TLSContextParameters tlsParams;
-  tlsParams.d_provider = "openssl";
-  std::shared_ptr<TLSCtx> tlsCtx = getTLSContext(tlsParams);
-
-  Socket sock(remote.sin4.sin_family, SOCK_STREAM);
-  // FIXME
-  auto fd = sock.getHandle();
-  setTCPNoDelay(fd);
-  DoHConnectionToBackend userData;
-  userData.d_handler = std::make_unique<TCPIOHandler>(host, sock.releaseHandle(), timeval{2, 0}, tlsCtx, time(nullptr));
-  userData.d_handler->connect(true, remote, timeval{2, 0});
-
-  /* check ALPN:
-SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen);
-#if OPENSSL_VERSION_NUMBER >= 0x10002000L
-    if (alpn == NULL) {
-      SSL_get0_alpn_selected(ssl, &alpn, &alpnlen);
-    }
-#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L
-
-    if (alpn == NULL || alpnlen != 2 || memcmp("h2", alpn, 2) != 0) {
-      fprintf(stderr, "h2 is not negotiated\n");
-      delete_http2_session_data(session_data);
-      return;
-    }
-  */
-
-  nghttp2_session_callbacks* cbs = nullptr;
-  if (nghttp2_session_callbacks_new(&cbs) != 0) {
-    cerr<<"unable to create a callback object for a new HTTP/2 session"<<endl;
-    return;
-  }
-  std::unique_ptr<nghttp2_session_callbacks, void(*)(nghttp2_session_callbacks*)> callbacks(cbs, nghttp2_session_callbacks_del);
-  cbs = nullptr;
-
-  nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback);
-  nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback);
-  nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback);
-  nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback);
-  nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback);
-  nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback);
-
-  nghttp2_session* sess = nullptr;
-  if (nghttp2_session_client_new(&sess, callbacks.get(), &userData) != 0) {
-    cerr<<"Coult not allocate a new HTTP/2 session"<<endl;
-    return;
-  }
-
-  userData.d_session = std::unique_ptr<nghttp2_session, void(*)(nghttp2_session*)>(sess, nghttp2_session_del);
-  sess = nullptr;
-
-  callbacks.reset();
-
-#warning we should make the 100 configurable here, as we might want a lower number before receiving the one actually supported by the server
-#warning we should also make the window size configurable, but 16M is a nice default
-  nghttp2_settings_entry iv[] = {
-    {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100},
-    {NGHTTP2_SETTINGS_ENABLE_PUSH, 0},
-    {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16*1024*1024}
-  };
-   /* client 24 bytes magic string will be sent by nghttp2 library */
-  int rv = nghttp2_submit_settings(userData.d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv)/sizeof(*iv));
-  if (rv != 0) {
-    cerr<<"Could not submit SETTINGS: "<<nghttp2_strerror(rv)<<endl;
-    return;
-  }
-
-  GenericDNSPacketWriter<PacketBuffer> pw(userData.d_in, DNSName("doh.dnsdist.org."), QType::A, QClass::IN, 0);
-  pw.getHeader()->rd = 1;
-  pw.commit();
-
-  /* we could use nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_NAME and nghttp2_nv_flag.NGHTTP2_NV_FLAG_NO_COPY_VALUE
-     to avoid a copy and lowercasing as long as we take care of making sure that the data will outlive the request
-     and that it is already lowercased. */
-  auto payloadSize = std::to_string(userData.d_in.size());
-  const nghttp2_nv hdrs[] = {
-      MAKE_NV2(":method", "POST"),
-      MAKE_NV2(":scheme", "https"),
-      MAKE_NV(":authority", host.c_str(), host.size()),
-      MAKE_NV(":path", path.c_str(), path.size()),
-      MAKE_NV2("accept", "application/dns-message"),
-      MAKE_NV2("content-type", "application/dns-message"),
-      MAKE_NV("content-length", payloadSize.c_str(), payloadSize.size()),
-      MAKE_NV2("user-agent", "nghttp2-" NGHTTP2_VERSION "/dnsdist")
-  };
-
-  /* f data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key in nva (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set
-   */
-  cerr<<"Remote size window is "<<nghttp2_session_get_remote_window_size(userData.d_session.get())<<endl;
-
-  nghttp2_data_provider data_provider;
-  data_provider.source.ptr = &userData;
-  data_provider.read_callback = [](nghttp2_session* session, int32_t stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* user_data) -> ssize_t
-  {
-    cerr<<"in data provider"<<endl;
-    auto userData = reinterpret_cast<DoHConnectionToBackend*>(user_data);
-    if (userData->d_inPos >= userData->d_in.size()) {
-       *data_flags |= NGHTTP2_DATA_FLAG_EOF;
-       cerr<<"EOF"<<endl;
-       return 0;
-    }
-    size_t remaining = userData->d_in.size()- userData->d_inPos;
-    size_t toCopy = length > remaining ? remaining : length;
-    memcpy(buf, &userData->d_in.at(userData->d_inPos), toCopy);
-    userData->d_inPos += toCopy;
-    cerr<<toCopy<<" written"<<endl;
-    return toCopy;
-  };
-
-  auto stream_id = nghttp2_submit_request(userData.d_session.get(), nullptr, hdrs, sizeof(hdrs)/sizeof(*hdrs), &data_provider, &userData);
-  if (stream_id < 0) {
-    cerr<<"Could not submit HTTP request: "<<nghttp2_strerror(stream_id)<<endl;
-    return;
-  }
-  rv = nghttp2_session_send(userData.d_session.get());
+int DoHConnectionToBackend::on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data) {
+  vinfolog("Error in HTTP/2 connection: %s", std::string(msg, len));
 
-  setNonBlocking(fd);
+  DoHConnectionToBackend* conn = reinterpret_cast<DoHConnectionToBackend*>(user_data);
+  conn->d_connectionDied = true;
 
-  doReadData(userData);
-  cerr<<"After reading data, remote size window is "<<nghttp2_session_get_remote_window_size(userData.d_session.get())<<endl;
-  cerr<<"Max number of streams from remote is "<<nghttp2_session_get_remote_settings(userData.d_session.get(), NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS)<<endl;
-  cerr<<"our own is "<<nghttp2_session_get_local_settings(userData.d_session.get(), NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS)<<endl;
-  // min(nghttp2_session_get_stream_remote_window_size(), nghttp2_session_get_remote_window_size())
-#warning for later: how do we know how many streams are left? the window size?
+  return 0;
 }
-#endif
 
 DoHConnectionToBackend::DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now): TCPConnectionToBackend(ds, mplexer, now)
 {
   // inherit most of the stuff from the TCPConnectionToBackend()
-
-  /* check ALPN:
-SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen);
-#if OPENSSL_VERSION_NUMBER >= 0x10002000L
-    if (alpn == NULL) {
-      SSL_get0_alpn_selected(ssl, &alpn, &alpnlen);
-    }
-#endif // OPENSSL_VERSION_NUMBER >= 0x10002000L
-
-    if (alpn == NULL || alpnlen != 2 || memcmp("h2", alpn, 2) != 0) {
-      fprintf(stderr, "h2 is not negotiated\n");
-      delete_http2_session_data(session_data);
-      return;
-    }
-  */
   d_ioState = make_unique<IOStateHandler>(*d_mplexer, d_handler->getDescriptor());
 
   nghttp2_session_callbacks* cbs = nullptr;
   if (nghttp2_session_callbacks_new(&cbs) != 0) {
-    cerr<<"unable to create a callback object for a new HTTP/2 session"<<endl;
+    d_connectionDied = true;
+    vinfolog("Unable to create a callback object for a new HTTP/2 session");
     return;
   }
   std::unique_ptr<nghttp2_session_callbacks, void(*)(nghttp2_session_callbacks*)> callbacks(cbs, nghttp2_session_callbacks_del);
@@ -661,12 +623,12 @@ SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen);
   nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback);
   nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback);
   nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback);
-  nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback);
   nghttp2_session_callbacks_set_error_callback2(callbacks.get(), on_error_callback);
 
   nghttp2_session* sess = nullptr;
   if (nghttp2_session_client_new(&sess, callbacks.get(), this) != 0) {
-    cerr<<"Coult not allocate a new HTTP/2 session"<<endl;
+    d_connectionDied = true;
+    vinfolog("Coult not allocate a new HTTP/2 session");
     return;
   }
 
@@ -675,46 +637,26 @@ SSL_get0_next_proto_negotiated(ssl, &alpn, &alpnlen);
 
   callbacks.reset();
 
-#warning we should make the 100 configurable here, as we might want a lower number before receiving the one actually supported by the server
-#warning we should also make the window size configurable, but 16M is a nice default
   nghttp2_settings_entry iv[] = {
-    {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100},
+    /* rfc7540 section-8.2.2:
+       "Advertising a SETTINGS_MAX_CONCURRENT_STREAMS value of zero disables
+       server push by preventing the server from creating the necessary
+       streams."
+    */
+    {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 0},
     {NGHTTP2_SETTINGS_ENABLE_PUSH, 0},
+    /* we might want to make the initial window size configurable, but 16M is a large enough default */
     {NGHTTP2_SETTINGS_INITIAL_WINDOW_SIZE, 16*1024*1024}
   };
    /* client 24 bytes magic string will be sent by nghttp2 library */
   int rv = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, iv, sizeof(iv)/sizeof(*iv));
   if (rv != 0) {
-    cerr<<"Could not submit SETTINGS: "<<nghttp2_strerror(rv)<<endl;
+    d_connectionDied = true;
+    vinfolog("Could not submit SETTINGS: %s", nghttp2_strerror(rv));
     return;
   }
 }
 
-class DownstreamDoHConnectionsManager
-{
-public:
-  static std::shared_ptr<DoHConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now);
-  static void releaseDownstreamConnection(std::shared_ptr<DoHConnectionToBackend>&& conn);
-  static void cleanupClosedConnections(struct timeval now);
-  static size_t clear();
-
-  static void setMaxCachedConnectionsPerDownstream(size_t max)
-  {
-    s_maxCachedConnectionsPerDownstream = max;
-  }
-
-  static void setCleanupInterval(uint16_t interval)
-  {
-    s_cleanupInterval = interval;
-  }
-
-private:
-  static thread_local map<boost::uuids::uuid, std::deque<std::shared_ptr<DoHConnectionToBackend>>> t_downstreamConnections;
-  static size_t s_maxCachedConnectionsPerDownstream;
-  static time_t s_nextCleanup;
-  static uint16_t s_cleanupInterval;
-};
-
 struct DoHClientCollection::DoHWorkerThread
 {
   DoHWorkerThread()
@@ -778,9 +720,94 @@ bool DoHClientCollection::passCrossProtocolQueryToThread(std::unique_ptr<CrossPr
   return true;
 }
 
+thread_local map<boost::uuids::uuid, std::deque<std::shared_ptr<DoHConnectionToBackend>>> DownstreamDoHConnectionsManager::t_downstreamConnections;
+size_t DownstreamDoHConnectionsManager::s_maxCachedConnectionsPerDownstream{10};
+time_t DownstreamDoHConnectionsManager::s_nextCleanup{0};
+uint16_t DownstreamDoHConnectionsManager::s_cleanupInterval{60};
+
+void DownstreamDoHConnectionsManager::cleanupClosedConnections(struct timeval now)
+{
+  struct timeval freshCutOff = now;
+  freshCutOff.tv_sec -= 1;
+
+  for (auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
+    for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
+      if (!(*connIt)) {
+        ++connIt;
+        continue;
+      }
+
+      /* don't bother checking freshly used connections */
+      if (freshCutOff < (*connIt)->getLastDataReceivedTime()) {
+        ++connIt;
+        continue;
+      }
+
+      if (isTCPSocketUsable((*connIt)->getHandle())) {
+        ++connIt;
+      }
+      else {
+        connIt = dsIt->second.erase(connIt);
+      }
+    }
+
+    if (!dsIt->second.empty()) {
+      ++dsIt;
+    }
+    else {
+      dsIt = t_downstreamConnections.erase(dsIt);
+    }
+  }
+}
+
 std::shared_ptr<DoHConnectionToBackend> DownstreamDoHConnectionsManager::getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
 {
-  return std::make_shared<DoHConnectionToBackend>(ds, mplexer, now);
+  std::shared_ptr<DoHConnectionToBackend> result;
+  struct timeval freshCutOff = now;
+  freshCutOff.tv_sec -= 1;
+
+  auto backendId = ds->getID();
+
+  if (s_cleanupInterval > 0 && (s_nextCleanup == 0 || s_nextCleanup <= now.tv_sec)) {
+    s_nextCleanup = now.tv_sec + s_cleanupInterval;
+    //cerr<<"cleaning up"<<endl;
+    cleanupClosedConnections(now);
+  }
+
+  {
+    //cerr<<"looking for existing connection"<<endl;
+    const auto& it = t_downstreamConnections.find(backendId);
+    if (it != t_downstreamConnections.end()) {
+      auto& list = it->second;
+      for (auto listIt = list.begin(); listIt != list.end(); ) {
+        auto& entry = *listIt;
+        if (!entry->canBeReused()) {
+          listIt = list.erase(listIt);
+          continue;
+        }
+        entry->setReused();
+        /* for connections that have not been used very recently,
+           check whether they have been closed in the meantime */
+        if (freshCutOff < entry->getLastDataReceivedTime()) {
+          /* used recently enough, skip the check */
+          ++ds->tcpReusedConnections;
+          return entry;
+        }
+
+        if (isTCPSocketUsable(entry->getHandle())) {
+          ++ds->tcpReusedConnections;
+          return entry;
+        }
+
+        /* otherwise let's try the next one, if any */
+        ++listIt;
+      }
+    }
+
+    auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now);
+    t_downstreamConnections[backendId].push_back(newConnection);
+    return newConnection;
+  }
 }
 
 static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param)
@@ -812,15 +839,19 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
     delete tmp;
     tmp = nullptr;
 
-    auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now);
-    
+    try {
+      auto downstream = DownstreamDoHConnectionsManager::getConnectionToDownstream(threadData->mplexer, downstreamServer, now);
+
 #warning what about the proxy protocol payload, here, do we need to remove it? we likely need to handle forward-for headers?
-    downstream->queueQuery(tqs, std::move(query));
+      downstream->queueQuery(tqs, std::move(query));
+    }
+    catch (...) {
+      tqs->notifyIOError(std::move(query.d_idstate), now);
+    }
   }
   catch (...) {
     delete tmp;
     tmp = nullptr;
-    throw;
   }
 }
 
@@ -962,7 +993,27 @@ void DoHClientCollection::addThread()
 bool initDoHWorkers()
 {
 #warning FIXME: number of DoH threads
-  g_dohClientThreads = std::make_unique<DoHClientCollection>(1);
+  g_dohClientThreads = std::make_unique<DoHClientCollection>(4);
   g_dohClientThreads->addThread();
   return true;
 }
+
+static bool select_next_proto_callback(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen) {
+  if (nghttp2_select_next_protocol(out, outlen, in, inlen) <= 0) {
+    vinfolog("The remote DoH backend did not advertise " NGHTTP2_PROTO_VERSION_ID);
+    return false;
+  }
+  return true;
+}
+
+bool setupDoHClientProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
+{
+  if (ctx == nullptr) {
+    return false;
+  }
+  /* we want to set the ALPN to h2, if only to mitigate the ALPACA attack */
+  const std::vector<std::vector<uint8_t>> h2Alpns = {{'h', '2'}};
+  ctx->setALPNProtos(h2Alpns);
+  ctx->setNextProtocolSelectCallback(select_next_proto_callback);
+  return true;
+}
index 97137358236bd0b5067ce23dca65480de49dcdd3..0775898e81818fed4dcfd8b9547a0d858047ad47 100644 (file)
@@ -60,4 +60,7 @@ private:
 extern std::unique_ptr<DoHClientCollection> g_dohClientThreads;
 extern std::atomic<uint64_t> g_dohStatesDumpRequested;
 
+class TLSCtx;
+
 bool initDoHWorkers();
+bool setupDoHClientProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);
index 2201f2df4c18a46ce43a64a61cf5420ad6ceaefd..9370a8a0073262a62654af4f1bde73413dc01d1b 100644 (file)
@@ -73,7 +73,7 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend
   if (conn->d_currentQuery.d_proxyProtocolPayloadAdded) {
     conn->d_proxyProtocolPayloadSent = true;
   }
-  conn->incQueries();
+  ++conn->d_queries;
   conn->d_currentPos = 0;
 
   DEBUGLOG("adding a pending response for ID "<<ntohs(conn->d_currentQuery.d_idstate.origID)<<" and QNAME "<<conn->d_currentQuery.d_idstate.qname);
index 9301ad70d89c2a12632a7b8f0ae526754e1a2115..aafbd3dd90df80d20d4f9bea61d0d2d30b3d6963 100644 (file)
@@ -46,11 +46,6 @@ public:
     return d_fresh;
   }
 
-  void incQueries()
-  {
-    ++d_queries;
-  }
-
   void setReused()
   {
     d_fresh = false;
@@ -86,7 +81,7 @@ public:
   }
 
   /* whether a connection can be reused for a different client */
-  bool canBeReused() const
+  virtual bool canBeReused() const
   {
     if (d_connectionDied) {
       return false;
@@ -126,7 +121,7 @@ public:
   virtual std::string toString() const
   {
     ostringstream o;
-    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses, linked to "<<(d_sender ? " a client" : "no client");
+    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses, linked to "<<(d_sender ? " a client" : "no client");
     return o.str();
   }
 
index 9d392006c958bddceb18d88eebb394a01b540214..f698d9d4353538956939b88883f7019764e069af 100644 (file)
@@ -140,7 +140,7 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
   std::string toString() const
   {
     ostringstream o;
-    o << "Incoming TCP connection from "<<d_ci.remote.toStringWithPort()<<" over FD "<<d_handler.getDescriptor()<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? std::to_string((int)d_ioState->getState()) : "empty")<<", queries count is "<<d_queriesCount<<", current queries count is "<<d_currentQueriesCount<<", "<<d_queuedResponses.size()<<" queued responses, "<<d_activeConnectionsToBackend.size()<<" active connections to a backend";
+    o << "Incoming TCP connection from "<<d_ci.remote.toStringWithPort()<<" over FD "<<d_handler.getDescriptor()<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<<d_queriesCount<<", current queries count is "<<d_currentQueriesCount<<", "<<d_queuedResponses.size()<<" queued responses, "<<d_activeConnectionsToBackend.size()<<" active connections to a backend";
     return o.str();
   }
 
index d8f8e6d8c2f5cb7c14ee3438b975f5eedd430029..6662e7d3dedfadb6dd8467be08948938055fa14d 100644 (file)
@@ -627,7 +627,7 @@ static int processDOHQuery(DOHUnit* du)
       du->ids.cs = &cs;
       setIDStateFromDNSQuestion(du->ids, dq, std::move(qname));
 
-      if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
+      if (du->downstream->passCrossProtocolQuery(std::move(cpq))) {
         return 0;
       }
       else {
index c8d98b8b987c09f2bc70305567901686e9d1b449..2fbe4b5147e4c48f32d3142904a2e92ef4d89119 100644 (file)
 class IOStateHandler
 {
 public:
-  IOStateHandler(FDMultiplexer& mplexer, const int fd): d_mplexer(mplexer), d_fd(fd), d_currentState(IOState::Done)
+  IOStateHandler(FDMultiplexer& mplexer, const int fd): d_mplexer(mplexer), d_fd(fd)
   {
   }
 
-  IOStateHandler(FDMultiplexer& mplexer): d_mplexer(mplexer), d_fd(-1), d_currentState(IOState::Done)
+  IOStateHandler(FDMultiplexer& mplexer): d_mplexer(mplexer), d_fd(-1)
   {
   }
 
@@ -36,9 +36,14 @@ public:
     }
   }
 
-  IOState getState() const
+  bool isWaitingForRead() const
   {
-    return d_currentState;
+    return d_isWaitingForRead;
+  }
+
+  bool isWaitingForWrite() const
+  {
+    return d_isWaitingForWrite;
   }
 
   void setSocket(int fd)
@@ -54,22 +59,66 @@ public:
     update(IOState::Done);
   }
 
+  std::string getState() const
+  {
+    std::string result("--");
+    result.reserve(2);
+    if (isWaitingForRead()) {
+      result.at(0) = 'R';
+    }
+    if (isWaitingForWrite()) {
+      result.at(1) = 'W';
+    }
+    return result;
+  }
+
+  void add(IOState iostate, FDMultiplexer::callbackfunc_t callback, FDMultiplexer::funcparam_t callbackData, boost::optional<struct timeval> ttd)
+  {
+    DEBUGLOG("in "<<__PRETTY_FUNCTION__<<" for fd "<<d_fd<<", last state was "<<getState()<<", adding "<<(int)iostate);
+    if (iostate == IOState::NeedRead) {
+      if (isWaitingForRead()) {
+        if (ttd) {
+          /* let's update the TTD ! */
+          d_mplexer.setReadTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */0);
+        }
+        return;
+      }
+
+      d_mplexer.addReadFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
+      DEBUGLOG(__PRETTY_FUNCTION__<<": add read FD "<<d_fd);
+      d_isWaitingForRead = true;
+    }
+    else if (iostate == IOState::NeedWrite) {
+      if (isWaitingForWrite()) {
+        if (ttd) {
+          /* let's update the TTD ! */
+          d_mplexer.setWriteTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */0);
+        }
+        return;
+      }
+
+      d_mplexer.addWriteFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
+      DEBUGLOG(__PRETTY_FUNCTION__<<": add write FD "<<d_fd);
+      d_isWaitingForWrite = true;
+    }
+  }
+
   void update(IOState iostate, FDMultiplexer::callbackfunc_t callback = FDMultiplexer::callbackfunc_t(), FDMultiplexer::funcparam_t callbackData = boost::any(), boost::optional<struct timeval> ttd = boost::none)
   {
-    DEBUGLOG("in "<<__PRETTY_FUNCTION__<<" for fd "<<d_fd<<", last state was "<<(int)d_currentState<<", new state is "<<(int)iostate);
-    if (d_currentState == IOState::NeedRead && iostate == IOState::Done) {
+    DEBUGLOG("in "<<__PRETTY_FUNCTION__<<" for fd "<<d_fd<<", last state was "<<getState()<<" , new state is "<<(int)iostate);
+    if (isWaitingForRead() && iostate == IOState::Done) {
       DEBUGLOG(__PRETTY_FUNCTION__<<": remove read FD "<<d_fd);
       d_mplexer.removeReadFD(d_fd);
-      d_currentState = IOState::Done;
+      d_isWaitingForRead = false;
     }
-    else if (d_currentState == IOState::NeedWrite && iostate == IOState::Done) {
+    if (isWaitingForWrite() && iostate == IOState::Done) {
       DEBUGLOG(__PRETTY_FUNCTION__<<": remove write FD "<<d_fd);
       d_mplexer.removeWriteFD(d_fd);
-      d_currentState = IOState::Done;
+      d_isWaitingForWrite = false;
     }
 
     if (iostate == IOState::NeedRead) {
-      if (d_currentState == IOState::NeedRead) {
+      if (isWaitingForRead()) {
         if (ttd) {
           /* let's update the TTD ! */
           d_mplexer.setReadTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */0);
@@ -77,7 +126,8 @@ public:
         return;
       }
 
-      if (d_currentState == IOState::NeedWrite) {
+      if (isWaitingForWrite()) {
+        d_isWaitingForWrite = false;
         d_mplexer.alterFDToRead(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
         DEBUGLOG(__PRETTY_FUNCTION__<<": alter from write to read FD "<<d_fd);
       }
@@ -86,11 +136,10 @@ public:
         DEBUGLOG(__PRETTY_FUNCTION__<<": add read FD "<<d_fd);
       }
 
-      d_currentState = IOState::NeedRead;
-
+      d_isWaitingForRead = true;
     }
     else if (iostate == IOState::NeedWrite) {
-      if (d_currentState == IOState::NeedWrite) {
+      if (isWaitingForWrite()) {
         if (ttd) {
           /* let's update the TTD ! */
           d_mplexer.setWriteTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */0);
@@ -98,7 +147,8 @@ public:
         return;
       }
 
-      if (d_currentState == IOState::NeedRead) {
+      if (isWaitingForRead()) {
+        d_isWaitingForRead = false;
         d_mplexer.alterFDToWrite(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
         DEBUGLOG(__PRETTY_FUNCTION__<<": alter from read to write FD "<<d_fd);
       }
@@ -107,10 +157,9 @@ public:
         DEBUGLOG(__PRETTY_FUNCTION__<<": add write FD "<<d_fd);
       }
 
-      d_currentState = IOState::NeedWrite;
+      d_isWaitingForWrite = true;
     }
     else if (iostate == IOState::Done) {
-      d_currentState = IOState::Done;
       DEBUGLOG(__PRETTY_FUNCTION__<<": done");
     }
   }
@@ -118,7 +167,8 @@ public:
 private:
   FDMultiplexer& d_mplexer;
   int d_fd;
-  IOState d_currentState;
+  bool d_isWaitingForRead{false};
+  bool d_isWaitingForWrite{false};
 };
 
 class IOStateGuard
index c3411a55b51c9f84af2d8e0aa2a436705fbb0839..7ea89e2b87ff5768997b9cc0e3ae3bbcfec35e0a 100644 (file)
@@ -227,6 +227,11 @@ public:
     return "";
   }
 
+  std::vector<uint8_t> getNextProtocol() const override
+  {
+    return std::vector<uint8_t>();
+  }
+
   LibsslTLSVersion getTLSVersion() const override
   {
     return LibsslTLSVersion::TLS13;
index b667d27d01cf69eb82eb15727e39046b555b9cca..ccc287ed6dceb68d6891a4184a790dd10b8f488b 100644 (file)
@@ -793,6 +793,40 @@ std::unique_ptr<FILE, int(*)(FILE*)> libssl_set_key_log_file(std::unique_ptr<SSL
 #endif /* HAVE_SSL_CTX_SET_KEYLOG_CALLBACK */
 }
 
+/* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
+void libssl_set_npn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
+{
+#ifdef HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB
+  SSL_CTX_set_next_proto_select_cb(ctx.get(), cb, arg);
+#endif
+}
+
+void libssl_set_alpn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg)
+{
+#ifdef HAVE_SSL_CTX_SET_ALPN_SELECT_CB
+  SSL_CTX_set_alpn_select_cb(ctx.get(), cb, arg);
+#endif
+}
+
+bool libssl_set_alpn_protos(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::vector<std::vector<uint8_t>>& protos)
+{
+#ifdef HAVE_SSL_CTX_SET_ALPN_PROTOS
+  std::vector<uint8_t> wire;
+  for (const auto& proto : protos) {
+    if (proto.size() > std::numeric_limits<uint8_t>::max()) {
+      throw std::runtime_error("Invalid ALPN value");
+    }
+    uint8_t length = proto.size();
+    wire.push_back(length);
+    wire.insert(wire.end(), proto.begin(), proto.end());
+  }
+  return SSL_CTX_set_alpn_protos(ctx.get(), wire.data(), wire.size()) == 0;
+#else
+  return false;
+#endif
+}
+
+
 std::string libssl_get_error_string()
 {
   BIO *mem = BIO_new(BIO_s_mem());
index b090afa7b146d59200d1a79d1ffe881978de7557..2af0f4ef8947ec96523b3818e540a35876ee32db 100644 (file)
@@ -126,6 +126,13 @@ std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> libssl_init_server_context(const TLS
 
 std::unique_ptr<FILE, int(*)(FILE*)> libssl_set_key_log_file(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::string& logFile);
 
+/* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
+void libssl_set_npn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg);
+/* called in a server context, to select an ALPN value advertised by the client if any */
+void libssl_set_alpn_select_callback(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, int (*cb)(SSL* s, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg), void* arg);
+/* set the supported ALPN protos in client context */
+bool libssl_set_alpn_protos(std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)>& ctx, const std::vector<std::vector<uint8_t>>& protos);
+
 std::string libssl_get_error_string();
 
 #endif /* HAVE_LIBSSL */
index 4eb80627ab54fa3ea80f1790d6d97bf3549ec484..e957d66bbe888338844717ffcf4e0095767eb2e5 100644 (file)
@@ -151,7 +151,12 @@ public:
       return IOState::NeedWrite;
     }
     else if (error == SSL_ERROR_SYSCALL) {
-      throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
+      if (errno == 0) {
+        throw std::runtime_error("TLS connection closed by remote end");
+      }
+      else {
+        throw std::runtime_error("Syscall error while processing TLS connection: " + std::string(strerror(errno)));
+      }
     }
     else if (error == SSL_ERROR_ZERO_RETURN) {
       throw std::runtime_error("TLS connection closed by remote end");
@@ -401,6 +406,29 @@ public:
     return std::string();
   }
 
+  std::vector<uint8_t> getNextProtocol() const override
+  {
+    std::vector<uint8_t> result;
+    if (!d_conn) {
+      return result;
+    }
+
+    const unsigned char* alpn = nullptr;
+    unsigned int alpnLen  = 0;
+#ifdef HAVE_SSL_GET0_NEXT_PROTO_NEGOTIATED
+    SSL_get0_next_proto_negotiated(d_conn.get(), &alpn, &alpnLen);
+#endif
+#ifdef HAVE_SSL_GET0_ALPN_SELECTED
+    if (alpn == nullptr) {
+      SSL_get0_alpn_selected(d_conn.get(), &alpn, &alpnLen);
+    }
+#endif
+    if (alpn != nullptr && alpnLen > 0) {
+      result.insert(result.end(), alpn, alpn + alpnLen);
+    }
+    return result;
+  }
+
   LibsslTLSVersion getTLSVersion() const override
   {
     auto proto = SSL_version(d_conn.get());
@@ -668,9 +696,74 @@ public:
     return "openssl";
   }
 
+  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
+  {
+    if (d_feContext && d_feContext->d_tlsCtx) {
+      d_alpnProtos = protos;
+      libssl_set_alpn_select_callback(d_feContext->d_tlsCtx, alpnServerSelectCallback, this);
+      return true;
+    }
+    if (d_tlsCtx) {
+      return libssl_set_alpn_protos(d_tlsCtx, protos);
+    }
+    return false;
+  }
+
+  bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
+  {
+    d_nextProtocolSelectCallback = cb;
+    libssl_set_npn_select_callback(d_tlsCtx, npnSelectCallback, this);
+    return true;
+  }
+
 private:
+  /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
+  static int npnSelectCallback(SSL* s, unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
+  {
+    if (!arg) {
+      return SSL_TLSEXT_ERR_ALERT_WARNING;
+    }
+    OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
+    if (obj->d_nextProtocolSelectCallback) {
+      return (*obj->d_nextProtocolSelectCallback)(out, outlen, in, inlen) ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_ALERT_WARNING;
+    }
+
+    return SSL_TLSEXT_ERR_OK;
+  }
+
+  static int alpnServerSelectCallback(SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen, void* arg)
+  {
+    if (!arg) {
+      return SSL_TLSEXT_ERR_ALERT_WARNING;
+    }
+    OpenSSLTLSIOCtx* obj = reinterpret_cast<OpenSSLTLSIOCtx*>(arg);
+
+    size_t pos = 0;
+    while (pos < inlen) {
+      size_t protoLen = in[pos];
+      pos++;
+      if (protoLen > (inlen - pos)) {
+        /* something is very wrong */
+        return SSL_TLSEXT_ERR_ALERT_WARNING;
+      }
+
+      for (const auto& tentative : obj->d_alpnProtos) {
+        if (tentative.size() == protoLen && memcmp(in + pos, tentative.data(), tentative.size()) == 0) {
+          *out = in + pos;
+          *outlen = protoLen;
+          return SSL_TLSEXT_ERR_OK;
+        }
+      }
+      pos += protoLen;
+    }
+
+    return SSL_TLSEXT_ERR_NOACK;
+  }
+
+  std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
   std::shared_ptr<OpenSSLFrontendContext> d_feContext;
-  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx; // client context
+  std::unique_ptr<SSL_CTX, void(*)(SSL_CTX*)> d_tlsCtx; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
+  bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
 };
 
 #endif /* HAVE_LIBSSL */
@@ -1226,6 +1319,20 @@ public:
     return std::string();
   }
 
+  std::vector<uint8_t> getNextProtocol() const override
+  {
+    std::vector<uint8_t> result;
+    if (!d_conn) {
+      return result;
+    }
+    gnutls_datum_t next;
+    if (gnutls_alpn_get_selected_protocol(d_conn.get(), &next) != GNUTLS_E_SUCCESS) {
+      return result;
+    }
+    result.insert(result.end(), next.data, next.data + next.size);
+    return result;
+  }
+
   LibsslTLSVersion getTLSVersion() const override
   {
     auto proto = gnutls_protocol_get_version(d_conn.get());
@@ -1285,6 +1392,19 @@ public:
     }
   }
 
+  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos)
+  {
+    std::vector<gnutls_datum_t> values;
+    values.reserve(protos.size());
+    for (const auto& proto : protos) {
+      gnutls_datum_t value;
+      value.data = const_cast<uint8_t*>(proto.data());
+      value.size = proto.size();
+      values.push_back(value);
+    }
+    return gnutls_alpn_set_protocols(d_conn.get(), values.data(), values.size(), GNUTLS_ALPN_MANDATORY);
+  }
+
 private:
   std::shared_ptr<GnuTLSTicketsKey> d_ticketsKey;
   std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)> d_conn;
@@ -1406,12 +1526,20 @@ public:
       ticketsKey = *(d_ticketsKey.read_lock());
     }
 
-    return std::make_unique<GnuTLSConnection>(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets);
+    auto connection = std::make_unique<GnuTLSConnection>(socket, timeout, d_creds.get(), d_priorityCache, ticketsKey, d_enableTickets);
+    if (!d_protos.empty()) {
+      connection->setALPNProtos(d_protos);
+    }
+    return connection;
   }
 
   std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, int socket, const struct timeval& timeout) override
   {
-    return std::make_unique<GnuTLSConnection>(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts);
+    auto connection = std::make_unique<GnuTLSConnection>(host, socket, timeout, d_creds.get(), d_priorityCache, d_validateCerts);
+    if (!d_protos.empty()) {
+      connection->setALPNProtos(d_protos);
+    }
+    return connection;
   }
 
   void rotateTicketsKey(time_t now) override
@@ -1457,8 +1585,19 @@ public:
     return "gnutls";
   }
 
+  bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
+  {
+#ifdef HAVE_GNUTLS_ALPN_SET_PROTOCOLS
+    d_protos = protos;
+    return true;
+#else
+    return false;
+#endif
+  }
+
 private:
   std::unique_ptr<gnutls_certificate_credentials_st, void(*)(gnutls_certificate_credentials_t)> d_creds;
+  std::vector<std::vector<uint8_t>> d_protos;
   gnutls_priority_t d_priorityCache{nullptr};
   SharedLockGuarded<std::shared_ptr<GnuTLSTicketsKey>> d_ticketsKey{nullptr};
   bool d_enableTickets{true};
@@ -1469,6 +1608,17 @@ private:
 
 #endif /* HAVE_DNS_OVER_TLS */
 
+bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx)
+{
+  if (ctx == nullptr) {
+    return false;
+  }
+  /* we want to set the ALPN to dot (RFC7858), if only to mitigate the ALPACA attack */
+  const std::vector<std::vector<uint8_t>> dotAlpns = {{'d', 'o', 't'}};
+  ctx->setALPNProtos(dotAlpns);
+  return true;
+}
+
 bool TLSFrontend::setupTLS()
 {
 #ifdef HAVE_DNS_OVER_TLS
@@ -1478,6 +1628,7 @@ bool TLSFrontend::setupTLS()
 #ifdef HAVE_GNUTLS
     if (d_provider == "gnutls") {
       newCtx = std::make_shared<GnuTLSIOCtx>(*this);
+      setupDoTProtocolNegotiation(newCtx);
       std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
       return true;
     }
@@ -1485,6 +1636,7 @@ bool TLSFrontend::setupTLS()
 #ifdef HAVE_LIBSSL
     if (d_provider == "openssl") {
       newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
+      setupDoTProtocolNegotiation(newCtx);
       std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
       return true;
     }
@@ -1498,6 +1650,7 @@ bool TLSFrontend::setupTLS()
 #endif /* HAVE_GNUTLS */
 #endif /* HAVE_LIBSSL */
 
+  setupDoTProtocolNegotiation(newCtx);
   std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
 #endif /* HAVE_DNS_OVER_TLS */
   return true;
index e948b130aea5f2da27caea7c5ca5dd0c9039043b..fcde2c62ae13ec1692f4d4cc90c353d6df5448a4 100644 (file)
@@ -33,6 +33,7 @@ public:
   virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0;
   virtual bool hasBufferedData() const = 0;
   virtual std::string getServerNameIndication() const = 0;
+  virtual std::vector<uint8_t> getNextProtocol() const = 0;
   virtual LibsslTLSVersion getTLSVersion() const = 0;
   virtual bool hasSessionBeenResumed() const = 0;
   virtual std::unique_ptr<TLSSession> getSession() = 0;
@@ -111,6 +112,18 @@ public:
   virtual size_t getTicketsKeysCount() = 0;
   virtual std::string getName() const = 0;
 
+  /* set the advertised ALPN protocols, in client or server context */
+  virtual bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos)
+  {
+    return false;
+  }
+
+  /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
+  virtual bool setNextProtocolSelectCallback(bool(*)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen))
+  {
+    return false;
+  }
+
 protected:
   std::atomic_flag d_rotatingTicketsKey;
   std::atomic<time_t> d_ticketsKeyNextRotation{0};
@@ -465,6 +478,14 @@ public:
     return std::string();
   }
 
+  std::vector<uint8_t> getNextProtocol() const
+  {
+    if (d_conn) {
+      return d_conn->getNextProtocol();
+    }
+    return std::vector<uint8_t>();
+  }
+
   LibsslTLSVersion getTLSVersion() const
   {
     if (d_conn) {
@@ -528,3 +549,4 @@ struct TLSContextParameters
 };
 
 std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
+bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);