]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Enforce concurrent streams count for pending queries
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 19 Jun 2025 09:18:27 +0000 (11:18 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 28 Aug 2025 08:36:42 +0000 (10:36 +0200)
The gist of the `MadeYouRest` attack is that streams can be reset
by the client, and thus no longer count towards the maximum number
of a streams as far as the library is concerned, while the server
is still processing the query and doing actual work.
This pull request introduces a counter for "killed but still being
processed streams" to prevent it.

Signed-off-by: Remi Gacogne <remi.gacogne@powerdns.com>
(cherry picked from commit 0214032c5f09fcfb440e5c5120f1491cb4f0fda4)

pdns/dnsdistdist/dnsdist-nghttp2-in.cc
pdns/dnsdistdist/dnsdist-nghttp2-in.hh

index 6c36f6bf7405eb0a18c985079cf83c10627ba9e5..2c96e067ebfeda8ddb3e547b6e151a2dbdeea4b7 100644 (file)
@@ -93,6 +93,8 @@ private:
 };
 #endif
 
+static constexpr uint32_t MAX_CONCURRENT_STREAMS{100U};
+
 class IncomingDoHCrossProtocolContext : public DOHUnitInterface
 {
 public:
@@ -288,7 +290,7 @@ bool IncomingHTTP2Connection::checkALPN()
 
 void IncomingHTTP2Connection::handleConnectionReady()
 {
-  constexpr std::array<nghttp2_settings_entry, 1> settings{{{NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100U}}};
+  constexpr std::array<nghttp2_settings_entry, 1> settings{{{NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, MAX_CONCURRENT_STREAMS}}};
   auto ret = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, settings.data(), settings.size());
   if (ret != 0) {
     throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret)));
@@ -547,12 +549,22 @@ void NGHTTP2Headers::addDynamicHeader(std::vector<nghttp2_nv>& headers, NGHTTP2H
   NGHTTP2Headers::addCustomDynamicHeader(headers, name, value);
 }
 
+std::unordered_map<IncomingHTTP2Connection::StreamID, IncomingHTTP2Connection::PendingQuery>::iterator IncomingHTTP2Connection::getStreamContext(StreamID streamID)
+{
+  auto streamIt = d_currentStreams.find(streamID);
+  if (streamIt == d_currentStreams.end()) {
+    /* it might have been closed by the remote end in the meantime */
+    d_killedStreams.erase(streamID);
+  }
+  return streamIt;
+}
+
 IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResponse&& response)
 {
   if (response.d_idstate.d_streamID == -1) {
     throw std::runtime_error("Invalid DoH stream ID while sending response");
   }
-  auto streamIt = d_currentStreams.find(response.d_idstate.d_streamID);
+  auto streamIt = getStreamContext(response.d_idstate.d_streamID);
   if (streamIt == d_currentStreams.end()) {
     /* it might have been closed by the remote end in the meantime */
     return hasPendingWrite() ? IOState::NeedWrite : IOState::Done;
@@ -592,7 +604,7 @@ void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPRespon
     throw std::runtime_error("Invalid DoH stream ID while handling I/O error notification");
   }
 
-  auto streamIt = d_currentStreams.find(response.d_idstate.d_streamID);
+  auto streamIt = getStreamContext(response.d_idstate.d_streamID);
   if (streamIt == d_currentStreams.end()) {
     /* it might have been closed by the remote end in the meantime */
     return;
@@ -921,7 +933,7 @@ int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, co
   /* is this the last frame for this stream? */
   if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) {
     auto streamID = frame->hd.stream_id;
-    auto stream = conn->d_currentStreams.find(streamID);
+    auto stream = conn->getStreamContext(streamID);
     if (stream != conn->d_currentStreams.end()) {
       conn->handleIncomingQuery(std::move(stream->second), streamID);
     }
@@ -941,7 +953,9 @@ int IncomingHTTP2Connection::on_stream_close_callback(nghttp2_session* session,
 {
   auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
 
-  conn->d_currentStreams.erase(stream_id);
+  if (conn->d_currentStreams.erase(stream_id) > 0) {
+    conn->d_killedStreams.emplace(stream_id);
+  }
   return 0;
 }
 
@@ -952,20 +966,29 @@ int IncomingHTTP2Connection::on_begin_headers_callback(nghttp2_session* session,
   }
 
   auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
-  auto insertPair = conn->d_currentStreams.emplace(frame->hd.stream_id, PendingQuery());
-  if (!insertPair.second) {
-    /* there is a stream ID collision, something is very wrong! */
-    vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort());
-    conn->d_connectionClosing = true;
-    conn->d_needFlush = true;
-    nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
-    auto ret = nghttp2_session_send(conn->d_session.get());
+  auto close_connection = [](IncomingHTTP2Connection* connection, int32_t streamID, const ComboAddress& remote) -> int {
+    connection->d_connectionClosing = true;
+    connection->d_needFlush = true;
+    nghttp2_session_terminate_session(connection->d_session.get(), NGHTTP2_REFUSED_STREAM);
+    auto ret = nghttp2_session_send(connection->d_session.get());
     if (ret != 0) {
-      vinfolog("Error flushing HTTP response for stream %d from %s: %s", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret));
+      vinfolog("Error flushing HTTP response for stream %d from %s: %s", streamID, remote.toStringWithPort(), nghttp2_strerror(ret));
       return NGHTTP2_ERR_CALLBACK_FAILURE;
     }
 
     return 0;
+  };
+
+  if (conn->getConcurrentStreamsCount() >= MAX_CONCURRENT_STREAMS) {
+    vinfolog("Too many concurrent streams on connection from %d", conn->d_ci.remote.toStringWithPort());
+    return close_connection(conn, frame->hd.stream_id, conn->d_ci.remote);
+  }
+
+  auto insertPair = conn->d_currentStreams.emplace(frame->hd.stream_id, PendingQuery());
+  if (!insertPair.second) {
+    /* there is a stream ID collision, something is very wrong! */
+    vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort());
+    return close_connection(conn, frame->hd.stream_id, conn->d_ci.remote);
   }
 
   return 0;
@@ -1002,7 +1025,7 @@ int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const
       return nameLen == expected.size() && memcmp(name, expected.data(), expected.size()) == 0;
     };
 
-    auto stream = conn->d_currentStreams.find(frame->hd.stream_id);
+    auto stream = conn->getStreamContext(frame->hd.stream_id);
     if (stream == conn->d_currentStreams.end()) {
       vinfolog("Unable to match the stream ID %d to a known one!", frame->hd.stream_id);
       return NGHTTP2_ERR_CALLBACK_FAILURE;
@@ -1065,7 +1088,7 @@ int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const
 int IncomingHTTP2Connection::on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, IncomingHTTP2Connection::StreamID stream_id, const uint8_t* data, size_t len, void* user_data)
 {
   auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
-  auto stream = conn->d_currentStreams.find(stream_id);
+  auto stream = conn->getStreamContext(stream_id);
   if (stream == conn->d_currentStreams.end()) {
     vinfolog("Unable to match the stream ID %d to a known one!", stream_id);
     return NGHTTP2_ERR_CALLBACK_FAILURE;
@@ -1155,7 +1178,7 @@ void IncomingHTTP2Connection::stopIO()
 
 uint32_t IncomingHTTP2Connection::getConcurrentStreamsCount() const
 {
-  return d_currentStreams.size();
+  return d_currentStreams.size() + d_killedStreams.size();
 }
 
 boost::optional<struct timeval> IncomingHTTP2Connection::getIdleClientReadTTD(struct timeval now) const
@@ -1217,6 +1240,7 @@ void IncomingHTTP2Connection::handleIOError()
   d_outPos = 0;
   nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR);
   d_currentStreams.clear();
+  d_killedStreams.clear();
   stopIO();
 }
 
index e63077882c39288d856aea4c02c65025e1f45aeb..16da4deb5420b788e7262011bbfda4b302f10b6b 100644 (file)
@@ -86,6 +86,7 @@ private:
   std::unique_ptr<DOHUnitInterface> getDOHUnit(uint32_t streamID) override;
 
   void stopIO();
+  std::unordered_map<StreamID, PendingQuery>::iterator getStreamContext(StreamID streamID);
   uint32_t getConcurrentStreamsCount() const;
   void updateIO(IOState newState, const FDMultiplexer::callbackfunc_t& callback);
   void handleIOError();
@@ -101,6 +102,7 @@ private:
 
   std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)> d_session{nullptr, nghttp2_session_del};
   std::unordered_map<StreamID, PendingQuery> d_currentStreams;
+  std::unordered_set<StreamID> d_killedStreams;
   PacketBuffer d_out;
   PacketBuffer d_in;
   size_t d_outPos{0};