]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Prevent unaligned access when reading the DNS header in DoQ
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 6 Oct 2023 14:57:05 +0000 (16:57 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Oct 2023 11:38:12 +0000 (13:38 +0200)
pdns/dns.hh
pdns/dnsdistdist/dnsdist-dnsparser.cc
pdns/dnsdistdist/dnsdist-dnsparser.hh
pdns/dnsdistdist/doq.cc

index 24a02e8a4866a769cd1675165d191e7ff551fe4d..066c7c4381b468ae9943b38368f867f5604aad14 100644 (file)
@@ -191,9 +191,14 @@ static_assert(sizeof(dnsheader) == 12, "dnsheader size must be 12");
 class dnsheader_aligned
 {
 public:
+  static bool isMemoryAligned(const void* mem)
+  {
+    return reinterpret_cast<uintptr_t>(mem) % sizeof(uint32_t) == 0; // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
+  }
+
   dnsheader_aligned(const void* mem)
   {
-    if (reinterpret_cast<uintptr_t>(mem) % sizeof(uint32_t) == 0) {  // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
+    if (isMemoryAligned(mem)) {
       d_p = reinterpret_cast<const dnsheader*>(mem);  // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
     }
     else {
@@ -207,14 +212,31 @@ public:
     return d_p;
   }
 
+  [[nodiscard]] const dnsheader& operator*() const
+  {
+    return *d_p;
+  }
+
+  [[nodiscard]] const dnsheader* operator->() const
+  {
+    return d_p;
+  }
+
 private:
   dnsheader d_h{};
-  const dnsheader *d_p{};
+  const dnsheaderd_p{};
 };
 
-inline uint16_t * getFlagsFromDNSHeader(struct dnsheader * dh)
+inline uint16_t* getFlagsFromDNSHeader(dnsheader* dh)
+{
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+  return reinterpret_cast<uint16_t*>(reinterpret_cast<char*>(dh) + sizeof(uint16_t));
+}
+
+inline const uint16_t * getFlagsFromDNSHeader(const dnsheader* dh)
 {
-  return (uint16_t*) (((char *) dh) + sizeof(uint16_t));
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+  return reinterpret_cast<const uint16_t*>(reinterpret_cast<const char*>(dh) + sizeof(uint16_t));
 }
 
 #define DNS_TYPE_SIZE (2)
index 90ce0758051698cb96ed7f8e5a25a01ca0de38b0..49f2942b03d44302c8116dddba2ffa2a131b63f9 100644 (file)
@@ -186,4 +186,32 @@ bool changeNameInDNSPacket(PacketBuffer& initialPacket, const DNSName& from, con
   return true;
 }
 
+namespace PacketMangling
+{
+  bool editDNSHeaderFromPacket(PacketBuffer& packet, std::function<bool(dnsheader& header)> editFunction)
+  {
+    if (packet.size() < sizeof(dnsheader)) {
+      throw std::runtime_error("Trying to edit the DNS header of a too small packet");
+    }
+
+    return editDNSHeaderFromRawPacket(packet.data(), editFunction);
+  }
+
+  bool editDNSHeaderFromRawPacket(void* packet, std::function<bool(dnsheader& header)> editFunction)
+  {
+    if (dnsheader_aligned::isMemoryAligned(packet)) {
+      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+      auto* header = reinterpret_cast<dnsheader*>(packet);
+      return editFunction(*header);
+    }
+
+    dnsheader header;
+    memcpy(&header, packet, sizeof(header));
+    if (!editFunction(header)) {
+      return false;
+    }
+    memcpy(packet, &header, sizeof(header));
+    return true;
+  }
+}
 }
index 91de7acf782f642bf4a0702cfeaa8ab788a5a9c5..839f6cd396fe06f6cf52ff8fd4a5894c889fad81 100644 (file)
@@ -54,4 +54,10 @@ public:
  * because it could contain pointers that would not be rewritten.
  */
 bool changeNameInDNSPacket(PacketBuffer& initialPacket, const DNSName& from, const DNSName& to);
+
+namespace PacketMangling
+{
+  bool editDNSHeaderFromPacket(PacketBuffer& packet, std::function<bool(dnsheader& header)> editFunction);
+  bool editDNSHeaderFromRawPacket(void* packet, std::function<bool(dnsheader& header)> editFunction);
+}
 }
index f1206bb9a115569320a7ede05cd371986e77118e..093eebe35e32a7f35017d6230bc4209bc889aaff 100644 (file)
@@ -34,6 +34,7 @@
 #include "threadname.hh"
 
 #include "dnsdist-ecs.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-tcp.hh"
 #include "dnsdist-random.hh"
@@ -624,11 +625,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
     if (unit->query.size() < sizeof(dnsheader)) {
       ++dnsdist::metrics::g_stats.nonCompliantQueries;
       ++clientState.nonCompliantQueries;
-      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-      auto* dnsHeader = reinterpret_cast<struct dnsheader*>(unit->query.data());
-      dnsHeader->rcode = RCode::ServFail;
-      dnsHeader->qr = true;
-      unit->response = std::move(unit->query);
+      unit->response.clear();
 
       handleImmediateResponse(std::move(unit), "DoQ non-compliant query");
       return;
@@ -641,11 +638,14 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
     {
       /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-      auto* dnsHeader = reinterpret_cast<struct dnsheader*>(unit->query.data());
-
-      if (!checkQueryHeaders(dnsHeader, clientState)) {
-        dnsHeader->rcode = RCode::ServFail;
-        dnsHeader->qr = true;
+      dnsheader_aligned dnsHeader(unit->query.data());
+
+      if (!checkQueryHeaders(dnsHeader.get(), clientState)) {
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+          header.rcode = RCode::ServFail;
+          header.qr = true;
+          return true;
+        });
         unit->response = std::move(unit->query);
 
         handleImmediateResponse(std::move(unit), "DoQ invalid headers");
@@ -653,8 +653,11 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
       }
 
       if (dnsHeader->qdcount == 0) {
-        dnsHeader->rcode = RCode::NotImp;
-        dnsHeader->qr = true;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+          header.rcode = RCode::NotImp;
+          header.qr = true;
+          return true;
+        });
         unit->response = std::move(unit->query);
 
         handleImmediateResponse(std::move(unit), "DoQ empty query");
@@ -668,8 +671,11 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     unit->ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), sizeof(dnsheader), false, &unit->ids.qtype, &unit->ids.qclass);
     DNSQuestion dnsQuestion(unit->ids, unit->query);
-    const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader());
-    ids.origFlags = *flags;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) {
+      const uint16_t* flags = getFlagsFromDNSHeader(&header);
+      ids.origFlags = *flags;
+      return true;
+    });
     unit->ids.cs = &clientState;
 
     auto result = processQuery(dnsQuestion, holders, downstream);
@@ -685,8 +691,7 @@ static void processDOQQuery(DOQUnitUniquePtr&& doqUnit)
         unit->response = std::move(unit->query);
       }
       if (unit->response.size() >= sizeof(dnsheader)) {
-        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-        const auto* dnsHeader = reinterpret_cast<const struct dnsheader*>(unit->response.data());
+        const dnsheader_aligned dnsHeader(unit->response.data());
 
         handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dnsHeader, dnsdist::Protocol::DoQ, dnsdist::Protocol::DoQ, false);
       }