From: Remi Gacogne Date: Fri, 6 Oct 2023 14:57:05 +0000 (+0200) Subject: dnsdist: Prevent unaligned access when reading the DNS header in DoQ X-Git-Tag: rec-5.0.0-alpha2~6^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d5d9573cf17982a4403d8938a1fd06c9564e7f97;p=thirdparty%2Fpdns.git dnsdist: Prevent unaligned access when reading the DNS header in DoQ --- diff --git a/pdns/dns.hh b/pdns/dns.hh index 24a02e8a48..066c7c4381 100644 --- a/pdns/dns.hh +++ b/pdns/dns.hh @@ -191,9 +191,14 @@ static_assert(sizeof(dnsheader) == 12, "dnsheader size must be 12"); class dnsheader_aligned { public: + static bool isMemoryAligned(const void* mem) + { + return reinterpret_cast(mem) % sizeof(uint32_t) == 0; // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) + } + dnsheader_aligned(const void* mem) { - if (reinterpret_cast(mem) % sizeof(uint32_t) == 0) { // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) + if (isMemoryAligned(mem)) { d_p = reinterpret_cast(mem); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) } else { @@ -207,14 +212,31 @@ public: return d_p; } + [[nodiscard]] const dnsheader& operator*() const + { + return *d_p; + } + + [[nodiscard]] const dnsheader* operator->() const + { + return d_p; + } + private: dnsheader d_h{}; - const dnsheader *d_p{}; + const dnsheader* d_p{}; }; -inline uint16_t * getFlagsFromDNSHeader(struct dnsheader * dh) +inline uint16_t* getFlagsFromDNSHeader(dnsheader* dh) +{ + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return reinterpret_cast(reinterpret_cast(dh) + sizeof(uint16_t)); +} + +inline const uint16_t * getFlagsFromDNSHeader(const dnsheader* dh) { - return (uint16_t*) (((char *) dh) + sizeof(uint16_t)); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return reinterpret_cast(reinterpret_cast(dh) + sizeof(uint16_t)); } #define DNS_TYPE_SIZE (2) diff --git a/pdns/dnsdistdist/dnsdist-dnsparser.cc b/pdns/dnsdistdist/dnsdist-dnsparser.cc index 90ce075805..49f2942b03 100644 --- a/pdns/dnsdistdist/dnsdist-dnsparser.cc +++ b/pdns/dnsdistdist/dnsdist-dnsparser.cc @@ -186,4 +186,32 @@ bool changeNameInDNSPacket(PacketBuffer& initialPacket, const DNSName& from, con return true; } +namespace PacketMangling +{ + bool editDNSHeaderFromPacket(PacketBuffer& packet, std::function editFunction) + { + if (packet.size() < sizeof(dnsheader)) { + throw std::runtime_error("Trying to edit the DNS header of a too small packet"); + } + + return editDNSHeaderFromRawPacket(packet.data(), editFunction); + } + + bool editDNSHeaderFromRawPacket(void* packet, std::function editFunction) + { + if (dnsheader_aligned::isMemoryAligned(packet)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto* header = reinterpret_cast(packet); + return editFunction(*header); + } + + dnsheader header; + memcpy(&header, packet, sizeof(header)); + if (!editFunction(header)) { + return false; + } + memcpy(packet, &header, sizeof(header)); + return true; + } +} } diff --git a/pdns/dnsdistdist/dnsdist-dnsparser.hh b/pdns/dnsdistdist/dnsdist-dnsparser.hh index 91de7acf78..839f6cd396 100644 --- a/pdns/dnsdistdist/dnsdist-dnsparser.hh +++ b/pdns/dnsdistdist/dnsdist-dnsparser.hh @@ -54,4 +54,10 @@ public: * because it could contain pointers that would not be rewritten. */ bool changeNameInDNSPacket(PacketBuffer& initialPacket, const DNSName& from, const DNSName& to); + +namespace PacketMangling +{ + bool editDNSHeaderFromPacket(PacketBuffer& packet, std::function editFunction); + bool editDNSHeaderFromRawPacket(void* packet, std::function editFunction); +} } diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index f1206bb9a1..093eebe35e 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -34,6 +34,7 @@ #include "threadname.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-tcp.hh" #include "dnsdist-random.hh" @@ -624,11 +625,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) if (unit->query.size() < sizeof(dnsheader)) { ++dnsdist::metrics::g_stats.nonCompliantQueries; ++clientState.nonCompliantQueries; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - auto* dnsHeader = reinterpret_cast(unit->query.data()); - dnsHeader->rcode = RCode::ServFail; - dnsHeader->qr = true; - unit->response = std::move(unit->query); + unit->response.clear(); handleImmediateResponse(std::move(unit), "DoQ non-compliant query"); return; @@ -641,11 +638,14 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) { /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - auto* dnsHeader = reinterpret_cast(unit->query.data()); - - if (!checkQueryHeaders(dnsHeader, clientState)) { - dnsHeader->rcode = RCode::ServFail; - dnsHeader->qr = true; + dnsheader_aligned dnsHeader(unit->query.data()); + + if (!checkQueryHeaders(dnsHeader.get(), clientState)) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) { + header.rcode = RCode::ServFail; + header.qr = true; + return true; + }); unit->response = std::move(unit->query); handleImmediateResponse(std::move(unit), "DoQ invalid headers"); @@ -653,8 +653,11 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) } if (dnsHeader->qdcount == 0) { - dnsHeader->rcode = RCode::NotImp; - dnsHeader->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); unit->response = std::move(unit->query); handleImmediateResponse(std::move(unit), "DoQ empty query"); @@ -668,8 +671,11 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) unit->ids.qname = DNSName(reinterpret_cast(unit->query.data()), static_cast(unit->query.size()), sizeof(dnsheader), false, &unit->ids.qtype, &unit->ids.qclass); DNSQuestion dnsQuestion(unit->ids, unit->query); - const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader()); - ids.origFlags = *flags; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) { + const uint16_t* flags = getFlagsFromDNSHeader(&header); + ids.origFlags = *flags; + return true; + }); unit->ids.cs = &clientState; auto result = processQuery(dnsQuestion, holders, downstream); @@ -685,8 +691,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit) unit->response = std::move(unit->query); } if (unit->response.size() >= sizeof(dnsheader)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - const auto* dnsHeader = reinterpret_cast(unit->response.data()); + const dnsheader_aligned dnsHeader(unit->response.data()); handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dnsHeader, dnsdist::Protocol::DoQ, dnsdist::Protocol::DoQ, false); }