]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Loop on `quiche_conn_stream_recv()` until done
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 15:47:24 +0000 (16:47 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 15:47:24 +0000 (16:47 +0100)
We might get more than one stream event in a single packet.

pdns/dnsdistdist/doq.cc

index 29d0dc696291b21c21a004bc99373fe625d384e0..4a43f6d108f43362f21a8419dbc210beb8bf3458 100644 (file)
@@ -582,32 +582,48 @@ static void flushStalledResponses(Connection& conn)
 static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState, Connection& conn, uint64_t streamID, const ComboAddress& client, const PacketBuffer& serverConnID)
 {
   auto& streamBuffer = conn.d_streamBuffers[streamID];
-  auto existingLength = streamBuffer.size();
-  bool fin = false;
-  streamBuffer.resize(existingLength + 512);
-  auto received = quiche_conn_stream_recv(conn.d_conn.get(), streamID,
-                                          &streamBuffer.at(existingLength), 512,
-                                          &fin);
-  streamBuffer.resize(existingLength + received);
-  if (fin) {
-    if (streamBuffer.size() < (sizeof(uint16_t) + sizeof(dnsheader))) {
-      ++dnsdist::metrics::g_stats.nonCompliantQueries;
-      ++clientState.nonCompliantQueries;
-      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+  while (true) {
+    bool fin = false;
+    auto existingLength = streamBuffer.size();
+    streamBuffer.resize(existingLength + 512);
+    auto received = quiche_conn_stream_recv(conn.d_conn.get(), streamID,
+                                            &streamBuffer.at(existingLength), 512,
+                                            &fin);
+    if (received == 0 || received == QUICHE_ERR_DONE) {
+      streamBuffer.resize(existingLength);
       return;
     }
-    uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1);
-    streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2);
-    if (payloadLength != streamBuffer.size()) {
+    else if (received < 0) {
       ++dnsdist::metrics::g_stats.nonCompliantQueries;
       ++clientState.nonCompliantQueries;
       quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
       return;
     }
-    DEBUGLOG("Dispatching query");
-    doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
-    conn.d_streamBuffers.erase(streamID);
+
+    streamBuffer.resize(existingLength + received);
+    if (fin) {
+      break;
+    }
+  }
+
+  if (streamBuffer.size() < (sizeof(uint16_t) + sizeof(dnsheader))) {
+    ++dnsdist::metrics::g_stats.nonCompliantQueries;
+    ++clientState.nonCompliantQueries;
+    quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+    return;
+  }
+
+  uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1);
+  streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2);
+  if (payloadLength != streamBuffer.size()) {
+    ++dnsdist::metrics::g_stats.nonCompliantQueries;
+    ++clientState.nonCompliantQueries;
+    quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+    return;
   }
+  DEBUGLOG("Dispatching query");
+  doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
+  conn.d_streamBuffers.erase(streamID);
 }
 
 // this is the entrypoint from dnsdist.cc