From: Remi Gacogne Date: Fri, 15 Dec 2023 15:56:23 +0000 (+0100) Subject: dnsdist: Handle congested DoQ streams X-Git-Tag: auth-4.9.0-alpha1~36^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F13638%2Fhead;p=thirdparty%2Fpdns.git dnsdist: Handle congested DoQ streams If the stream has no capacity left, Quiche will refuse to queue more data and return `QUICHE_ERR_DONE`. We then have to wait until the stream becomes writable again to retry sending our response. --- diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index a1e6394790..e3c14ccee8 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -67,6 +67,7 @@ public: QuicheConnection d_conn; QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free}; std::unordered_map d_streamBuffers; + std::unordered_map d_streamOutBuffers; }; static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description); @@ -263,7 +264,29 @@ private: std::shared_ptr DOH3CrossProtocolQuery::s_sender = std::make_shared(); -static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len) +static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, PacketBuffer& response) +{ + size_t pos = 0; + while (pos < response.size()) { + // send_body takes care of setting fin to false if it cannot send the entire content so we can try again. + auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(), + streamID, &response.at(pos), response.size() - pos, true); + if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) { + response.erase(response.begin(), response.begin() + static_cast(pos)); + return false; + } + if (res < 0) { + // Shutdown with internal error code + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR)); + return true; + } + pos += res; + } + + return true; +} + +static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len) { std::string status = std::to_string(statusCode); std::string lenStr = std::to_string(len); @@ -285,8 +308,13 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const .value_len = lenStr.size(), }, }; - quiche_h3_send_response(conn, quic_conn, - streamID, headers.data(), headers.size(), len == 0); + auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(), + streamID, headers.data(), headers.size(), len == 0); + if (returnValue != 0) { + /* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */ + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR)); + return; + } if (len == 0) { return; @@ -295,28 +323,27 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const size_t pos = 0; while (pos < len) { // send_body takes care of setting fin to false if it cannot send the entire content so we can try again. - auto res = quiche_h3_send_body(conn, quic_conn, + auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(), // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API streamID, const_cast(body) + pos, len - pos, true); + if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API + conn.d_streamOutBuffers[streamID] = PacketBuffer(body + pos, body + len); + return; + } if (res < 0) { // Shutdown with internal error code - quiche_conn_stream_shutdown(quic_conn, streamID, QUICHE_SHUTDOWN_WRITE, static_cast(1)); + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(1)); return; } pos += res; } } -static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const std::string& content) +static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API - h3_send_response(quic_conn, conn, streamID, statusCode, reinterpret_cast(content.data()), content.size()); -} - -static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len) -{ - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, statusCode, body, len); + h3_send_response(conn, streamID, statusCode, reinterpret_cast(content.data()), content.size()); } static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response) @@ -616,6 +643,21 @@ static void flushResponses(pdns::channel::Receiver& receiver) } } +static void flushStalledResponses(H3Connection& conn) +{ + for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) { + const auto streamID = streamIt->first; + auto& response = streamIt->second; + if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) { + if (tryWriteResponse(conn, streamID, response)) { + streamIt = conn.d_streamOutBuffers.erase(streamIt); + continue; + } + } + ++streamIt; + } +} + static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map& headers, int64_t streamID, quiche_h3_event* event) { auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) { @@ -623,7 +665,7 @@ static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& fronten ++dnsdist::metrics::g_stats.nonCompliantQueries; ++clientState.nonCompliantQueries; ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg); + h3_send_response(conn, streamID, 400, msg); }; // Callback result. Any value other than 0 will interrupt further header processing. @@ -684,37 +726,49 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, ++dnsdist::metrics::g_stats.nonCompliantQueries; ++clientState.nonCompliantQueries; ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg); + h3_send_response(conn, streamID, 400, msg); }; - if (headers.at(":method") == "POST") { - if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") { - handleImmediateError("Unsupported content-type"); - return; - } - PacketBuffer buffer(std::numeric_limits::max()); - PacketBuffer decoded; + if (headers.at(":method") != "POST") { + handleImmediateError("DATA frame for non-POST method"); + return; + } - while (true) { - ssize_t len = quiche_h3_recv_body(conn.d_http3.get(), - conn.d_conn.get(), streamID, - buffer.data(), buffer.capacity()); + if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") { + handleImmediateError("Unsupported content-type"); + return; + } - if (len <= 0) { - break; - } - decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len); - } + PacketBuffer buffer(std::numeric_limits::max()); + auto& streamBuffer = conn.d_streamBuffers[streamID]; - if (decoded.size() < sizeof(dnsheader)) { - handleImmediateError("DoH3 non-compliant query"); - return; + while (true) { + buffer.resize(std::numeric_limits::max()); + ssize_t len = quiche_h3_recv_body(conn.d_http3.get(), + conn.d_conn.get(), streamID, + buffer.data(), buffer.capacity()); + + if (len <= 0) { + break; } - DEBUGLOG("Dispatching POST query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID); + buffer.resize(static_cast(len)); + streamBuffer.insert(streamBuffer.end(), buffer.begin(), buffer.end()); + } + + if (!quiche_conn_stream_finished(conn.d_conn.get(), streamID)) { + return; + } + + if (streamBuffer.size() < sizeof(dnsheader)) { conn.d_streamBuffers.erase(streamID); + handleImmediateError("DoH3 non-compliant query"); + return; } + + DEBUGLOG("Dispatching POST query"); + doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID); + conn.d_streamBuffers.erase(streamID); } static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID) @@ -892,6 +946,7 @@ void doh3Thread(ClientState* clientState) conn = frontend->d_server_config->d_connections.erase(conn); } else { + flushStalledResponses(conn->second); ++conn; } } diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 03ba0f2571..9b626aaf88 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -65,6 +65,7 @@ public: ComboAddress d_peer; QuicheConnection d_conn; std::unordered_map d_streamBuffers; + std::unordered_map d_streamOutBuffers; }; static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description); @@ -260,7 +261,26 @@ private: std::shared_ptr DOQCrossProtocolQuery::s_sender = std::make_shared(); -static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, const PacketBuffer& response) +static bool tryWriteResponse(Connection& conn, const uint64_t streamID, PacketBuffer& response) +{ + size_t pos = 0; + while (pos < response.size()) { + auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true); + if (res == QUICHE_ERR_DONE) { + response.erase(response.begin(), response.begin() + static_cast(pos)); + return false; + } + if (res < 0) { + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_INTERNAL_ERROR)); + return true; + } + pos += res; + } + + return true; +} + +static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, PacketBuffer& response) { if (response.empty()) { ++frontend.d_errorResponses; @@ -270,25 +290,9 @@ static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64 ++frontend.d_validResponses; auto responseSize = static_cast(response.size()); const std::array sizeBytes = {static_cast(responseSize / 256), static_cast(responseSize % 256)}; - size_t pos = 0; - while (pos < sizeBytes.size()) { - auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &sizeBytes.at(pos), sizeBytes.size() - pos, false); - if (res < 0) { - quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_INTERNAL_ERROR)); - return; - } - pos += res; - } - - pos = 0; - while (pos < response.size()) { - // stream_send sets fin to false itself when the capacity of the stream is less than the desired writing length - auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true); - if (res < 0) { - quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_INTERNAL_ERROR)); - return; - } - pos += res; + response.insert(response.begin(), sizeBytes.begin(), sizeBytes.end()); + if (!tryWriteResponse(conn, streamID, response)) { + conn.d_streamOutBuffers[streamID] = std::move(response); } } @@ -560,6 +564,21 @@ static void flushResponses(pdns::channel::Receiver& receiver) } } +static void flushStalledResponses(Connection& conn) +{ + for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) { + const auto& streamID = streamIt->first; + auto& response = streamIt->second; + if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) { + if (tryWriteResponse(conn, streamID, response)) { + streamIt = conn.d_streamOutBuffers.erase(streamIt); + continue; + } + } + ++streamIt; + } +} + // this is the entrypoint from dnsdist.cc void doqThread(ClientState* clientState) { @@ -721,6 +740,7 @@ void doqThread(ClientState* clientState) conn = frontend->d_server_config->d_connections.erase(conn); } else { + flushStalledResponses(conn->second); ++conn; } }