]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Socket: Prevent alloc+copy in Socket::recvFromAsync()
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 16:28:53 +0000 (17:28 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Dec 2023 16:28:53 +0000 (17:28 +0100)
pdns/dnsdistdist/doh3.cc
pdns/dnsdistdist/doq.cc
pdns/dnsreplay.cc
pdns/sstuff.hh

index b6449cde37a71065f78fa1bcdeed7feaaad0d12f..2f588e97e6100045e70f6c830cf094c631019f22 100644 (file)
@@ -805,15 +805,15 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3
   }
 }
 
-static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientState, Socket& sock)
+static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientState, Socket& sock, PacketBuffer& buffer)
 {
   while (true) {
-    DEBUGLOG("Received datagram");
-    std::string bufferStr;
     ComboAddress client;
-    if (!sock.recvFromAsync(bufferStr, client) || bufferStr.size() == 0) {
+    buffer.resize(4096);
+    if (!sock.recvFromAsync(buffer, client) || buffer.size() == 0) {
       return;
     }
+    DEBUGLOG("Received DoH3 datagram of size "<<buffer.size()<<" from "<<client.toStringWithPort());
 
     uint32_t version{0};
     uint8_t type{0};
@@ -824,8 +824,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
     std::array<uint8_t, MAX_TOKEN_LEN> token{};
     size_t token_len = token.size();
 
-    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-    auto res = quiche_header_info(reinterpret_cast<const uint8_t*>(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN,
+    auto res = quiche_header_info(buffer.data(), buffer.size(), LOCAL_CONN_ID_LEN,
                                   &version, &type,
                                   scid.data(), &scid_len,
                                   dcid.data(), &dcid_len,
@@ -881,8 +880,7 @@ static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientStat
       clientState.local.getSocklen(),
     };
 
-    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-    auto done = quiche_conn_recv(conn->get().d_conn.get(), reinterpret_cast<uint8_t*>(bufferStr.data()), bufferStr.size(), &recv_info);
+    auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info);
     if (done < 0) {
       continue;
     }
@@ -927,13 +925,14 @@ void doh3Thread(ClientState* clientState)
     mplexer->addReadFD(sock.getHandle(), [](int, FDMultiplexer::funcparam_t&) {});
     mplexer->addReadFD(responseReceiverFD, [](int, FDMultiplexer::funcparam_t&) {});
     std::vector<int> readyFDs;
+    PacketBuffer buffer(4096);
     while (true) {
       readyFDs.clear();
       mplexer->getAvailableFDs(readyFDs, 500);
 
       try {
         if (std::find(readyFDs.begin(), readyFDs.end(), sock.getHandle()) != readyFDs.end()) {
-          handleSocketReadable(*frontend, *clientState, sock);
+          handleSocketReadable(*frontend, *clientState, sock, buffer);
         }
 
         if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {
@@ -953,7 +952,7 @@ void doh3Thread(ClientState* clientState)
             quiche_conn_stats(conn->second.d_conn.get(), &stats);
             quiche_conn_path_stats(conn->second.d_conn.get(), 0, &path_stats);
 
-            DEBUGLOG("Connection closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
+            DEBUGLOG("Connection (DoH3) closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
 #endif
             conn = frontend->d_server_config->d_connections.erase(conn);
           }
index 6ac7e24d12e30782d8d8aef9217f5552f0278dc0..1eb94520d8b62da8c33efe0baa6713ca0ff85c67 100644 (file)
@@ -626,15 +626,15 @@ static void handleReadableStream(DOQFrontend& frontend, ClientState& clientState
   conn.d_streamBuffers.erase(streamID);
 }
 
-static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState, Socket& sock)
+static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState, Socket& sock, PacketBuffer& buffer)
 {
   while (true) {
-    DEBUGLOG("Received datagram");
-    std::string bufferStr;
     ComboAddress client;
-    if (!sock.recvFromAsync(bufferStr, client) || bufferStr.size() == 0) {
+    buffer.resize(4096);
+    if (!sock.recvFromAsync(buffer, client) || buffer.size() == 0) {
       return;
     }
+    DEBUGLOG("Received DoQ datagram of size "<<buffer.size()<<" from "<<client.toStringWithPort());
 
     uint32_t version{0};
     uint8_t type{0};
@@ -645,8 +645,7 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState
     std::array<uint8_t, MAX_TOKEN_LEN> token{};
     size_t token_len = token.size();
 
-    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-    auto res = quiche_header_info(reinterpret_cast<const uint8_t*>(bufferStr.data()), bufferStr.size(), LOCAL_CONN_ID_LEN,
+    auto res = quiche_header_info(buffer.data(), buffer.size(), LOCAL_CONN_ID_LEN,
                                   &version, &type,
                                   scid.data(), &scid_len,
                                   dcid.data(), &dcid_len,
@@ -702,8 +701,7 @@ static void handleSocketReadable(DOQFrontend& frontend, ClientState& clientState
       clientState.local.getSocklen(),
     };
 
-    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-    auto done = quiche_conn_recv(conn->get().d_conn.get(), reinterpret_cast<uint8_t*>(bufferStr.data()), bufferStr.size(), &recv_info);
+    auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info);
     if (done < 0) {
       continue;
     }
@@ -742,13 +740,14 @@ void doqThread(ClientState* clientState)
     mplexer->addReadFD(sock.getHandle(), [](int, FDMultiplexer::funcparam_t&) {});
     mplexer->addReadFD(responseReceiverFD, [](int, FDMultiplexer::funcparam_t&) {});
     std::vector<int> readyFDs;
+    PacketBuffer buffer(4096);
     while (true) {
       readyFDs.clear();
       mplexer->getAvailableFDs(readyFDs, 500);
 
       try {
         if (std::find(readyFDs.begin(), readyFDs.end(), sock.getHandle()) != readyFDs.end()) {
-          handleSocketReadable(*frontend, *clientState, sock);
+          handleSocketReadable(*frontend, *clientState, sock, buffer);
         }
 
         if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {
@@ -768,7 +767,7 @@ void doqThread(ClientState* clientState)
             quiche_conn_stats(conn->second.d_conn.get(), &stats);
             quiche_conn_path_stats(conn->second.d_conn.get(), 0, &path_stats);
 
-            DEBUGLOG("Connection closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
+            DEBUGLOG("Connection (DoQ) closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
 #endif
             conn = frontend->d_server_config->d_connections.erase(conn);
           }
index 66d2fcb95885a5d5ebb8f168bb18aa59470a567b..36f0db835b4f25df50c4ca13992cdd2c9509ffef 100644 (file)
@@ -387,7 +387,7 @@ std::unique_ptr<Socket> s_socket = nullptr;
 static void receiveFromReference()
 try
 {
-  string packet;
+  PacketBuffer packet;
   ComboAddress remote;
   int res=waitForData(s_socket->getHandle(), g_timeoutMsec/1000, 1000*(g_timeoutMsec%1000));
 
@@ -397,7 +397,8 @@ try
   while (s_socket->recvFromAsync(packet, remote)) {
     try {
       s_weanswers++;
-      MOADNSParser mdp(false, packet.c_str(), packet.length());
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+      MOADNSParser mdp(false, reinterpret_cast<const char*>(packet.data()), packet.size());
       if(!mdp.d_header.qr) {
         cout<<"Received a question from our reference nameserver!"<<endl;
         continue;
index f0b186626216298c3506c7a78e2f57b83066655b..d1f2a5bc217a49c4e3fba67cf005db5c84504043 100644 (file)
@@ -38,9 +38,9 @@
 #include <boost/utility.hpp>
 #include <csignal>
 #include "namespaces.hh"
+#include "noinitvector.hh"
 
-
-typedef int ProtocolType; //!< Supported protocol types
+using ProtocolType = int; //!< Supported protocol types
 
 //! Representation of a Socket and many of the Berkeley functions available
 class Socket : public boost::noncopyable
@@ -173,54 +173,64 @@ public:
   /** For datagram sockets, receive a datagram and learn where it came from
       \param dgram Will be filled with the datagram
       \param ep Will be filled with the origin of the datagram */
-  void recvFrom(string &dgram, ComboAddress &ep)
+  void recvFrom(string &dgram, ComboAddress& remote)
   {
-    socklen_t remlen = sizeof(ep);
-    ssize_t bytes;
-    d_buffer.resize(s_buflen);
-    if((bytes=recvfrom(d_socket, &d_buffer[0], s_buflen, 0, reinterpret_cast<sockaddr *>(&ep) , &remlen)) <0)
-      throw NetworkError("After recvfrom: "+stringerror());
-
-    dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
+    socklen_t remlen = sizeof(remote);
+    if (dgram.size() < s_buflen) {
+      dgram.resize(s_buflen);
+    }
+    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+    auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr *>(&remote) , &remlen);
+    if (bytes < 0) {
+      throw NetworkError("After recvfrom: " + stringerror());
+    }
+    dgram.resize(static_cast<size_t>(bytes));
   }
 
-  bool recvFromAsync(string& dgram, ComboAddress& remote)
+  bool recvFromAsync(PacketBuffer& dgram, ComboAddress& remote)
   {
     socklen_t remlen = sizeof(remote);
-    d_buffer.resize(s_buflen);
-    const auto bytes = recvfrom(d_socket, d_buffer.data(), s_buflen, 0, reinterpret_cast<sockaddr *>(&remote), &remlen);
+    if (dgram.size() < s_buflen) {
+      dgram.resize(s_buflen);
+    }
+    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+    auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr *>(&remote), &remlen);
     if (bytes < 0) {
       if (errno != EAGAIN) {
         throw NetworkError("After async recvfrom: " + stringerror());
       }
-      return false;
+      else {
+        return false;
+      }
     }
-    dgram.assign(d_buffer, 0, static_cast<size_t>(bytes));
+    dgram.resize(static_cast<size_t>(bytes));
     return true;
   }
 
-
   //! For datagram sockets, send a datagram to a destination
-  void sendTo(const char* msg, size_t len, const ComboAddress &ep)
+  void sendTo(const char* msg, size_t len, const ComboAddress& remote)
   {
-    if(sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr *>(&ep), ep.getSocklen())<0)
-      throw NetworkError("After sendto: "+stringerror());
+    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+    if (sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr *>(&remote), remote.getSocklen()) < 0) {
+      throw NetworkError("After sendto: " + stringerror());
+    }
   }
 
   //! For connected datagram sockets, send a datagram
   void send(const std::string& msg)
   {
-    if(::send(d_socket, msg.c_str(), msg.size(), 0)<0)
+    if (::send(d_socket, msg.data(), msg.size(), 0) < 0) {
       throw NetworkError("After send: "+stringerror());
+    }
   }
 
 
   /** For datagram sockets, send a datagram to a destination
       \param dgram The datagram
-      \param ep The intended destination of the datagram */
-  void sendTo(const string &dgram, const ComboAddress &ep)
+      \param remote The intended destination of the datagram */
+  void sendTo(const string& dgram, const ComboAddress& remote)
   {
-    sendTo(dgram.c_str(), dgram.length(), ep);
+    sendTo(dgram.data(), dgram.length(), remote);
   }