From: Remi Gacogne Date: Mon, 25 Sep 2023 13:37:39 +0000 (+0200) Subject: dnsdist: Better handling of short reads/writes in DoQ X-Git-Tag: rec-5.0.0-alpha2~6^2~27 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6c66428a6e2f517f8a769d8acb0bf6da69579df7;p=thirdparty%2Fpdns.git dnsdist: Better handling of short reads/writes in DoQ --- diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 2a16a052a2..cdb12919c0 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -64,6 +64,7 @@ public: ComboAddress d_peer; QuicheConnection d_conn; + std::unordered_map d_streamBuffers; }; static void sendBackDOQUnit(DOQUnitUniquePtr&& du, const char* description); @@ -261,19 +262,46 @@ private: std::shared_ptr DOQCrossProtocolQuery::s_sender = std::make_shared(); +/* from rfc9250 section-4.3 */ +enum class DOQ_Error_Codes : uint64_t { + DOQ_NO_ERROR = 0, + DOQ_INTERNAL_ERROR = 1, + DOQ_PROTOCOL_ERROR = 2, + DOQ_REQUEST_CANCELLED = 3, + DOQ_EXCESSIVE_LOAD = 4, + DOQ_UNSPECIFIED_ERROR = 5 +}; + static void handleResponse(DOQFrontend& df, Connection& conn, const uint64_t streamID, const PacketBuffer& response) { if (response.size() == 0) { - quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, 0x5); + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR)); + return; } - else { - uint16_t responseSize = static_cast(response.size()); - const uint8_t sizeBytes[] = {static_cast(responseSize / 256), static_cast(responseSize % 256)}; - auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, sizeBytes, sizeof(sizeBytes), false); - if (res == sizeof(sizeBytes)) { - res = quiche_conn_stream_send(conn.d_conn.get(), streamID, response.data(), response.size(), true); + + uint16_t responseSize = static_cast(response.size()); + const std::array sizeBytes = {static_cast(responseSize / 256), static_cast(responseSize % 256)}; + size_t pos = 0; + do { + auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, sizeBytes.data() + pos, sizeBytes.size() - pos, false); + if (res < 0) { + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_INTERNAL_ERROR)); + return; } + pos += res; } + while (pos < sizeBytes.size()); + + pos = 0; + do { + auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, response.data() + pos, response.size() - pos, true); + if (res < 0) { + quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_INTERNAL_ERROR)); + return; + } + pos += res; + } + while (pos < response.size()); } static void fillRandom(PacketBuffer& buffer, size_t size) @@ -755,7 +783,7 @@ void doqThread(ClientState* cs) Socket sock(cs->udpFD); - PacketBuffer buffer(std::numeric_limits::max()); + PacketBuffer buffer(std::numeric_limits::max()); auto mplexer = std::unique_ptr(FDMultiplexer::getMultiplexerSilent()); auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor(); @@ -840,29 +868,33 @@ void doqThread(ClientState* cs) uint64_t streamID = 0; while (quiche_stream_iter_next(readable.get(), &streamID)) { + auto& streamBuffer = conn->get().d_streamBuffers[streamID]; + auto existingLength = streamBuffer.size(); bool fin = false; - buffer.resize(std::numeric_limits::max()); + streamBuffer.resize(existingLength + 512); auto received = quiche_conn_stream_recv(conn->get().d_conn.get(), streamID, - buffer.data(), buffer.size(), + &streamBuffer.at(existingLength), 512, &fin); - if (received < 2) { - break; - } - buffer.resize(received); - + streamBuffer.resize(existingLength + received); if (fin) { - // we skip message length, should we verify ? - buffer.erase(buffer.begin(), buffer.begin() + 2); - if (buffer.size() >= sizeof(dnsheader)) { - doq_dispatch_query(*(frontend->d_server_config), std::move(buffer), cs->local, client, serverConnID, streamID); + if (streamBuffer.size() < (sizeof(dnsheader) + sizeof(uint16_t))) { + quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR)); + break; + } + uint16_t payloadLength = streamBuffer.at(0) * 256 + streamBuffer.at(1); + streamBuffer.erase(streamBuffer.begin(), streamBuffer.begin() + 2); + if (payloadLength != streamBuffer.size()) { + quiche_conn_stream_shutdown(conn->get().d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast(DOQ_Error_Codes::DOQ_PROTOCOL_ERROR)); + break; } + doq_dispatch_query(*(frontend->d_server_config), std::move(streamBuffer), cs->local, client, serverConnID, streamID); + conn->get().d_streamBuffers.erase(streamID); } } } else { DEBUGLOG("Connection not established"); } - // } } if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {