]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: More delinting of the DoH3 code 13556/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Dec 2023 10:58:06 +0000 (11:58 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 8 Dec 2023 08:24:09 +0000 (09:24 +0100)
pdns/dnsdist-doh-common.hh
pdns/dnsdist.cc
pdns/dnsdistdist/dnsdist-doh-common.cc
pdns/dnsdistdist/doh3.cc

index a5c8e968c02d8a2edc97228526532430c04f1792..0dc714df23a3c0273f3d6f6c748189cd4f3a0c4d 100644 (file)
@@ -33,7 +33,8 @@
 #include "stat_t.hh"
 #include "tcpiohandler.hh"
 
-namespace dnsdist::doh {
+namespace dnsdist::doh
+{
 std::optional<PacketBuffer> getPayloadFromPath(const std::string_view& path);
 }
 
index 3b1fb0735ef48a8bbd1b1a234f85f96130708361..4a7f386ab1cbefe9f1f69ce7ba4fd5626c9a484b 100644 (file)
@@ -2875,57 +2875,57 @@ static void startFrontends()
 {
   std::vector<ClientState*> tcpStates;
   std::vector<ClientState*> udpStates;
-  for (auto& cs : g_frontends) {
-    if (cs->dohFrontend != nullptr && cs->dohFrontend->d_library == "h2o") {
+  for (auto& clientState : g_frontends) {
+    if (clientState->dohFrontend != nullptr && clientState->dohFrontend->d_library == "h2o") {
 #ifdef HAVE_DNS_OVER_HTTPS
 #ifdef HAVE_LIBH2OEVLOOP
-      std::thread dotThreadHandle(dohThread, cs.get());
-      if (!cs->cpus.empty()) {
-        mapThreadToCPUList(dotThreadHandle.native_handle(), cs->cpus);
+      std::thread dotThreadHandle(dohThread, clientState.get());
+      if (!clientState->cpus.empty()) {
+        mapThreadToCPUList(dotThreadHandle.native_handle(), clientState->cpus);
       }
       dotThreadHandle.detach();
 #endif /* HAVE_LIBH2OEVLOOP */
 #endif /* HAVE_DNS_OVER_HTTPS */
         continue;
       }
-      if (cs->doqFrontend != nullptr) {
+      if (clientState->doqFrontend != nullptr) {
 #ifdef HAVE_DNS_OVER_QUIC
-        std::thread doqThreadHandle(doqThread, cs.get());
-        if (!cs->cpus.empty()) {
-          mapThreadToCPUList(doqThreadHandle.native_handle(), cs->cpus);
+        std::thread doqThreadHandle(doqThread, clientState.get());
+        if (!clientState->cpus.empty()) {
+          mapThreadToCPUList(doqThreadHandle.native_handle(), clientState->cpus);
         }
         doqThreadHandle.detach();
 #endif /* HAVE_DNS_OVER_QUIC */
         continue;
       }
-      if (cs->doh3Frontend != nullptr) {
+      if (clientState->doh3Frontend != nullptr) {
 #ifdef HAVE_DNS_OVER_HTTP3
-        std::thread doh3ThreadHandle(doh3Thread, cs.get());
-        if (!cs->cpus.empty()) {
-          mapThreadToCPUList(doh3ThreadHandle.native_handle(), cs->cpus);
+        std::thread doh3ThreadHandle(doh3Thread, clientState.get());
+        if (!clientState->cpus.empty()) {
+          mapThreadToCPUList(doh3ThreadHandle.native_handle(), clientState->cpus);
         }
         doh3ThreadHandle.detach();
 #endif /* HAVE_DNS_OVER_HTTP3 */
         continue;
       }
-      if (cs->udpFD >= 0) {
+      if (clientState->udpFD >= 0) {
 #ifdef USE_SINGLE_ACCEPTOR_THREAD
-        udpStates.push_back(cs.get());
+        udpStates.push_back(clientState.get());
 #else /* USE_SINGLE_ACCEPTOR_THREAD */
-        std::thread udpClientThreadHandle(udpClientThread, std::vector<ClientState*>{ cs.get() });
-        if (!cs->cpus.empty()) {
-          mapThreadToCPUList(udpClientThreadHandle.native_handle(), cs->cpus);
+        std::thread udpClientThreadHandle(udpClientThread, std::vector<ClientState*>{ clientState.get() });
+        if (!clientState->cpus.empty()) {
+          mapThreadToCPUList(udpClientThreadHandle.native_handle(), clientState->cpus);
         }
         udpClientThreadHandle.detach();
 #endif /* USE_SINGLE_ACCEPTOR_THREAD */
       }
-      else if (cs->tcpFD >= 0) {
+      else if (clientState->tcpFD >= 0) {
 #ifdef USE_SINGLE_ACCEPTOR_THREAD
-        tcpStates.push_back(cs.get());
+        tcpStates.push_back(clientState.get());
 #else /* USE_SINGLE_ACCEPTOR_THREAD */
-        std::thread tcpAcceptorThreadHandle(tcpAcceptorThread, std::vector<ClientState*>{cs.get() });
-        if (!cs->cpus.empty()) {
-          mapThreadToCPUList(tcpAcceptorThreadHandle.native_handle(), cs->cpus);
+        std::thread tcpAcceptorThreadHandle(tcpAcceptorThread, std::vector<ClientState*>{clientState.get() });
+        if (!clientState->cpus.empty()) {
+          mapThreadToCPUList(tcpAcceptorThreadHandle.native_handle(), clientState->cpus);
         }
         tcpAcceptorThreadHandle.detach();
 #endif /* USE_SINGLE_ACCEPTOR_THREAD */
index ef1c2780d42fda1805334bb2cdda98fc6f7bca11..71cd87cd0f2f7a9056789702a11fbbe290c6b4ae 100644 (file)
@@ -129,7 +129,8 @@ void DOHFrontend::setup()
 
 #endif /* HAVE_DNS_OVER_HTTPS */
 
-namespace dnsdist::doh {
+namespace dnsdist::doh
+{
 std::optional<PacketBuffer> getPayloadFromPath(const std::string_view& path)
 {
   std::optional<PacketBuffer> result{std::nullopt};
index bcaf454438b9f65566959d2d232e6e4de4a23176..a1e63947905bd7f40515f3f14d5f4fe7176b1b58 100644 (file)
@@ -267,24 +267,26 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const
 {
   std::string status = std::to_string(statusCode);
   std::string lenStr = std::to_string(len);
-  quiche_h3_header headers[] = {
-    {
+  std::array<quiche_h3_header, 2> headers{
+    (quiche_h3_header){
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
       .name = reinterpret_cast<const uint8_t*>(":status"),
       .name_len = sizeof(":status") - 1,
-
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
       .value = reinterpret_cast<const uint8_t*>(status.data()),
       .value_len = status.size(),
     },
-    {
+    (quiche_h3_header){
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
       .name = reinterpret_cast<const uint8_t*>("content-length"),
       .name_len = sizeof("content-length") - 1,
-
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
       .value = reinterpret_cast<const uint8_t*>(lenStr.data()),
       .value_len = lenStr.size(),
     },
   };
   quiche_h3_send_response(conn, quic_conn,
-                          streamID, headers, 2, len == 0);
+                          streamID, headers.data(), headers.size(), len == 0);
 
   if (len == 0) {
     return;
@@ -294,6 +296,7 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const
   while (pos < len) {
     // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
     auto res = quiche_h3_send_body(conn, quic_conn,
+                                   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
                                    streamID, const_cast<uint8_t*>(body) + pos, len - pos, true);
     if (res < 0) {
       // Shutdown with internal error code
@@ -306,10 +309,13 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const
 
 static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
 {
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
   h3_send_response(quic_conn, conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
 }
+
 static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
 {
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
   h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, statusCode, body, len);
 }
 
@@ -610,6 +616,107 @@ static void flushResponses(pdns::channel::Receiver<DOH3Unit>& receiver)
   }
 }
 
+static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
+{
+  auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
+    DEBUGLOG(msg);
+    ++dnsdist::metrics::g_stats.nonCompliantQueries;
+    ++clientState.nonCompliantQueries;
+    ++frontend.d_errorResponses;
+    h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
+  };
+
+  // Callback result. Any value other than 0 will interrupt further header processing.
+  int cbresult = quiche_h3_event_for_each_header(
+    event,
+    [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int {
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+      std::string_view key(reinterpret_cast<char*>(name), name_len);
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+      std::string_view content(reinterpret_cast<char*>(value), value_len);
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+      auto* headersptr = reinterpret_cast<std::map<std::string, std::string>*>(argp);
+      headersptr->emplace(key, content);
+      return 0;
+    },
+    &headers);
+
+  if (cbresult != 0 || headers.count(":method") == 0) {
+    handleImmediateError("Unable to process query headers");
+    return;
+  }
+
+  if (headers.at(":method") == "GET") {
+    if (headers.count(":path") == 0 || headers.at(":path").empty()) {
+      handleImmediateError("Path not found");
+      return;
+    }
+    const auto& path = headers.at(":path");
+    auto payload = dnsdist::doh::getPayloadFromPath(path);
+    if (!payload) {
+      handleImmediateError("Unable to find the DNS parameter");
+      return;
+    }
+    if (payload->size() < sizeof(dnsheader)) {
+      handleImmediateError("DoH3 non-compliant query");
+      return;
+    }
+    DEBUGLOG("Dispatching GET query");
+    doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID);
+    conn.d_streamBuffers.erase(streamID);
+    return;
+  }
+
+  if (headers.at(":method") == "POST") {
+    if (!quiche_h3_event_headers_has_body(event)) {
+      handleImmediateError("Empty POST query");
+    }
+    return;
+  }
+
+  handleImmediateError("Unsupported HTTP method");
+}
+
+static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
+{
+  auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
+    DEBUGLOG(msg);
+    ++dnsdist::metrics::g_stats.nonCompliantQueries;
+    ++clientState.nonCompliantQueries;
+    ++frontend.d_errorResponses;
+    h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
+  };
+
+  if (headers.at(":method") == "POST") {
+    if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
+      handleImmediateError("Unsupported content-type");
+      return;
+    }
+    PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
+    PacketBuffer decoded;
+
+    while (true) {
+      ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
+                                        conn.d_conn.get(), streamID,
+                                        buffer.data(), buffer.capacity());
+
+      if (len <= 0) {
+        break;
+      }
+      decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len);
+    }
+
+    if (decoded.size() < sizeof(dnsheader)) {
+      handleImmediateError("DoH3 non-compliant query");
+      return;
+    }
+
+    DEBUGLOG("Dispatching POST query");
+    doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID);
+    conn.d_streamBuffers.erase(streamID);
+  }
+}
+
 static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID)
 {
   std::map<std::string, std::string> headers;
@@ -626,114 +733,12 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3
 
     switch (quiche_h3_event_type(event)) {
     case QUICHE_H3_EVENT_HEADERS: {
-      // Callback result. Any value other than 0 will interrupt further header processing.
-      int cbresult = quiche_h3_event_for_each_header(
-        event,
-        [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int {
-          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
-          std::string_view key(reinterpret_cast<char*>(name), name_len);
-          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
-          std::string_view content(reinterpret_cast<char*>(value), value_len);
-          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
-          auto* headersptr = reinterpret_cast<std::map<std::string, std::string>*>(argp);
-          headersptr->emplace(key, content);
-          return 0;
-        },
-        &headers);
-
-      if (cbresult != 0 || headers.count(":method") == 0) {
-        DEBUGLOG("Failed to process headers");
-        ++dnsdist::metrics::g_stats.nonCompliantQueries;
-        ++clientState.nonCompliantQueries;
-        ++frontend.d_errorResponses;
-        h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unable to process query headers");
-        break;
-      }
-
-      if (headers.at(":method") == "GET") {
-        if (headers.count(":path") == 0 || headers.at(":path").empty()) {
-          DEBUGLOG("Path not found");
-          ++dnsdist::metrics::g_stats.nonCompliantQueries;
-          ++clientState.nonCompliantQueries;
-          ++frontend.d_errorResponses;
-          h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Path not found");
-          break;
-        }
-        const auto& path = headers.at(":path");
-        auto payload = dnsdist::doh::getPayloadFromPath(path);
-        if (!payload) {
-          DEBUGLOG("User error, unable to find the DNS parameter");
-          ++dnsdist::metrics::g_stats.nonCompliantQueries;
-          ++clientState.nonCompliantQueries;
-          ++frontend.d_errorResponses;
-          h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unable to find the DNS parameter");
-          break;
-        }
-        if (payload->size() < sizeof(dnsheader)) {
-          ++dnsdist::metrics::g_stats.nonCompliantQueries;
-          ++clientState.nonCompliantQueries;
-          ++frontend.d_errorResponses;
-          h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "DoH3 non-compliant query");
-          break;
-        }
-        DEBUGLOG("Dispatching GET query");
-        doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID);
-        conn.d_streamBuffers.erase(streamID);
-      }
-      else if (headers.at(":method") == "POST") {
-        if (!quiche_h3_event_headers_has_body(event)) {
-          DEBUGLOG("Empty POST query");
-          ++dnsdist::metrics::g_stats.nonCompliantQueries;
-          ++clientState.nonCompliantQueries;
-          ++frontend.d_errorResponses;
-          h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Empty POST query");
-          break;
-        }
-      }
-      else {
-        DEBUGLOG("Unsupported HTTP method");
-        ++dnsdist::metrics::g_stats.nonCompliantQueries;
-        ++clientState.nonCompliantQueries;
-        ++frontend.d_errorResponses;
-        h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unsupported HTTP method");
-        break;
-      }
+      processH3HeaderEvent(clientState, frontend, conn, client, serverConnID, headers, streamID, event);
       break;
     }
     case QUICHE_H3_EVENT_DATA: {
-      if (headers.at(":method") == "POST") {
-        if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
-          DEBUGLOG("Unsupported content-type");
-          ++dnsdist::metrics::g_stats.nonCompliantQueries;
-          ++clientState.nonCompliantQueries;
-          ++frontend.d_errorResponses;
-          h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unsupported content-type");
-          break;
-        }
-        PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
-        PacketBuffer decoded;
-
-        while (true) {
-          ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
-                                            conn.d_conn.get(), streamID,
-                                            buffer.data(), buffer.capacity());
-
-          if (len <= 0) {
-            break;
-          }
-          decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len);
-        }
-        if (decoded.size() < sizeof(dnsheader)) {
-          ++dnsdist::metrics::g_stats.nonCompliantQueries;
-          ++clientState.nonCompliantQueries;
-          ++frontend.d_errorResponses;
-          h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "DoH3 non-compliant query");
-          break;
-        }
-        DEBUGLOG("Dispatching POST query");
-        doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID);
-        conn.d_streamBuffers.erase(streamID);
-      }
+      processH3DataEvent(clientState, frontend, conn, client, serverConnID, headers, streamID, event);
+      break;
     }
     case QUICHE_H3_EVENT_FINISHED:
     case QUICHE_H3_EVENT_RESET:
@@ -746,7 +751,6 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3
   }
 }
 
-
 // this is the entrypoint from dnsdist.cc
 void doh3Thread(ClientState* clientState)
 {