]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Split DoQ 'socket readable' to a separate function
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 15:53:23 +0000 (16:53 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 15:53:23 +0000 (16:53 +0100)
pdns/dnsdistdist/doq.cc

index 4a43f6d108f43362f21a8419dbc210beb8bf3458..a89aca24e2d183db2d3c6fbc1363d8415c1ed86a 100644 (file)
@@ -626,6 +626,98 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState
   conn.d_streamBuffers.erase(streamID);
 }
 
+static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState, Socket& sock)
+{
+  DEBUGLOG("Received datagram");
+  std::string bufferStr;
+  ComboAddress client;
+  sock.recvFrom(bufferStr, client);
+
+  uint32_t version{0};
+  uint8_t type{0};
+  std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> scid{};
+  size_t scid_len = scid.size();
+  std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> dcid{};
+  size_t dcid_len = dcid.size();
+  std::array<uint8_t, MAX_TOKEN_LEN> token{};
+  size_t token_len = token.size();
+
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+  auto res = quiche_header_info(reinterpret_cast<const uint8_t*>(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN,
+                                &version, &type,
+                                scid.data(), &scid_len,
+                                dcid.data(), &dcid_len,
+                                token.data(), &token_len);
+  if (res != 0) {
+    DEBUGLOG("Error in quiche_header_info: " << res);
+    return;
+  }
+
+  // destination connection ID, will have to be sent as original destination connection ID
+  PacketBuffer serverConnID(dcid.begin(), dcid.begin() + dcid_len);
+  // source connection ID, will have to be sent as destination connection ID
+  PacketBuffer clientConnID(scid.begin(), scid.begin() + scid_len);
+  auto conn = getConnection(frontend.d_server_config->d_connections, serverConnID);
+
+  if (!conn) {
+    DEBUGLOG("Connection not found");
+    if (!quiche_version_is_supported(version)) {
+      DEBUGLOG("Unsupported version");
+      ++frontend.d_doqUnsupportedVersionErrors;
+      handleVersionNegociation(sock, clientConnID, serverConnID, client);
+      return;
+    }
+
+    if (token_len == 0) {
+      /* stateless retry */
+      DEBUGLOG("No token received");
+      handleStatelessRetry(sock, clientConnID, serverConnID, client, version);
+      return;
+    }
+
+    PacketBuffer tokenBuf(token.begin(), token.begin() + token_len);
+    auto originalDestinationID = validateToken(tokenBuf, client);
+    if (!originalDestinationID) {
+      ++frontend.d_doqInvalidTokensReceived;
+      DEBUGLOG("Discarding invalid token");
+      return;
+    }
+
+    DEBUGLOG("Creating a new connection");
+    conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, clientState.local, client);
+    if (!conn) {
+      return;
+    }
+  }
+  DEBUGLOG("Connection found");
+  quiche_recv_info recv_info = {
+    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+    reinterpret_cast<struct sockaddr*>(&client),
+    client.getSocklen(),
+    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+    reinterpret_cast<struct sockaddr*>(&clientState.local),
+    clientState.local.getSocklen(),
+  };
+
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+  auto done = quiche_conn_recv(conn->get().d_conn.get(), reinterpret_cast<uint8_t*>(bufferStr.data()), bufferStr.size(), &recv_info);
+  if (done < 0) {
+    return;
+  }
+
+  if (quiche_conn_is_established(conn->get().d_conn.get())) {
+    auto readable = std::unique_ptr<quiche_stream_iter, decltype(&quiche_stream_iter_free)>(quiche_conn_readable(conn->get().d_conn.get()), quiche_stream_iter_free);
+
+    uint64_t streamID = 0;
+    while (quiche_stream_iter_next(readable.get(), &streamID)) {
+      handleReadableStream(frontend, clientState, *conn, streamID, client, serverConnID);
+    }
+  }
+  else {
+    DEBUGLOG("Connection not established");
+  }
+}
+
 // this is the entrypoint from dnsdist.cc
 void doqThread(ClientState* clientState)
 {
@@ -649,94 +741,7 @@ void doqThread(ClientState* clientState)
       mplexer->getAvailableFDs(readyFDs, 500);
 
       if (std::find(readyFDs.begin(), readyFDs.end(), sock.getHandle()) != readyFDs.end()) {
-        DEBUGLOG("Received datagram");
-        std::string bufferStr;
-        ComboAddress client;
-        sock.recvFrom(bufferStr, client);
-
-        uint32_t version{0};
-        uint8_t type{0};
-        std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> scid{};
-        size_t scid_len = scid.size();
-        std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> dcid{};
-        size_t dcid_len = dcid.size();
-        std::array<uint8_t, MAX_TOKEN_LEN> token{};
-        size_t token_len = token.size();
-
-        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-        auto res = quiche_header_info(reinterpret_cast<const uint8_t*>(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN,
-                                      &version, &type,
-                                      scid.data(), &scid_len,
-                                      dcid.data(), &dcid_len,
-                                      token.data(), &token_len);
-        if (res != 0) {
-          DEBUGLOG("Error in quiche_header_info: " << res);
-          continue;
-        }
-
-        // destination connection ID, will have to be sent as original destination connection ID
-        PacketBuffer serverConnID(dcid.begin(), dcid.begin() + dcid_len);
-        // source connection ID, will have to be sent as destination connection ID
-        PacketBuffer clientConnID(scid.begin(), scid.begin() + scid_len);
-        auto conn = getConnection(frontend->d_server_config->d_connections, serverConnID);
-
-        if (!conn) {
-          DEBUGLOG("Connection not found");
-          if (!quiche_version_is_supported(version)) {
-            DEBUGLOG("Unsupported version");
-            ++frontend->d_doqUnsupportedVersionErrors;
-            handleVersionNegociation(sock, clientConnID, serverConnID, client);
-            continue;
-          }
-
-          if (token_len == 0) {
-            /* stateless retry */
-            DEBUGLOG("No token received");
-            handleStatelessRetry(sock, clientConnID, serverConnID, client, version);
-            continue;
-          }
-
-          PacketBuffer tokenBuf(token.begin(), token.begin() + token_len);
-          auto originalDestinationID = validateToken(tokenBuf, client);
-          if (!originalDestinationID) {
-            ++frontend->d_doqInvalidTokensReceived;
-            DEBUGLOG("Discarding invalid token");
-            continue;
-          }
-
-          DEBUGLOG("Creating a new connection");
-          conn = createConnection(*frontend->d_server_config, serverConnID, *originalDestinationID, clientState->local, client);
-          if (!conn) {
-            continue;
-          }
-        }
-        DEBUGLOG("Connection found");
-        quiche_recv_info recv_info = {
-          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-          reinterpret_cast<struct sockaddr*>(&client),
-          client.getSocklen(),
-          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-          reinterpret_cast<struct sockaddr*>(&clientState->local),
-          clientState->local.getSocklen(),
-        };
-
-        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-        auto done = quiche_conn_recv(conn->get().d_conn.get(), reinterpret_cast<uint8_t*>(bufferStr.data()), bufferStr.size(), &recv_info);
-        if (done < 0) {
-          continue;
-        }
-
-        if (quiche_conn_is_established(conn->get().d_conn.get())) {
-          auto readable = std::unique_ptr<quiche_stream_iter, decltype(&quiche_stream_iter_free)>(quiche_conn_readable(conn->get().d_conn.get()), quiche_stream_iter_free);
-
-          uint64_t streamID = 0;
-          while (quiche_stream_iter_next(readable.get(), &streamID)) {
-            handleReadableStream(*frontend, *clientState, *conn, streamID, client, serverConnID);
-          }
-        }
-        else {
-          DEBUGLOG("Connection not established");
-        }
+        handleSocketReadable(*frontend, *clientState, sock);
       }
 
       if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {