]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Better handling of short reads/writes in DoQ
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 25 Sep 2023 13:37:39 +0000 (15:37 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Oct 2023 11:38:01 +0000 (13:38 +0200)
pdns/dnsdistdist/doq.cc

index 2a16a052a29830935026c6f9e8c6d7c511aaf964..cdb12919c0259cb83477273cac87bc77369e3cda 100644 (file)
@@ -64,6 +64,7 @@ public:
 
   ComboAddress d_peer;
   QuicheConnection d_conn;
+  std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
 };
 
 static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description);
@@ -261,19 +262,46 @@ private:
 
 std::shared_ptr<DOQTCPCrossQuerySender> DOQCrossProtocolQuery::s_sender = std::make_shared<DOQTCPCrossQuerySender>();
 
+/* from rfc9250 section-4.3 */
+enum class DOQ_Error_Codes : uint64_t {
+  DOQ_NO_ERROR = 0,
+  DOQ_INTERNAL_ERROR = 1,
+  DOQ_PROTOCOL_ERROR = 2,
+  DOQ_REQUEST_CANCELLED = 3,
+  DOQ_EXCESSIVE_LOAD = 4,
+  DOQ_UNSPECIFIED_ERROR = 5
+};
+
 static void handleResponse(DOQFrontend& df, Connection& conn, const uint64_t streamID, const PacketBuffer& response)
 {
   if (response.size() == 0) {
-    quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, 0x5);
+    quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
+    return;
   }
-  else {
-    uint16_t responseSize = static_cast<uint16_t>(response.size());
-    const uint8_t sizeBytes[] = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
-    auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, sizeBytes, sizeof(sizeBytes), false);
-    if (res == sizeof(sizeBytes)) {
-      res = quiche_conn_stream_send(conn.d_conn.get(), streamID, response.data(), response.size(), true);
+
+  uint16_t responseSize = static_cast<uint16_t>(response.size());
+  const std::array<uint8_t, 2> sizeBytes = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
+  size_t pos = 0;
+  do {
+    auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, sizeBytes.data() + pos, sizeBytes.size() - pos, false);
+    if (res < 0) {
+      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+      return;
     }
+    pos += res;
   }
+  while (pos < sizeBytes.size());
+
+  pos = 0;
+  do {
+    auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, response.data() + pos, response.size() - pos, true);
+    if (res < 0) {
+      quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+      return;
+    }
+    pos += res;
+  }
+  while (pos < response.size());
 }
 
 static void fillRandom(PacketBuffer& buffer, size_t size)
@@ -755,7 +783,7 @@ void doqThread(ClientState* cs)
 
     Socket sock(cs->udpFD);
 
-    PacketBuffer buffer(std::numeric_limits<unsigned short>::max());
+    PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
     auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
 
     auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
@@ -840,29 +868,33 @@ void doqThread(ClientState* cs)
 
           uint64_t streamID = 0;
           while (quiche_stream_iter_next(readable.get(), &streamID)) {
+            auto& streamBuffer = conn->get().d_streamBuffers[streamID];
+            auto existingLength = streamBuffer.size();
             bool fin = false;
-            buffer.resize(std::numeric_limits<unsigned short>::max());
+            streamBuffer.resize(existingLength + 512);
             auto received = quiche_conn_stream_recv(conn->get().d_conn.get(), streamID,
-                                                    buffer.data(), buffer.size(),
+                                                    &streamBuffer.at(existingLength), 512,
                                                     &fin);
-            if (received < 2) {
-              break;
-            }
-            buffer.resize(received);
-
+            streamBuffer.resize(existingLength + received);
             if (fin) {
-              // we skip message length, should we verify ?
-              buffer.erase(buffer.begin(), buffer.begin() + 2);
-              if (buffer.size() >= sizeof(dnsheader)) {
-                doq_dispatch_query(*(frontend->d_server_config), std::move(buffer), cs->local, client, serverConnID, streamID);
+              if (streamBuffer.size() < (sizeof(dnsheader) + sizeof(uint16_t))) {
+                quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+                break;
+              }
+              uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1);
+              streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2);
+              if (payloadLength != streamBuffer.size()) {
+                quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR));
+                break;
               }
+              doq_dispatch_query(*(frontend->d_server_config), std::move(streamBuffer), cs->local, client, serverConnID, streamID);
+              conn->get().d_streamBuffers.erase(streamID);
             }
           }
         }
         else {
           DEBUGLOG("Connection not established");
         }
-        // }
       }
 
       if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {