]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Split the DoQ 'readable stream' handling code to a function
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 15:45:37 +0000 (16:45 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 15:46:07 +0000 (16:46 +0100)
pdns/dnsdistdist/doq.cc

index 9b626aaf8822ec164a23884b53deeb5c27ee5f60..29d0dc696291b21c21a004bc99373fe625d384e0 100644 (file)
@@ -579,6 +579,37 @@ static void flushStalledResponses(Connection& conn)
   }
 }
 
+static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState, Connection& conn, uint64_t streamID, const ComboAddress& client, const PacketBuffer& serverConnID)
+{
+  auto& streamBuffer = conn.d_streamBuffers[streamID];
+  auto existingLength = streamBuffer.size();
+  bool fin = false;
+  streamBuffer.resize(existingLength + 512);
+  auto received = quiche_conn_stream_recv(conn.d_conn.get(), streamID,
+                                          &streamBuffer.at(existingLength), 512,
+                                          &fin);
+  streamBuffer.resize(existingLength + received);
+  if (fin) {
+    if (streamBuffer.size() < (sizeof(uint16_t) + sizeof(dnsheader))) {
+      ++dnsdist::metrics::g_stats.nonCompliantQueries;
+      ++clientState.nonCompliantQueries;
+      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+      return;
+    }
+    uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1);
+    streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2);
+    if (payloadLength != streamBuffer.size()) {
+      ++dnsdist::metrics::g_stats.nonCompliantQueries;
+      ++clientState.nonCompliantQueries;
+      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+      return;
+    }
+    DEBUGLOG("Dispatching query");
+    doq_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
+    conn.d_streamBuffers.erase(streamID);
+  }
+}
+
 // this is the entrypoint from dnsdist.cc
 void doqThread(ClientState* clientState)
 {
@@ -684,33 +715,7 @@ void doqThread(ClientState* clientState)
 
           uint64_t streamID = 0;
           while (quiche_stream_iter_next(readable.get(), &streamID)) {
-            auto& streamBuffer = conn->get().d_streamBuffers[streamID];
-            auto existingLength = streamBuffer.size();
-            bool fin = false;
-            streamBuffer.resize(existingLength + 512);
-            auto received = quiche_conn_stream_recv(conn->get().d_conn.get(), streamID,
-                                                    &streamBuffer.at(existingLength), 512,
-                                                    &fin);
-            streamBuffer.resize(existingLength + received);
-            if (fin) {
-              if (streamBuffer.size() < (sizeof(uint16_t) + sizeof(dnsheader))) {
-                ++dnsdist::metrics::g_stats.nonCompliantQueries;
-                ++clientState->nonCompliantQueries;
-                quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
-                break;
-              }
-              uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1);
-              streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2);
-              if (payloadLength != streamBuffer.size()) {
-                ++dnsdist::metrics::g_stats.nonCompliantQueries;
-                ++clientState->nonCompliantQueries;
-                quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
-                break;
-              }
-              DEBUGLOG("Dispatching query");
-              doq_dispatch_query(*(frontend->d_server_config), std::move(streamBuffer), clientState->local, client, serverConnID, streamID);
-              conn->get().d_streamBuffers.erase(streamID);
-            }
+            handleReadableStream(*frontend, *clientState, *conn, streamID, client, serverConnID);
           }
         }
         else {