From: Remi Gacogne Date: Thu, 7 Dec 2023 10:58:06 +0000 (+0100) Subject: dnsdist: More delinting of the DoH3 code X-Git-Tag: dnsdist-1.9.0-alpha4~15^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F13556%2Fhead;p=thirdparty%2Fpdns.git dnsdist: More delinting of the DoH3 code --- diff --git a/pdns/dnsdist-doh-common.hh b/pdns/dnsdist-doh-common.hh index a5c8e968c0..0dc714df23 100644 --- a/pdns/dnsdist-doh-common.hh +++ b/pdns/dnsdist-doh-common.hh @@ -33,7 +33,8 @@ #include "stat_t.hh" #include "tcpiohandler.hh" -namespace dnsdist::doh { +namespace dnsdist::doh +{ std::optional getPayloadFromPath(const std::string_view& path); } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 3b1fb0735e..4a7f386ab1 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -2875,57 +2875,57 @@ static void startFrontends() { std::vector tcpStates; std::vector udpStates; - for (auto& cs : g_frontends) { - if (cs->dohFrontend != nullptr && cs->dohFrontend->d_library == "h2o") { + for (auto& clientState : g_frontends) { + if (clientState->dohFrontend != nullptr && clientState->dohFrontend->d_library == "h2o") { #ifdef HAVE_DNS_OVER_HTTPS #ifdef HAVE_LIBH2OEVLOOP - std::thread dotThreadHandle(dohThread, cs.get()); - if (!cs->cpus.empty()) { - mapThreadToCPUList(dotThreadHandle.native_handle(), cs->cpus); + std::thread dotThreadHandle(dohThread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(dotThreadHandle.native_handle(), clientState->cpus); } dotThreadHandle.detach(); #endif /* HAVE_LIBH2OEVLOOP */ #endif /* HAVE_DNS_OVER_HTTPS */ continue; } - if (cs->doqFrontend != nullptr) { + if (clientState->doqFrontend != nullptr) { #ifdef HAVE_DNS_OVER_QUIC - std::thread doqThreadHandle(doqThread, cs.get()); - if (!cs->cpus.empty()) { - mapThreadToCPUList(doqThreadHandle.native_handle(), cs->cpus); + std::thread doqThreadHandle(doqThread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(doqThreadHandle.native_handle(), clientState->cpus); } doqThreadHandle.detach(); #endif /* HAVE_DNS_OVER_QUIC */ continue; } - if (cs->doh3Frontend != nullptr) { + if (clientState->doh3Frontend != nullptr) { #ifdef HAVE_DNS_OVER_HTTP3 - std::thread doh3ThreadHandle(doh3Thread, cs.get()); - if (!cs->cpus.empty()) { - mapThreadToCPUList(doh3ThreadHandle.native_handle(), cs->cpus); + std::thread doh3ThreadHandle(doh3Thread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(doh3ThreadHandle.native_handle(), clientState->cpus); } doh3ThreadHandle.detach(); #endif /* HAVE_DNS_OVER_HTTP3 */ continue; } - if (cs->udpFD >= 0) { + if (clientState->udpFD >= 0) { #ifdef USE_SINGLE_ACCEPTOR_THREAD - udpStates.push_back(cs.get()); + udpStates.push_back(clientState.get()); #else /* USE_SINGLE_ACCEPTOR_THREAD */ - std::thread udpClientThreadHandle(udpClientThread, std::vector{ cs.get() }); - if (!cs->cpus.empty()) { - mapThreadToCPUList(udpClientThreadHandle.native_handle(), cs->cpus); + std::thread udpClientThreadHandle(udpClientThread, std::vector{ clientState.get() }); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(udpClientThreadHandle.native_handle(), clientState->cpus); } udpClientThreadHandle.detach(); #endif /* USE_SINGLE_ACCEPTOR_THREAD */ } - else if (cs->tcpFD >= 0) { + else if (clientState->tcpFD >= 0) { #ifdef USE_SINGLE_ACCEPTOR_THREAD - tcpStates.push_back(cs.get()); + tcpStates.push_back(clientState.get()); #else /* USE_SINGLE_ACCEPTOR_THREAD */ - std::thread tcpAcceptorThreadHandle(tcpAcceptorThread, std::vector{cs.get() }); - if (!cs->cpus.empty()) { - mapThreadToCPUList(tcpAcceptorThreadHandle.native_handle(), cs->cpus); + std::thread tcpAcceptorThreadHandle(tcpAcceptorThread, std::vector{clientState.get() }); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(tcpAcceptorThreadHandle.native_handle(), clientState->cpus); } tcpAcceptorThreadHandle.detach(); #endif /* USE_SINGLE_ACCEPTOR_THREAD */ diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc index ef1c2780d4..71cd87cd0f 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.cc +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -129,7 +129,8 @@ void DOHFrontend::setup() #endif /* HAVE_DNS_OVER_HTTPS */ -namespace dnsdist::doh { +namespace dnsdist::doh +{ std::optional getPayloadFromPath(const std::string_view& path) { std::optional result{std::nullopt}; diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index bcaf454438..a1e6394790 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -267,24 +267,26 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const { std::string status = std::to_string(statusCode); std::string lenStr = std::to_string(len); - quiche_h3_header headers[] = { - { + std::array headers{ + (quiche_h3_header){ + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API .name = reinterpret_cast(":status"), .name_len = sizeof(":status") - 1, - + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API .value = reinterpret_cast(status.data()), .value_len = status.size(), }, - { + (quiche_h3_header){ + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API .name = reinterpret_cast("content-length"), .name_len = sizeof("content-length") - 1, - + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API .value = reinterpret_cast(lenStr.data()), .value_len = lenStr.size(), }, }; quiche_h3_send_response(conn, quic_conn, - streamID, headers, 2, len == 0); + streamID, headers.data(), headers.size(), len == 0); if (len == 0) { return; @@ -294,6 +296,7 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const while (pos < len) { // send_body takes care of setting fin to false if it cannot send the entire content so we can try again. auto res = quiche_h3_send_body(conn, quic_conn, + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API streamID, const_cast(body) + pos, len - pos, true); if (res < 0) { // Shutdown with internal error code @@ -306,10 +309,13 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const std::string& content) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API h3_send_response(quic_conn, conn, streamID, statusCode, reinterpret_cast(content.data()), content.size()); } + static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, statusCode, body, len); } @@ -610,6 +616,107 @@ static void flushResponses(pdns::channel::Receiver& receiver) } } +static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map& headers, int64_t streamID, quiche_h3_event* event) +{ + auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) { + DEBUGLOG(msg); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg); + }; + + // Callback result. Any value other than 0 will interrupt further header processing. + int cbresult = quiche_h3_event_for_each_header( + event, + [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API + std::string_view key(reinterpret_cast(name), name_len); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API + std::string_view content(reinterpret_cast(value), value_len); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API + auto* headersptr = reinterpret_cast*>(argp); + headersptr->emplace(key, content); + return 0; + }, + &headers); + + if (cbresult != 0 || headers.count(":method") == 0) { + handleImmediateError("Unable to process query headers"); + return; + } + + if (headers.at(":method") == "GET") { + if (headers.count(":path") == 0 || headers.at(":path").empty()) { + handleImmediateError("Path not found"); + return; + } + const auto& path = headers.at(":path"); + auto payload = dnsdist::doh::getPayloadFromPath(path); + if (!payload) { + handleImmediateError("Unable to find the DNS parameter"); + return; + } + if (payload->size() < sizeof(dnsheader)) { + handleImmediateError("DoH3 non-compliant query"); + return; + } + DEBUGLOG("Dispatching GET query"); + doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID); + conn.d_streamBuffers.erase(streamID); + return; + } + + if (headers.at(":method") == "POST") { + if (!quiche_h3_event_headers_has_body(event)) { + handleImmediateError("Empty POST query"); + } + return; + } + + handleImmediateError("Unsupported HTTP method"); +} + +static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map& headers, int64_t streamID, quiche_h3_event* event) +{ + auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) { + DEBUGLOG(msg); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg); + }; + + if (headers.at(":method") == "POST") { + if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") { + handleImmediateError("Unsupported content-type"); + return; + } + PacketBuffer buffer(std::numeric_limits::max()); + PacketBuffer decoded; + + while (true) { + ssize_t len = quiche_h3_recv_body(conn.d_http3.get(), + conn.d_conn.get(), streamID, + buffer.data(), buffer.capacity()); + + if (len <= 0) { + break; + } + decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len); + } + + if (decoded.size() < sizeof(dnsheader)) { + handleImmediateError("DoH3 non-compliant query"); + return; + } + + DEBUGLOG("Dispatching POST query"); + doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID); + conn.d_streamBuffers.erase(streamID); + } +} + static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID) { std::map headers; @@ -626,114 +733,12 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3 switch (quiche_h3_event_type(event)) { case QUICHE_H3_EVENT_HEADERS: { - // Callback result. Any value other than 0 will interrupt further header processing. - int cbresult = quiche_h3_event_for_each_header( - event, - [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API - std::string_view key(reinterpret_cast(name), name_len); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API - std::string_view content(reinterpret_cast(value), value_len); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API - auto* headersptr = reinterpret_cast*>(argp); - headersptr->emplace(key, content); - return 0; - }, - &headers); - - if (cbresult != 0 || headers.count(":method") == 0) { - DEBUGLOG("Failed to process headers"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unable to process query headers"); - break; - } - - if (headers.at(":method") == "GET") { - if (headers.count(":path") == 0 || headers.at(":path").empty()) { - DEBUGLOG("Path not found"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Path not found"); - break; - } - const auto& path = headers.at(":path"); - auto payload = dnsdist::doh::getPayloadFromPath(path); - if (!payload) { - DEBUGLOG("User error, unable to find the DNS parameter"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unable to find the DNS parameter"); - break; - } - if (payload->size() < sizeof(dnsheader)) { - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "DoH3 non-compliant query"); - break; - } - DEBUGLOG("Dispatching GET query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID); - conn.d_streamBuffers.erase(streamID); - } - else if (headers.at(":method") == "POST") { - if (!quiche_h3_event_headers_has_body(event)) { - DEBUGLOG("Empty POST query"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Empty POST query"); - break; - } - } - else { - DEBUGLOG("Unsupported HTTP method"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unsupported HTTP method"); - break; - } + processH3HeaderEvent(clientState, frontend, conn, client, serverConnID, headers, streamID, event); break; } case QUICHE_H3_EVENT_DATA: { - if (headers.at(":method") == "POST") { - if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") { - DEBUGLOG("Unsupported content-type"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unsupported content-type"); - break; - } - PacketBuffer buffer(std::numeric_limits::max()); - PacketBuffer decoded; - - while (true) { - ssize_t len = quiche_h3_recv_body(conn.d_http3.get(), - conn.d_conn.get(), streamID, - buffer.data(), buffer.capacity()); - - if (len <= 0) { - break; - } - decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len); - } - if (decoded.size() < sizeof(dnsheader)) { - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState.nonCompliantQueries; - ++frontend.d_errorResponses; - h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "DoH3 non-compliant query"); - break; - } - DEBUGLOG("Dispatching POST query"); - doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID); - conn.d_streamBuffers.erase(streamID); - } + processH3DataEvent(clientState, frontend, conn, client, serverConnID, headers, streamID, event); + break; } case QUICHE_H3_EVENT_FINISHED: case QUICHE_H3_EVENT_RESET: @@ -746,7 +751,6 @@ static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3 } } - // this is the entrypoint from dnsdist.cc void doh3Thread(ClientState* clientState) {