]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Small optimizations for incoming DoH
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 11 May 2023 15:49:39 +0000 (17:49 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Sep 2023 07:19:18 +0000 (09:19 +0200)
pdns/dnsdistdist/dnsdist-nghttp2-in.cc
pdns/dnsdistdist/dnsdist-nghttp2-in.hh

index 578cdc6d731dc51cf233b4b149fe84f4ec626b3d..95b5705fd1422d3173a2c43fe8cdde53b92e9937 100644 (file)
@@ -96,7 +96,7 @@ class IncomingDoHCrossProtocolContext : public DOHUnitInterface
 {
 public:
   IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr<IncomingHTTP2Connection> connection, IncomingHTTP2Connection::StreamID streamID) :
-    d_connection(connection), d_query(std::move(query)), d_streamID(streamID)
+    d_connection(std::move(connection)), d_query(std::move(query)), d_streamID(streamID)
   {
   }
 
@@ -458,7 +458,7 @@ IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResp
     responseBuffer = std::move(response.d_buffer);
   }
 
-  sendResponse(response.d_idstate.d_streamID, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType);
+  sendResponse(response.d_idstate.d_streamID, context, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType);
   handleResponseSent(response);
 
   return IOState::Done;
@@ -474,11 +474,12 @@ void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPRespon
   }
 
   assert(response.d_idstate.d_streamID != -1);
-  d_currentStreams.at(response.d_idstate.d_streamID).d_buffer = std::move(response.d_buffer);
-  sendResponse(response.d_idstate.d_streamID, 502, d_ci.cs->dohFrontend->d_customResponseHeaders);
+  auto& context = d_currentStreams.at(response.d_idstate.d_streamID);
+  context.d_buffer = std::move(response.d_buffer);
+  sendResponse(response.d_idstate.d_streamID, context, 502, d_ci.cs->dohFrontend->d_customResponseHeaders);
 }
 
-bool IncomingHTTP2Connection::sendResponse(IncomingHTTP2Connection::StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType, bool addContentType)
+bool IncomingHTTP2Connection::sendResponse(IncomingHTTP2Connection::StreamID streamID, IncomingHTTP2Connection::PendingQuery& context, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType, bool addContentType)
 {
   /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set.
    */
@@ -498,12 +499,13 @@ bool IncomingHTTP2Connection::sendResponse(IncomingHTTP2Connection::StreamID str
 
     if (obj.d_queryPos >= obj.d_buffer.size()) {
       *data_flags |= NGHTTP2_DATA_FLAG_EOF;
+      obj.d_buffer.clear();
     }
     return toCopy;
   };
 
   const auto& df = d_ci.cs->dohFrontend;
-  auto& responseBody = d_currentStreams.at(streamID).d_buffer;
+  auto& responseBody = context.d_buffer;
 
   std::vector<nghttp2_nv> headers;
   std::string responseCodeStr;
@@ -676,19 +678,34 @@ static std::optional<PacketBuffer> getPayloadFromPath(const std::string_view& pa
   }
 
   // need to base64url decode this
-  string sdns(path.substr(pos + 5));
-  boost::replace_all(sdns, "-", "+");
-  boost::replace_all(sdns, "_", "/");
-
-  // re-add padding that may have been missing
-  switch (sdns.size() % 4) {
+  string sdns;
+  const size_t payloadSize = path.size() - pos - 5;
+  size_t neededPadding = 0;
+  switch (payloadSize % 4) {
   case 2:
-    sdns.append(2, '=');
+    neededPadding = 2;
     break;
   case 3:
-    sdns.append(1, '=');
+    neededPadding = 1;
     break;
   }
+  sdns.reserve(payloadSize + neededPadding);
+  sdns = path.substr(pos + 5);
+  for (auto &entry : sdns) {
+    switch (entry) {
+    case '-':
+      entry = '+';
+      break;
+    case '_':
+      entry = '/';
+      break;
+    }
+  }
+
+  if (neededPadding) {
+    // re-add padding that may have been missing
+    sdns.append(neededPadding, '=');
+  }
 
   PacketBuffer decoded;
   /* rough estimate so we hopefully don't need a new allocation later */
@@ -714,7 +731,7 @@ void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::Pendi
       query.d_buffer = std::move(response);
     }
     vinfolog("Sending an immediate %d response to incoming DoH query: %s", code, reason);
-    sendResponse(streamID, code, d_ci.cs->dohFrontend->d_customResponseHeaders);
+    sendResponse(streamID, query, code, d_ci.cs->dohFrontend->d_customResponseHeaders);
   };
 
   ++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries;
@@ -769,7 +786,7 @@ void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::Pendi
           query.d_buffer.pop_back();
         }
 
-        sendResponse(streamID, entry->getStatusCode(), customHeaders ? *customHeaders : d_ci.cs->dohFrontend->d_customResponseHeaders, std::string(), false);
+        sendResponse(streamID, query, entry->getStatusCode(), customHeaders ? *customHeaders : d_ci.cs->dohFrontend->d_customResponseHeaders, std::string(), false);
         return;
       }
     }
index a8e68777c1efa6cd099fe3941f45df6ea7a129ba..1fa2b83a2ee5d91c200e6c9dcd8654cf4249708f 100644 (file)
@@ -88,7 +88,7 @@ private:
   void updateIO(IOState newState, FDMultiplexer::callbackfunc_t callback);
   void watchForRemoteHostClosingConnection();
   void handleIOError();
-  bool sendResponse(StreamID streamID, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true);
+  bool sendResponse(StreamID streamID, PendingQuery& context, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType = "", bool addContentType = true);
   void handleIncomingQuery(PendingQuery&& query, StreamID streamID);
   bool checkALPN();
   void readHTTPData();