From: Remi Gacogne Date: Tue, 5 Dec 2023 16:03:30 +0000 (+0100) Subject: dnsdist: Split the DoH3 event handling loop off the main one X-Git-Tag: dnsdist-1.9.0-alpha4~15^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=559942b3400bff5e16567bb42d190b29714b6b98;p=thirdparty%2Fpdns.git dnsdist: Split the DoH3 event handling loop off the main one --- diff --git a/pdns/dnsdist-doh-common.hh b/pdns/dnsdist-doh-common.hh index 6e0cc86e03..a5c8e968c0 100644 --- a/pdns/dnsdist-doh-common.hh +++ b/pdns/dnsdist-doh-common.hh @@ -21,8 +21,10 @@ */ #pragma once +#include #include #include +#include #include "config.h" #include "iputils.hh" @@ -31,6 +33,10 @@ #include "stat_t.hh" #include "tcpiohandler.hh" +namespace dnsdist::doh { +std::optional getPayloadFromPath(const std::string_view& path); +} + struct DOHServerConfig; class DOHResponseMapEntry diff --git a/pdns/dnsdistdist/dnsdist-doh-common.cc b/pdns/dnsdistdist/dnsdist-doh-common.cc index 7d994365a2..ef1c2780d4 100644 --- a/pdns/dnsdistdist/dnsdist-doh-common.cc +++ b/pdns/dnsdistdist/dnsdist-doh-common.cc @@ -19,6 +19,7 @@ * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ +#include "base64.hh" #include "dnsdist-doh-common.hh" #include "dnsdist-rules.hh" @@ -127,3 +128,65 @@ void DOHFrontend::setup() } #endif /* HAVE_DNS_OVER_HTTPS */ + +namespace dnsdist::doh { +std::optional getPayloadFromPath(const std::string_view& path) +{ + std::optional result{std::nullopt}; + + if (path.size() <= 5) { + return result; + } + + auto pos = path.find("?dns="); + if (pos == string::npos) { + pos = path.find("&dns="); + } + + if (pos == string::npos) { + return result; + } + + // need to base64url decode this + string sdns; + const size_t payloadSize = path.size() - pos - 5; + size_t neededPadding = 0; + switch (payloadSize % 4) { + case 2: + neededPadding = 2; + break; + case 3: + neededPadding = 1; + break; + } + sdns.reserve(payloadSize + neededPadding); + sdns = path.substr(pos + 5); + for (auto& entry : sdns) { + switch (entry) { + case '-': + entry = '+'; + break; + case '_': + entry = '/'; + break; + } + } + + if (neededPadding != 0) { + // re-add padding that may have been missing + sdns.append(neededPadding, '='); + } + + PacketBuffer decoded; + /* rough estimate so we hopefully don't need a new allocation later */ + /* We reserve at few additional bytes to be able to add EDNS later */ + const size_t estimate = ((sdns.size() * 3) / 4); + decoded.reserve(estimate); + if (B64Decode(sdns, decoded) < 0) { + return result; + } + + result = std::move(decoded); + return result; +} +} diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc index cacf367374..80e55aea05 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc @@ -20,8 +20,8 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ -#include "base64.hh" #include "dnsdist-dnsparser.hh" +#include "dnsdist-doh-common.hh" #include "dnsdist-nghttp2-in.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsparser.hh" @@ -740,66 +740,6 @@ static void processForwardedForHeader(const std::unique_ptr& headers } } -static std::optional getPayloadFromPath(const std::string_view& path) -{ - std::optional result{std::nullopt}; - - if (path.size() <= 5) { - return result; - } - - auto pos = path.find("?dns="); - if (pos == string::npos) { - pos = path.find("&dns="); - } - - if (pos == string::npos) { - return result; - } - - // need to base64url decode this - string sdns; - const size_t payloadSize = path.size() - pos - 5; - size_t neededPadding = 0; - switch (payloadSize % 4) { - case 2: - neededPadding = 2; - break; - case 3: - neededPadding = 1; - break; - } - sdns.reserve(payloadSize + neededPadding); - sdns = path.substr(pos + 5); - for (auto& entry : sdns) { - switch (entry) { - case '-': - entry = '+'; - break; - case '_': - entry = '/'; - break; - } - } - - if (neededPadding != 0) { - // re-add padding that may have been missing - sdns.append(neededPadding, '='); - } - - PacketBuffer decoded; - /* rough estimate so we hopefully don't need a new allocation later */ - /* We reserve at few additional bytes to be able to add EDNS later */ - const size_t estimate = ((sdns.size() * 3) / 4); - decoded.reserve(estimate); - if (B64Decode(sdns, decoded) < 0) { - return result; - } - - result = std::move(decoded); - return result; -} - void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::PendingQuery&& query, IncomingHTTP2Connection::StreamID streamID) { const auto handleImmediateResponse = [this, &query, streamID](uint16_t code, const std::string& reason, PacketBuffer&& response = PacketBuffer()) { @@ -878,7 +818,7 @@ void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::Pendi } if (query.d_buffer.empty() && query.d_method == PendingQuery::Method::Get && !query.d_queryString.empty()) { - auto payload = getPayloadFromPath(query.d_queryString); + auto payload = dnsdist::doh::getPayloadFromPath(query.d_queryString); if (payload) { query.d_buffer = std::move(*payload); } diff --git a/pdns/dnsdistdist/doh3.cc b/pdns/dnsdistdist/doh3.cc index 921db830e9..bcaf454438 100644 --- a/pdns/dnsdistdist/doh3.cc +++ b/pdns/dnsdistdist/doh3.cc @@ -25,7 +25,6 @@ #ifdef HAVE_DNS_OVER_HTTP3 #include -#include "dnsparser.hh" #include "dolog.hh" #include "iputils.hh" #include "misc.hh" @@ -34,8 +33,8 @@ #include "threadname.hh" #include "base64.hh" -#include "dnsdist-ecs.hh" #include "dnsdist-dnsparser.hh" +#include "dnsdist-ecs.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-tcp.hh" #include "dnsdist-random.hh" @@ -417,9 +416,6 @@ std::unique_ptr getDOH3CrossProtocolQueryFromDQ(DNSQuestion& return std::make_unique(std::move(unit), isResponse); } -/* - We are not in the main DoH3 thread but in the DoH3 'client' thread. -*/ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit) { const auto handleImmediateResponse = [](DOH3UnitUniquePtr&& unit, [[maybe_unused]] const char* reason) { @@ -575,12 +571,6 @@ static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit) static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID) { try { - /* we only parse it there as a sanity check, we will parse it again later */ - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - DNSPacketMangler mangler(reinterpret_cast(query.data()), query.size()); - mangler.skipDomainName(); - mangler.skipBytes(4); - auto unit = std::make_unique(std::move(query)); unit->dsc = &dsc; unit->ids.origDest = local; @@ -592,7 +582,7 @@ static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, con processDOH3Query(std::move(unit)); } catch (const std::exception& exp) { - vinfolog("Had error parsing DoH3 DNS packet from %s: %s", remote.toStringWithPort(), exp.what()); + vinfolog("Had error handling DoH3 DNS packet from %s: %s", remote.toStringWithPort(), exp.what()); } } @@ -620,6 +610,143 @@ static void flushResponses(pdns::channel::Receiver& receiver) } } +static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID) +{ + std::map headers; + while (true) { + quiche_h3_event* event{nullptr}; + // Processes HTTP/3 data received from the peer + int64_t streamID = quiche_h3_conn_poll(conn.d_http3.get(), + conn.d_conn.get(), + &event); + + if (streamID < 0) { + break; + } + + switch (quiche_h3_event_type(event)) { + case QUICHE_H3_EVENT_HEADERS: { + // Callback result. Any value other than 0 will interrupt further header processing. + int cbresult = quiche_h3_event_for_each_header( + event, + [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API + std::string_view key(reinterpret_cast(name), name_len); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API + std::string_view content(reinterpret_cast(value), value_len); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API + auto* headersptr = reinterpret_cast*>(argp); + headersptr->emplace(key, content); + return 0; + }, + &headers); + + if (cbresult != 0 || headers.count(":method") == 0) { + DEBUGLOG("Failed to process headers"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unable to process query headers"); + break; + } + + if (headers.at(":method") == "GET") { + if (headers.count(":path") == 0 || headers.at(":path").empty()) { + DEBUGLOG("Path not found"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Path not found"); + break; + } + const auto& path = headers.at(":path"); + auto payload = dnsdist::doh::getPayloadFromPath(path); + if (!payload) { + DEBUGLOG("User error, unable to find the DNS parameter"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unable to find the DNS parameter"); + break; + } + if (payload->size() < sizeof(dnsheader)) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "DoH3 non-compliant query"); + break; + } + DEBUGLOG("Dispatching GET query"); + doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID); + conn.d_streamBuffers.erase(streamID); + } + else if (headers.at(":method") == "POST") { + if (!quiche_h3_event_headers_has_body(event)) { + DEBUGLOG("Empty POST query"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Empty POST query"); + break; + } + } + else { + DEBUGLOG("Unsupported HTTP method"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unsupported HTTP method"); + break; + } + break; + } + case QUICHE_H3_EVENT_DATA: { + if (headers.at(":method") == "POST") { + if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") { + DEBUGLOG("Unsupported content-type"); + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "Unsupported content-type"); + break; + } + PacketBuffer buffer(std::numeric_limits::max()); + PacketBuffer decoded; + + while (true) { + ssize_t len = quiche_h3_recv_body(conn.d_http3.get(), + conn.d_conn.get(), streamID, + buffer.data(), buffer.capacity()); + + if (len <= 0) { + break; + } + decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len); + } + if (decoded.size() < sizeof(dnsheader)) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + ++frontend.d_errorResponses; + h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, "DoH3 non-compliant query"); + break; + } + DEBUGLOG("Dispatching POST query"); + doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID); + conn.d_streamBuffers.erase(streamID); + } + } + case QUICHE_H3_EVENT_FINISHED: + case QUICHE_H3_EVENT_RESET: + case QUICHE_H3_EVENT_PRIORITY_UPDATE: + case QUICHE_H3_EVENT_GOAWAY: + break; + } + + quiche_h3_event_free(event); + } +} + + // this is the entrypoint from dnsdist.cc void doh3Thread(ClientState* clientState) { @@ -732,164 +859,7 @@ void doh3Thread(ClientState* clientState) DEBUGLOG("Successfully created HTTP/3 connection"); } - std::map headers; - while (true) { - quiche_h3_event* event{nullptr}; - // Processes HTTP/3 data received from the peer - int64_t streamID = quiche_h3_conn_poll(conn->get().d_http3.get(), - conn->get().d_conn.get(), - &event); - - if (streamID < 0) { - break; - } - - switch (quiche_h3_event_type(event)) { - case QUICHE_H3_EVENT_HEADERS: { - // Callback result. Any value other than 0 will interrupt further header processing. - int cbresult = quiche_h3_event_for_each_header( - event, - [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int { - std::string_view key(reinterpret_cast(name), name_len); - std::string_view content(reinterpret_cast(value), value_len); - auto* headersptr = reinterpret_cast*>(argp); - headersptr->emplace(key, content); - return 0; - }, - &headers); - if (cbresult != 0 || headers.count(":method") == 0) { - DEBUGLOG("Failed to process headers"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to process query headers"); - break; - } - if (headers.at(":method") == "GET") { - if (headers.count(":path") == 0 || headers.at(":path").empty()) { - DEBUGLOG("Path not found"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Path not found"); - break; - } - auto pos = headers.at(":path").find("?dns="); - if (pos == string::npos) { - pos = headers.at(":path").find("&dns="); - } - if (pos == string::npos) { - DEBUGLOG("User error, unable to find the DNS parameter"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to find the DNS parameter"); - break; - } - // need to base64url decode this - string sdns(headers.at(":path").substr(pos + 5)); - boost::replace_all(sdns, "-", "+"); - boost::replace_all(sdns, "_", "/"); - // re-add padding that may have been missing - switch (sdns.size() % 4) { - case 2: - sdns.append(2, '='); - break; - case 3: - sdns.append(1, '='); - break; - } - - PacketBuffer decoded; - /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */ - const size_t maxAdditionalSizeForEDNS = 35U; - /* rough estimate so we hopefully don't need a new allocation later */ - /* We reserve at few additional bytes to be able to add EDNS later */ - const size_t estimate = ((sdns.size() * 3) / 4); - decoded.reserve(estimate + maxAdditionalSizeForEDNS); - if (B64Decode(sdns, decoded) < 0) { - DEBUGLOG("Unable to base64 decode()"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unable to decode BASE64-URL"); - break; - } - if (decoded.size() < sizeof(dnsheader)) { - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query"); - break; - } - DEBUGLOG("Dispatching GET query"); - doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID); - conn->get().d_streamBuffers.erase(streamID); - } - else if (headers.at(":method") == "POST") { - if (!quiche_h3_event_headers_has_body(event)) { - DEBUGLOG("Empty POST query"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Empty POST query"); - break; - } - } - else { - DEBUGLOG("Unsupported HTTP method"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported HTTP method"); - break; - } - break; - } - case QUICHE_H3_EVENT_DATA: { - if (headers.at(":method") == "POST") { - if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") { - DEBUGLOG("Unsupported content-type"); - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "Unsupported content-type"); - break; - } - PacketBuffer buffer(std::numeric_limits::max()); - PacketBuffer decoded; - - while (true) { - ssize_t len = quiche_h3_recv_body(conn->get().d_http3.get(), - conn->get().d_conn.get(), streamID, - buffer.data(), buffer.capacity()); - - if (len <= 0) { - break; - } - decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len); - } - if (decoded.size() < sizeof(dnsheader)) { - ++dnsdist::metrics::g_stats.nonCompliantQueries; - ++clientState->nonCompliantQueries; - ++frontend->d_errorResponses; - h3_send_response(conn->get().d_conn.get(), conn->get().d_http3.get(), streamID, 400, "DoH3 non-compliant query"); - break; - } - DEBUGLOG("Dispatching POST query"); - doh3_dispatch_query(*(frontend->d_server_config), std::move(decoded), clientState->local, client, serverConnID, streamID); - conn->get().d_streamBuffers.erase(streamID); - } - } - case QUICHE_H3_EVENT_FINISHED: - case QUICHE_H3_EVENT_RESET: - case QUICHE_H3_EVENT_PRIORITY_UPDATE: - case QUICHE_H3_EVENT_GOAWAY: - break; - } - - quiche_h3_event_free(event); - } + processH3Events(*clientState, *frontend, conn->get(), client, serverConnID); } else { DEBUGLOG("Connection not established"); diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 66a71b8504..03ba0f2571 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -25,7 +25,6 @@ #ifdef HAVE_DNS_OVER_QUIC #include -#include "dnsparser.hh" #include "dolog.hh" #include "iputils.hh" #include "misc.hh" @@ -33,8 +32,8 @@ #include "sstuff.hh" #include "threadname.hh" -#include "dnsdist-ecs.hh" #include "dnsdist-dnsparser.hh" +#include "dnsdist-ecs.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-tcp.hh" #include "dnsdist-random.hh" @@ -375,9 +374,6 @@ std::unique_ptr getDOQCrossProtocolQueryFromDQ(DNSQuestion& return std::make_unique(std::move(unit), isResponse); } -/* - We are not in the main DoQ thread but in the DoQ 'client' thread. -*/ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) { const auto handleImmediateResponse = [](DOQUnitUniquePtr&& unit, [[maybe_unused]] const char* reason) { @@ -525,12 +521,6 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID) { try { - /* we only parse it there as a sanity check, we will parse it again later */ - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - DNSPacketMangler mangler(reinterpret_cast(query.data()), query.size()); - mangler.skipDomainName(); - mangler.skipBytes(4); - auto unit = std::make_unique(std::move(query)); unit->dsc = &dsc; unit->ids.origDest = local; @@ -542,7 +532,7 @@ static void doq_dispatch_query(DOQServerConfig& dsc, PacketBuffer&& query, const processDOQQuery(std::move(unit)); } catch (const std::exception& exp) { - vinfolog("Had error parsing DoQ DNS packet from %s: %s", remote.toStringWithPort(), exp.what()); + vinfolog("Had error handling DoQ DNS packet from %s: %s", remote.toStringWithPort(), exp.what()); } }