]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Handle congested DoQ streams 13638/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 15 Dec 2023 15:56:23 +0000 (16:56 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 15 Dec 2023 16:03:55 +0000 (17:03 +0100)
If the stream has no capacity left, Quiche will refuse to queue
more data and return `QUICHE_ERR_DONE`. We then have to wait until
the stream becomes writable again to retry sending our response.

pdns/dnsdistdist/doh3.cc
pdns/dnsdistdist/doq.cc

index a1e63947905bd7f40515f3f14d5f4fe7176b1b58..e3c14ccee8cb74b34fa675d8e655736b245ffab7 100644 (file)
@@ -67,6 +67,7 @@ public:
   QuicheConnection d_conn;
   QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
   std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
+  std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
 };
 
 static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
@@ -263,7 +264,29 @@ private:
 
 std::shared_ptr<DOH3TCPCrossQuerySender> DOH3CrossProtocolQuery::s_sender = std::make_shared<DOH3TCPCrossQuerySender>();
 
-static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
+static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, PacketBuffer& response)
+{
+  size_t pos = 0;
+  while (pos < response.size()) {
+    // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
+    auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
+                                   streamID, &response.at(pos), response.size() - pos, true);
+    if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
+      response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
+      return false;
+    }
+    if (res < 0) {
+      // Shutdown with internal error code
+      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+      return true;
+    }
+    pos += res;
+  }
+
+  return true;
+}
+
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
 {
   std::string status = std::to_string(statusCode);
   std::string lenStr = std::to_string(len);
@@ -285,8 +308,13 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const
       .value_len = lenStr.size(),
     },
   };
-  quiche_h3_send_response(conn, quic_conn,
-                          streamID, headers.data(), headers.size(), len == 0);
+  auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
+                                             streamID, headers.data(), headers.size(), len == 0);
+  if (returnValue != 0) {
+    /* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
+    quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+    return;
+  }
 
   if (len == 0) {
     return;
@@ -295,28 +323,27 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const
   size_t pos = 0;
   while (pos < len) {
     // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
-    auto res = quiche_h3_send_body(conn, quic_conn,
+    auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
                                    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
                                    streamID, const_cast<uint8_t*>(body) + pos, len - pos, true);
+    if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
+      conn.d_streamOutBuffers[streamID] = PacketBuffer(body + pos, body + len);
+      return;
+    }
     if (res < 0) {
       // Shutdown with internal error code
-      quiche_conn_stream_shutdown(quic_conn, streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
+      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
       return;
     }
     pos += res;
   }
 }
 
-static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
 {
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
-  h3_send_response(quic_conn, conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
-}
-
-static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
-{
-  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
-  h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, statusCode, body, len);
+  h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
 }
 
 static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
@@ -616,6 +643,21 @@ static void flushResponses(pdns::channel::Receiver<DOH3Unit>& receiver)
   }
 }
 
+static void flushStalledResponses(H3Connection& conn)
+{
+  for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
+    const auto streamID = streamIt->first;
+    auto& response = streamIt->second;
+    if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
+      if (tryWriteResponse(conn, streamID, response)) {
+        streamIt = conn.d_streamOutBuffers.erase(streamIt);
+        continue;
+      }
+    }
+    ++streamIt;
+  }
+}
+
 static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
 {
   auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
@@ -623,7 +665,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten
     ++dnsdist::metrics::g_stats.nonCompliantQueries;
     ++clientState.nonCompliantQueries;
     ++frontend.d_errorResponses;
-    h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
+    h3_send_response(conn, streamID, 400, msg);
   };
 
   // Callback result. Any value other than 0 will interrupt further header processing.
@@ -684,37 +726,49 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
     ++dnsdist::metrics::g_stats.nonCompliantQueries;
     ++clientState.nonCompliantQueries;
     ++frontend.d_errorResponses;
-    h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
+    h3_send_response(conn, streamID, 400, msg);
   };
 
-  if (headers.at(":method") == "POST") {
-    if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
-      handleImmediateError("Unsupported content-type");
-      return;
-    }
-    PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
-    PacketBuffer decoded;
+  if (headers.at(":method") != "POST") {
+    handleImmediateError("DATA frame for non-POST method");
+    return;
+  }
 
-    while (true) {
-      ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
-                                        conn.d_conn.get(), streamID,
-                                        buffer.data(), buffer.capacity());
+  if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
+    handleImmediateError("Unsupported content-type");
+    return;
+  }
 
-      if (len <= 0) {
-        break;
-      }
-      decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len);
-    }
+  PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
+  auto& streamBuffer = conn.d_streamBuffers[streamID];
 
-    if (decoded.size() < sizeof(dnsheader)) {
-      handleImmediateError("DoH3 non-compliant query");
-      return;
+  while (true) {
+    buffer.resize(std::numeric_limits<uint16_t>::max());
+    ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
+                                      conn.d_conn.get(), streamID,
+                                      buffer.data(), buffer.capacity());
+
+    if (len <= 0) {
+      break;
     }
 
-    DEBUGLOG("Dispatching POST query");
-    doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID);
+    buffer.resize(static_cast<size_t>(len));
+    streamBuffer.insert(streamBuffer.end(), buffer.begin(), buffer.end());
+  }
+
+  if (!quiche_conn_stream_finished(conn.d_conn.get(), streamID)) {
+    return;
+  }
+
+  if (streamBuffer.size() < sizeof(dnsheader)) {
     conn.d_streamBuffers.erase(streamID);
+    handleImmediateError("DoH3 non-compliant query");
+    return;
   }
+
+  DEBUGLOG("Dispatching POST query");
+  doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
+  conn.d_streamBuffers.erase(streamID);
 }
 
 static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID)
@@ -892,6 +946,7 @@ void doh3Thread(ClientState* clientState)
           conn = frontend->d_server_config->d_connections.erase(conn);
         }
         else {
+          flushStalledResponses(conn->second);
           ++conn;
         }
       }
index 03ba0f25715845cd55565aa9cf3a930b1feb2087..9b626aaf8822ec164a23884b53deeb5c27ee5f60 100644 (file)
@@ -65,6 +65,7 @@ public:
   ComboAddress d_peer;
   QuicheConnection d_conn;
   std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
+  std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
 };
 
 static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
@@ -260,7 +261,26 @@ private:
 
 std::shared_ptr<DOQTCPCrossQuerySender> DOQCrossProtocolQuery::s_sender = std::make_shared<DOQTCPCrossQuerySender>();
 
-static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, const PacketBuffer& response)
+static bool tryWriteResponse(Connection& conn, const uint64_t streamID, PacketBuffer& response)
+{
+  size_t pos = 0;
+  while (pos < response.size()) {
+    auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true);
+    if (res == QUICHE_ERR_DONE) {
+      response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
+      return false;
+    }
+    if (res < 0) {
+      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+      return true;
+    }
+    pos += res;
+  }
+
+  return true;
+}
+
+static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, PacketBuffer& response)
 {
   if (response.empty()) {
     ++frontend.d_errorResponses;
@@ -270,25 +290,9 @@ static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64
   ++frontend.d_validResponses;
   auto responseSize = static_cast<uint16_t>(response.size());
   const std::array<uint8_t, 2> sizeBytes = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
-  size_t pos = 0;
-  while (pos < sizeBytes.size()) {
-    auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &sizeBytes.at(pos), sizeBytes.size() - pos, false);
-    if (res < 0) {
-      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
-      return;
-    }
-    pos += res;
-  }
-
-  pos = 0;
-  while (pos < response.size()) {
-    // stream_send sets fin to false itself when the capacity of the stream is less than the desired writing length
-    auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true);
-    if (res < 0) {
-      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
-      return;
-    }
-    pos += res;
+  response.insert(response.begin(), sizeBytes.begin(), sizeBytes.end());
+  if (!tryWriteResponse(conn, streamID, response)) {
+    conn.d_streamOutBuffers[streamID] = std::move(response);
   }
 }
 
@@ -560,6 +564,21 @@ static void flushResponses(pdns::channel::Receiver<DOQUnit>& receiver)
   }
 }
 
+static void flushStalledResponses(Connection& conn)
+{
+  for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
+    const auto& streamID = streamIt->first;
+    auto& response = streamIt->second;
+    if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
+      if (tryWriteResponse(conn, streamID, response)) {
+        streamIt = conn.d_streamOutBuffers.erase(streamIt);
+        continue;
+      }
+    }
+    ++streamIt;
+  }
+}
+
 // this is the entrypoint from dnsdist.cc
 void doqThread(ClientState* clientState)
 {
@@ -721,6 +740,7 @@ void doqThread(ClientState* clientState)
           conn = frontend->d_server_config->d_connections.erase(conn);
         }
         else {
+          flushStalledResponses(conn->second);
           ++conn;
         }
       }