]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Prevent dnsheader alignment issues
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 5 Oct 2023 14:38:56 +0000 (16:38 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 17 Oct 2023 13:43:28 +0000 (15:43 +0200)
23 files changed:
pdns/dnscrypt.cc
pdns/dnsdist-cache.cc
pdns/dnsdist-ecs.cc
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist-xpf.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-discovery.cc
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-nghttp2-in.cc
pdns/dnsdistdist/dnsdist-secpoll.cc
pdns/dnsdistdist/dnsdist-tcp.hh
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/doq.cc
pdns/dnsdistdist/test-dnsdist-lua-ffi.cc
pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc
pdns/dnsparser.cc
pdns/dnstap.cc
pdns/protozero.cc
pdns/test-dnsdistpacketcache_cc.cc

index 6db8613a3d01899c1d3844d6d647bfe44f8055f5..192dcfad1802d652476402fe30ee382a692b3b4b 100644 (file)
@@ -399,9 +399,10 @@ bool DNSCryptQuery::parsePlaintextQuery(const PacketBuffer& packet)
     return false;
   }
 
-  const struct dnsheader * dh = reinterpret_cast<const struct dnsheader *>(packet.data());
-  if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query)
+  const dnsheader_aligned dh(packet.data());
+  if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query) {
     return false;
+  }
 
   unsigned int qnameWireLength;
   uint16_t qtype, qclass;
index 7ca9be2f6e7ee0291b2e60c9410f883471969c42..67de6226a0432d7362b5470249ead2e0a507f9c6 100644 (file)
@@ -252,7 +252,7 @@ bool DNSDistPacketCache::get(DNSQuestion& dq, uint16_t queryId, uint32_t* keyOut
     }
 
     /* check for collision */
-    if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) {
+    if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader().get())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) {
       ++d_lookupCollisions;
       return false;
     }
index 9e9d9c329e76bd8513799abe2082f384ca583080..2cad1945bca8de18beb3eb745743f62dad02b21d 100644 (file)
@@ -21,6 +21,7 @@
  */
 #include "dolog.hh"
 #include "dnsdist.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsparser.hh"
 #include "dnswriter.hh"
@@ -44,13 +45,15 @@ bool g_addEDNSToSelfGeneratedResponses{true};
 int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
 {
   assert(initialPacket.size() >= sizeof(dnsheader));
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+  const dnsheader_aligned dh(initialPacket.data());
 
-  if (ntohs(dh->arcount) == 0)
+  if (ntohs(dh->arcount) == 0) {
     return ENOENT;
+  }
 
-  if (ntohs(dh->qdcount) == 0)
+  if (ntohs(dh->qdcount) == 0) {
     return ENOENT;
+  }
 
   PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
 
@@ -152,7 +155,7 @@ static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>
 bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
 {
   assert(initialPacket.size() >= sizeof(dnsheader));
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+  const dnsheader_aligned dh(initialPacket.data());
 
   if (ntohs(dh->qdcount) == 0) {
     return false;
@@ -269,7 +272,7 @@ static bool slowParseEDNSOptions(const PacketBuffer& packet, EDNSOptionViewMap&
     return false;
   }
 
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+  const dnsheader_aligned dh(packet.data());
 
   if (ntohs(dh->qdcount) == 0) {
     return false;
@@ -324,10 +327,11 @@ int locateEDNSOptRR(const PacketBuffer& packet, uint16_t * optStart, size_t * op
   assert(optStart != NULL);
   assert(optLen != NULL);
   assert(last != NULL);
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+  const dnsheader_aligned dh(packet.data());
 
-  if (ntohs(dh->arcount) == 0)
+  if (ntohs(dh->arcount) == 0) {
     return ENOENT;
+  }
 
   PacketReader pr(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
 
@@ -390,14 +394,15 @@ int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_
 {
   assert(optRDPosition != nullptr);
   assert(remaining != nullptr);
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+  const dnsheader_aligned dh(packet.data());
 
   if (offset >= packet.size()) {
     return ENOENT;
   }
 
-  if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0)
+  if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0) {
     return ENOENT;
+  }
 
   size_t pos = sizeof(dnsheader) + offset;
   pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
@@ -571,10 +576,12 @@ static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const strin
     return false;
   }
 
-  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
-  uint16_t arcount = ntohs(dh->arcount);
-  arcount++;
-  dh->arcount = htons(arcount);
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+    uint16_t arcount = ntohs(header.arcount);
+    arcount++;
+    header.arcount = htons(arcount);
+    return true;
+  });
   ednsAdded = true;
   ecsAdded = true;
 
@@ -585,7 +592,7 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, cons
 {
   assert(qnameWireLength <= packet.size());
 
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+  const dnsheader_aligned dh(packet.data());
 
   if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
     PacketBuffer newContent;
@@ -752,7 +759,7 @@ bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const
 int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
 {
   assert(initialPacket.size() >= sizeof(dnsheader));
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+  const dnsheader_aligned dh(initialPacket.data());
 
   if (ntohs(dh->arcount) == 0)
     return ENOENT;
@@ -852,8 +859,10 @@ bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t p
     return false;
   }
 
-  auto dh = reinterpret_cast<dnsheader*>(packet.data());
-  dh->arcount = htons(ntohs(dh->arcount) + 1);
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+    header.arcount = htons(ntohs(header.arcount) + 1);
+    return true;
+  });
 
   return true;
 }
@@ -894,17 +903,19 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone,
 
   /* chop off everything after the question */
   packet.resize(queryPartSize);
-  dh = dq.getHeader();
-  if (nxd) {
-    dh->rcode = RCode::NXDomain;
-  }
-  else {
-    dh->rcode = RCode::NoError;
-  }
-  dh->qr = true;
-  dh->ancount = 0;
-  dh->nscount = 0;
-  dh->arcount = 0;
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) {
+    if (nxd) {
+      header.rcode = RCode::NXDomain;
+    }
+    else {
+      header.rcode = RCode::NoError;
+    }
+    header.qr = true;
+    header.ancount = 0;
+    header.nscount = 0;
+    header.arcount = 0;
+    return true;
+  });
 
   rdLength = htons(rdLength);
   ttl = htonl(ttl);
@@ -934,16 +945,18 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone,
   }
 
   packet.insert(packet.end(), soa.begin(), soa.end());
-  dh = dq.getHeader();
 
   /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
      NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
      and have EDNS added to AR afterwards */
-  if (soaInAuthoritySection) {
-    dh->nscount = htons(1);
-  } else {
-    dh->arcount = htons(1);
-  }
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) {
+    if (soaInAuthoritySection) {
+      header.nscount = htons(1);
+    } else {
+      header.arcount = htons(1);
+    }
+    return true;
+  });
 
   if (hadEDNS) {
     /* now we need to add a new OPT record */
@@ -982,7 +995,10 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
 
   /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
   packet.resize(packet.size() - existingOptLen);
-  dq.getHeader()->arcount = 0;
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+    header.arcount = 0;
+    return true;
+  });
 
   if (g_addEDNSToSelfGeneratedResponses) {
     /* now we need to add a new OPT record */
@@ -1107,7 +1123,10 @@ bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsDa
 
   auto& data = dq.getMutableData();
   if (generateOptRR(optRData, data, dq.getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
-    dq.getHeader()->arcount = htons(1);
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+      header.arcount = htons(1);
+      return true;
+    });
     // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
     dq.ids.ednsAdded = true;
   }
@@ -1129,17 +1148,22 @@ bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer,  uin
     hadEDNS = getEDNS0Record(buffer, edns0);
   }
 
-  auto dh = reinterpret_cast<dnsheader*>(buffer.data());
-  dh->rcode = rcode;
-  dh->ad = false;
-  dh->aa = false;
-  dh->ra = dh->rd;
-  dh->qr = true;
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode,clearAnswers](dnsheader& header) {
+    header.rcode = rcode;
+    header.ad = false;
+    header.aa = false;
+    header.ra = header.rd;
+    header.qr = true;
+
+    if (clearAnswers) {
+      header.ancount = 0;
+      header.nscount = 0;
+      header.arcount = 0;
+    }
+    return true;
+  });
 
   if (clearAnswers) {
-    dh->ancount = 0;
-    dh->nscount = 0;
-    dh->arcount = 0;
     buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
     if (hadEDNS) {
       DNSQuestion dq(state, buffer);
index b81600fe8c92aa2e7668efa78f28fc71860797a2..b9d1b5b3af8a0286e3dd55dc7fd6dc6a888ddd8d 100644 (file)
@@ -23,6 +23,7 @@
 #include "threadname.hh"
 #include "dnsdist.hh"
 #include "dnsdist-async.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-edns.hh"
 #include "dnsdist-lua.hh"
@@ -242,37 +243,48 @@ std::map<std::string,double> TeeAction::getStats() const
 void TeeAction::worker()
 {
   setThreadName("dnsdist/TeeWork");
-  char packet[1500];
-  int res=0;
-  struct dnsheader* dh=(struct dnsheader*)packet;
-  for(;;) {
-    res=waitForData(d_fd, 0, 250000);
-    if(d_pleaseQuit)
+  std::array<char, 1500> packet;
+  ssize_t res = 0;
+  const dnsheader_aligned dh(packet.data());
+  for (;;) {
+    res = waitForData(d_fd, 0, 250000);
+    if (d_pleaseQuit) {
       break;
-    if(res < 0) {
+    }
+
+    if (res < 0) {
       usleep(250000);
       continue;
     }
-    if(res==0)
+    if (res == 0) {
       continue;
-    res=recv(d_fd, packet, sizeof(packet), 0);
-    if(res <= (int)sizeof(struct dnsheader))
+    }
+    res = recv(d_fd, packet.data(), packet.size(), 0);
+    if (static_cast<size_t>(res) <= sizeof(struct dnsheader)) {
       d_recverrors++;
-    else
+    }
+    else {
       d_responses++;
+    }
 
-    if(dh->rcode == RCode::NoError)
+    if (dh->rcode == RCode::NoError) {
       d_noerrors++;
-    else if(dh->rcode == RCode::ServFail)
+    }
+    else if (dh->rcode == RCode::ServFail) {
       d_servfails++;
-    else if(dh->rcode == RCode::NXDomain)
+    }
+    else if (dh->rcode == RCode::NXDomain) {
       d_nxdomains++;
-    else if(dh->rcode == RCode::Refused)
+    }
+    else if (dh->rcode == RCode::Refused) {
       d_refuseds++;
-    else if(dh->rcode == RCode::FormErr)
+    }
+    else if (dh->rcode == RCode::FormErr) {
       d_formerrs++;
-    else if(dh->rcode == RCode::NotImp)
+    }
+    else if (dh->rcode == RCode::NotImp) {
       d_notimps++;
+    }
   }
 }
 
@@ -343,9 +355,12 @@ public:
   RCodeAction(uint8_t rcode) : d_rcode(rcode) {}
   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
   {
-    dq->getHeader()->rcode = d_rcode;
-    dq->getHeader()->qr = true; // for good measure
-    setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+      header.rcode = d_rcode;
+      header.qr = true; // for good measure
+      setResponseHeadersFromConfig(header, d_responseConfig);
+      return true;
+    });
     return Action::HeaderModify;
   }
   std::string toString() const override
@@ -364,10 +379,13 @@ public:
   ERCodeAction(uint8_t rcode) : d_rcode(rcode) {}
   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
   {
-    dq->getHeader()->rcode = (d_rcode & 0xF);
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+      header.rcode = (d_rcode & 0xF);
+      header.qr = true; // for good measure
+      setResponseHeadersFromConfig(header, d_responseConfig);
+      return true;
+    });
     dq->ednsRCode = ((d_rcode & 0xFFF0) >> 4);
-    dq->getHeader()->qr = true; // for good measure
-    setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
     return Action::HeaderModify;
   }
   std::string toString() const override
@@ -819,7 +837,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu
   if (d_raw.size() >= sizeof(dnsheader)) {
     auto id = dq->getHeader()->id;
     dq->getMutableData() = d_raw;
-    dq->getHeader()->id = id;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [id](dnsheader& header) {
+      header.id = id;
+      return true;
+    });
     return Action::HeaderModify;
   }
   vector<ComboAddress> addrs;
@@ -875,10 +896,13 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu
   data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen); // there goes your EDNS
   uint8_t* dest = &(data.at(sizeof(dnsheader) + qnameWireLength + 4));
 
-  dq->getHeader()->qr = true; // for good measure
-  setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
-  dq->getHeader()->ancount = 0;
-  dq->getHeader()->arcount = 0; // for now, forget about your EDNS, we're marching over it
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+    header.qr = true; // for good measure
+    setResponseHeadersFromConfig(header, d_responseConfig);
+    header.ancount = 0;
+    header.arcount = 0; // for now, forget about your EDNS, we're marching over it
+    return true;
+  });
 
   uint32_t ttl = htonl(d_responseConfig.ttl);
   uint16_t qclass = htons(dq->ids.qclass);
@@ -902,7 +926,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu
     memcpy(dest, recordstart, sizeof(recordstart));
     dest += sizeof(recordstart);
     memcpy(dest, wireData.c_str(), wireData.length());
-    dq->getHeader()->ancount++;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+      header.ancount++;
+      return true;
+    });
   }
   else if (!rawResponses.empty()) {
     qtype = htons(qtype);
@@ -917,7 +944,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu
       memcpy(dest, rawResponse.c_str(), rawResponse.size());
       dest += rawResponse.size();
 
-      dq->getHeader()->ancount++;
+      dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+        header.ancount++;
+        return true;
+      });
     }
     raw = true;
   }
@@ -935,11 +965,18 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu
              addr.sin4.sin_family == AF_INET ? reinterpret_cast<const void*>(&addr.sin4.sin_addr.s_addr) : reinterpret_cast<const void*>(&addr.sin6.sin6_addr.s6_addr),
              addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr));
       dest += (addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr));
-      dq->getHeader()->ancount++;
+      dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+        header.ancount++;
+        return true;
+      });
     }
   }
 
-  dq->getHeader()->ancount = htons(dq->getHeader()->ancount);
+  auto finalANCount = dq->getHeader()->ancount;
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [finalANCount](dnsheader& header) {
+    header.ancount = htons(finalANCount);
+    return true;
+  });
 
   if (hadEDNS && raw == false) {
     addEDNS(dq->getMutableData(), dq->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0);
@@ -991,7 +1028,10 @@ public:
 
     auto& data = dq->getMutableData();
     if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
-      dq->getHeader()->arcount = htons(1);
+      dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+        header.arcount = htons(1);
+        return true;
+      });
       // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
       dq->ids.ednsAdded = true;
     }
@@ -1036,7 +1076,10 @@ public:
   // this action does not stop the processing
   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
   {
-    dq->getHeader()->rd = false;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+      header.rd = false;
+      return true;
+    });
     return Action::None;
   }
   std::string toString() const override
@@ -1252,7 +1295,10 @@ public:
   // this action does not stop the processing
   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
   {
-    dq->getHeader()->cd = true;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+      header.cd = true;
+      return true;
+    });
     return Action::None;
   }
   std::string toString() const override
@@ -1922,8 +1968,11 @@ public:
     }
 
     dq->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
-    dq->getHeader()->qr = true; // for good measure
-    setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+      header.qr = true; // for good measure
+      setResponseHeadersFromConfig(header, d_responseConfig);
+      return true;
+    });
     return Action::HeaderModify;
   }
 
@@ -2067,7 +2116,10 @@ public:
       return Action::None;
     }
 
-    setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+      setResponseHeadersFromConfig(header, d_responseConfig);
+      return true;
+    });
 
     return Action::Allow;
   }
index 057f71c9073a795ef5a3b3de0b5ec8dab738a8a0..bf456075c79ded8386075e14f3128ab1d1409bdc 100644 (file)
@@ -36,10 +36,20 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerMember<const DNSName (DNSQuestion::*)>("qname", [](const DNSQuestion& dq) -> const DNSName { return dq.ids.qname; }, [](DNSQuestion& dq, const DNSName& newName) { (void) newName; });
   luaCtx.registerMember<uint16_t (DNSQuestion::*)>("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; });
   luaCtx.registerMember<uint16_t (DNSQuestion::*)>("qclass", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSQuestion& dq, uint16_t newClass) { (void) newClass; });
-  luaCtx.registerMember<int (DNSQuestion::*)>("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.getHeader()->rcode = newRCode; });
+  luaCtx.registerMember<int (DNSQuestion::*)>("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) {
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [newRCode](dnsheader& header) {
+      header.rcode = newRCode;
+      return true;
+    });
+  });
   luaCtx.registerMember<const ComboAddress (DNSQuestion::*)>("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; });
   /* DNSDist DNSQuestion */
-  luaCtx.registerMember<dnsheader* (DNSQuestion::*)>("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast<DNSQuestion&>(dq).getHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { *(dq.getHeader()) = *dh; });
+  luaCtx.registerMember<dnsheader* (DNSQuestion::*)>("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast<DNSQuestion&>(dq).getMutableHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) {
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&dh](dnsheader& header) {
+      header = *dh;
+      return true;
+    });
+  });
   luaCtx.registerMember<uint16_t (DNSQuestion::*)>("len", [](const DNSQuestion& dq) -> uint16_t { return dq.getData().size(); }, [](DNSQuestion& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
   luaCtx.registerMember<uint8_t (DNSQuestion::*)>("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.overTCP(); }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
@@ -100,7 +110,12 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     auto& buffer = dq.getMutableData();
     buffer.clear();
     buffer.insert(buffer.begin(), raw.begin(), raw.end());
-    reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+
+  reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [oldID](dnsheader& header) {
+      header.id = oldID;
+      return true;
+    });
   });
   luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) {
       if (dq.ednsOptions == nullptr) {
@@ -333,9 +348,19 @@ private:
   luaCtx.registerMember<const DNSName (DNSResponse::*)>("qname", [](const DNSResponse& dq) -> const DNSName { return dq.ids.qname; }, [](DNSResponse& dq, const DNSName& newName) { (void) newName; });
   luaCtx.registerMember<uint16_t (DNSResponse::*)>("qtype", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSResponse& dq, uint16_t newType) { (void) newType; });
   luaCtx.registerMember<uint16_t (DNSResponse::*)>("qclass", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSResponse& dq, uint16_t newClass) { (void) newClass; });
-  luaCtx.registerMember<int (DNSResponse::*)>("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) { dq.getHeader()->rcode = newRCode; });
+  luaCtx.registerMember<int (DNSResponse::*)>("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) {
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [newRCode](dnsheader& header) {
+      header.rcode = newRCode;
+      return true;
+    });
+  });
   luaCtx.registerMember<const ComboAddress (DNSResponse::*)>("remoteaddr", [](const DNSResponse& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSResponse& dq, const ComboAddress newRemote) { (void) newRemote; });
-  luaCtx.registerMember<dnsheader* (DNSResponse::*)>("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast<DNSResponse&>(dr).getHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { *(dr.getHeader()) = *dh; });
+  luaCtx.registerMember<dnsheader* (DNSResponse::*)>("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast<DNSResponse&>(dr).getMutableHeader(); }, [](DNSResponse& dr, const dnsheader* dh) {
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [&dh](dnsheader& header) {
+      header = *dh;
+      return true;
+    });
+  });
   luaCtx.registerMember<uint16_t (DNSResponse::*)>("len", [](const DNSResponse& dq) -> uint16_t { return dq.getData().size(); }, [](DNSResponse& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
   luaCtx.registerMember<uint8_t (DNSResponse::*)>("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; });
   luaCtx.registerMember<bool (DNSResponse::*)>("tcp", [](const DNSResponse& dq) -> bool { return dq.overTCP(); }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; });
@@ -355,7 +380,10 @@ private:
     auto& buffer = dr.getMutableData();
     buffer.clear();
     buffer.insert(buffer.begin(), raw.begin(), raw.end());
-    reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [oldID](dnsheader& header) {
+      header.id = oldID;
+      return true;
+    });
   });
 
   luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) {
index 3d118508cd3e26d92b6aebe35f6dcb1a2f719e5c..7b9ebc7e5fee185aa51bcdeec09e22c269f1d7f5 100644 (file)
@@ -26,6 +26,7 @@
 
 #include "dnsdist.hh"
 #include "dnsdist-concurrent-connections.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-nghttp2-in.hh"
 #include "dnsdist-proxy-protocol.hh"
@@ -511,7 +512,7 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe
       DNSResponse dr(ids, response.d_buffer, ds);
       dr.d_incomingTCPState = state;
 
-      memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH));
+      memcpy(&response.d_cleartextDH, dr.getHeader().get(), sizeof(response.d_cleartextDH));
 
       if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) {
         state->terminateClientConnection();
@@ -668,16 +669,19 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha
 
   {
     /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
-    auto* dh = reinterpret_cast<dnsheader*>(query.data());
-    if (!checkQueryHeaders(dh, *d_ci.cs)) {
+    const dnsheader_aligned dh(query.data());
+    if (!checkQueryHeaders(dh.get(), *d_ci.cs)) {
       return QueryProcessingResult::InvalidHeaders;
     }
 
     if (dh->qdcount == 0) {
       TCPResponse response;
-      dh->rcode = RCode::NotImp;
-      dh->qr = true;
       auto queryID = dh->id;
+      dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
+        header.rcode = RCode::NotImp;
+        header.qr = true;
+        return true;
+      });
       response.d_idstate = std::move(ids);
       response.d_idstate.origID = queryID;
       response.d_idstate.selfGenerated = true;
@@ -696,8 +700,11 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha
   }
 
   DNSQuestion dq(ids, query);
-  const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
-  ids.origFlags = *flags;
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&ids](dnsheader& header) {
+    const uint16_t* flags = getFlagsFromDNSHeader(&header);
+    ids.origFlags = *flags;
+    return true;
+  });
   dq.d_incomingTCPState = state;
   dq.sni = d_handler.getServerNameIndication();
 
@@ -714,7 +721,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha
   if (forwardViaUDPFirst()) {
     // if there was no EDNS, we add it with a large buffer size
     // so we can use UDP to talk to the backend.
-    auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(query.data()));
+    const dnsheader_aligned dh(query.data());
     if (!dh->arcount) {
       if (addEDNS(query, 4096, false, 4096, 0)) {
         dq.ids.ednsAdded = true;
@@ -747,15 +754,15 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha
   // the buffer might have been invalidated by now
   uint16_t queryID;
   {
-    const dnsheader* dh = dq.getHeader();
+    const auto dh = dq.getHeader();
     queryID = dh->id;
   }
 
   if (result == ProcessQueryResult::SendAnswer) {
     TCPResponse response;
     {
-      const dnsheader* dh = dq.getHeader();
-      memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH));
+      const auto dh = dq.getHeader();
+      memcpy(&response.d_cleartextDH, dh.get(), sizeof(response.d_cleartextDH));
     }
     response.d_idstate = std::move(ids);
     response.d_idstate.origID = queryID;
index 6f4cba5315815ebc07c1c9fa53ccdc0ab0aaffcc..eb2ba5785571ccee55431742b320a6e468c24e38 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "dnsdist-xpf.hh"
 
+#include "dnsdist-dnsparser.hh"
 #include "dnsparser.hh"
 #include "xpf.hh"
 
@@ -54,7 +55,9 @@ bool addXPF(DNSQuestion& dq, uint16_t optionCode)
   pos += payload.size();
   (void) pos;
 
-  dq.getHeader()->arcount = htons(ntohs(dq.getHeader()->arcount) + 1);
-
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+    header.arcount = htons(ntohs(header.arcount) + 1);
+    return true;
+  });
   return true;
 }
index ec4ca88ac07ea594feaff27b110a255a77cf040f..d0c0dc629220caef2461d822b46446a52ea1a843 100644 (file)
@@ -53,6 +53,7 @@
 #include "dnsdist-carbon.hh"
 #include "dnsdist-console.hh"
 #include "dnsdist-discovery.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsdist-dynblocks.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-healthchecks.hh"
@@ -197,8 +198,12 @@ static void truncateTC(PacketBuffer& packet, size_t maximumSize, unsigned int qn
     }
 
     packet.resize(static_cast<uint16_t>(sizeof(dnsheader)+qnameWireLength+DNS_TYPE_SIZE+DNS_CLASS_SIZE));
-    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
-    dh->ancount = dh->arcount = dh->nscount = 0;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+      header.ancount = 0;
+      header.arcount = 0;
+      header.nscount = 0;
+      return true;
+    });
 
     if (hadEDNS) {
       addEDNS(packet, maximumSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0);
@@ -232,8 +237,8 @@ static std::unique_ptr<DelayPipe<DelayedPacket>> g_delay{nullptr};
 
 std::string DNSQuestion::getTrailingData() const
 {
-  const char* message = reinterpret_cast<const char*>(this->getHeader());
-  const uint16_t messageLen = getDNSPacketLength(message, this->data.size());
+  const char* message = reinterpret_cast<const char*>(this->getData().data());
+  const uint16_t messageLen = getDNSPacketLength(message, this->getData().size());
   return std::string(message + messageLen, this->getData().size() - messageLen);
 }
 
@@ -251,6 +256,14 @@ bool DNSQuestion::setTrailingData(const std::string& tail)
   return true;
 }
 
+bool DNSQuestion::editHeader(std::function<bool(dnsheader&)> editFunction)
+{
+  if (data.size() < sizeof(dnsheader)) {
+    throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer");
+  }
+  return dnsdist::PacketMangling::editDNSHeaderFromPacket(data, editFunction);
+}
+
 static void doLatencyStats(dnsdist::Protocol protocol, double udiff)
 {
   constexpr auto doAvg = [](double& var, double n, double weight) {
@@ -311,7 +324,7 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname,
     return false;
   }
 
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response.data());
+  const dnsheader_aligned dh(response.data());
   if (dh->qr == 0) {
     ++dnsdist::metrics::g_stats.nonCompliantResponses;
     if (remote) {
@@ -370,11 +383,14 @@ static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
   *flags |= origFlags;
 }
 
-static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
+static bool fixUpQueryTurnedResponse(DNSQuestion& dnsQuestion, const uint16_t origFlags)
 {
-  restoreFlags(dq.getHeader(), origFlags);
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [origFlags](dnsheader& header) {
+    restoreFlags(&header, origFlags);
+    return true;
+  });
 
-  return addEDNSToQueryTurnedResponse(dq);
+  return addEDNSToQueryTurnedResponse(dnsQuestion);
 }
 
 static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, bool* zeroScope)
@@ -383,8 +399,10 @@ static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t
     return false;
   }
 
-  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
-  restoreFlags(dh, origFlags);
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [origFlags](dnsheader& header) {
+    restoreFlags(&header, origFlags);
+    return true;
+  });
 
   if (response.size() == sizeof(dnsheader)) {
     return true;
@@ -422,10 +440,12 @@ static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t
         if (last) {
           /* simply remove the last AR */
           response.resize(response.size() - optLen);
-          dh = reinterpret_cast<struct dnsheader*>(response.data());
-          uint16_t arcount = ntohs(dh->arcount);
-          arcount--;
-          dh->arcount = htons(arcount);
+          dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [](dnsheader& header) {
+            uint16_t arcount = ntohs(header.arcount);
+            arcount--;
+            header.arcount = htons(arcount);
+            return true;
+          });
         }
         else {
           /* Removing an intermediary RR could lead to compression error */
@@ -499,7 +519,10 @@ static bool applyRulesToResponse(const std::vector<DNSDistResponseRuleAction>& r
         return true;
         break;
       case DNSResponseAction::Action::ServFail:
-        dr.getHeader()->rcode = RCode::ServFail;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) {
+          header.rcode = RCode::ServFail;
+          return true;
+        });
         return true;
         break;
         /* non-terminal actions follow */
@@ -660,7 +683,10 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
   if (ids.udpPayloadSize > 0 && response.size() > ids.udpPayloadSize) {
     vinfolog("Got a response of size %d while the initial UDP payload size was %d, truncating", response.size(), ids.udpPayloadSize);
     truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.ids.qname.wirelength());
-    dr.getHeader()->tc = true;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) {
+      header.tc = true;
+      return true;
+    });
   }
   else if (dr.getHeader()->tc && g_truncateTC) {
     truncateTC(response, dr.getMaximumSize(), dr.ids.qname.wirelength());
@@ -669,7 +695,7 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
   /* when the answer is encrypted in place, we need to get a copy
      of the original header before encryption to fill the ring buffer */
   dnsheader cleartextDH;
-  memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
+  memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH));
 
   if (!isAsync) {
     if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) {
@@ -759,7 +785,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
         }
 
         response.resize(static_cast<size_t>(got));
-        dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
+        const dnsheader_aligned dh(response.data());
         queryId = dh->id;
 
         auto ids = dss->getState(queryId);
@@ -775,7 +801,10 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 
         auto du = std::move(ids->du);
 
-        dh->id = ids->origID;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) {
+          header.id = ids->origID;
+          return true;
+        });
         ++dss->responses;
 
         double udiff = ids->queryRealTime.udiff();
@@ -869,7 +898,15 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s
     return false;
   }
 
-  switch(action) {
+  auto setRCode = [&dq](uint8_t rcode) {
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) {
+      header.rcode = rcode;
+      header.qr = true;
+      return true;
+    });
+  };
+
+  switch (action) {
   case DNSAction::Action::Allow:
     return true;
     break;
@@ -879,18 +916,15 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s
     return true;
     break;
   case DNSAction::Action::Nxdomain:
-    dq.getHeader()->rcode = RCode::NXDomain;
-    dq.getHeader()->qr = true;
+    setRCode(RCode::NXDomain);
     return true;
     break;
   case DNSAction::Action::Refused:
-    dq.getHeader()->rcode = RCode::Refused;
-    dq.getHeader()->qr = true;
+    setRCode(RCode::Refused);
     return true;
     break;
   case DNSAction::Action::ServFail:
-    dq.getHeader()->rcode = RCode::ServFail;
-    dq.getHeader()->qr = true;
+    setRCode(RCode::ServFail);
     return true;
     break;
   case DNSAction::Action::Spoof:
@@ -907,11 +941,14 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s
     break;
   case DNSAction::Action::Truncate:
     if (!dq.overTCP()) {
-      dq.getHeader()->tc = true;
-      dq.getHeader()->qr = true;
-      dq.getHeader()->ra = dq.getHeader()->rd;
-      dq.getHeader()->aa = false;
-      dq.getHeader()->ad = false;
+      dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+        header.tc = true;
+        header.qr = true;
+        header.ra = header.rd;
+        header.aa = false;
+        header.ad = false;
+        return true;
+      });
       ++dnsdist::metrics::g_stats.ruleTruncated;
       return true;
     }
@@ -926,7 +963,10 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s
     return true;
     break;
   case DNSAction::Action::NoRecurse:
-    dq.getHeader()->rd = false;
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+      header.rd = false;
+      return true;
+    });
     return true;
     break;
     /* non-terminal actions follow */
@@ -946,6 +986,14 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s
 
 static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now)
 {
+  auto setRCode = [&dq](uint8_t rcode) {
+    dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) {
+      header.rcode = rcode;
+      header.qr = true;
+      return true;
+    });
+  };
+
   if (g_rings.shouldRecordQueries()) {
     g_rings.insertQuery(now, dq.ids.origRemote, dq.ids.qname, dq.ids.qtype, dq.getData().size(), *dq.getHeader(), dq.getProtocol());
   }
@@ -980,6 +1028,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
       if (action == DNSAction::Action::None) {
         action = g_dynBlockAction;
       }
+
       switch (action) {
       case DNSAction::Action::NoOp:
         /* do nothing */
@@ -989,27 +1038,28 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
         vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort());
         updateBlockStats();
 
-        dq.getHeader()->rcode = RCode::NXDomain;
-        dq.getHeader()->qr=true;
+        setRCode(RCode::NXDomain);
         return true;
 
       case DNSAction::Action::Refused:
         vinfolog("Query from %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort());
         updateBlockStats();
 
-        dq.getHeader()->rcode = RCode::Refused;
-        dq.getHeader()->qr = true;
+        setRCode(RCode::Refused);
         return true;
 
       case DNSAction::Action::Truncate:
         if (!dq.overTCP()) {
           updateBlockStats();
           vinfolog("Query from %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort());
-          dq.getHeader()->tc = true;
-          dq.getHeader()->qr = true;
-          dq.getHeader()->ra = dq.getHeader()->rd;
-          dq.getHeader()->aa = false;
-          dq.getHeader()->ad = false;
+          dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+            header.tc = true;
+            header.qr = true;
+            header.ra = header.rd;
+            header.aa = false;
+            header.ad = false;
+            return true;
+          });
           return true;
         }
         else {
@@ -1019,7 +1069,10 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
       case DNSAction::Action::NoRecurse:
         updateBlockStats();
         vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort());
-        dq.getHeader()->rd = false;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+          header.rd = false;
+          return true;
+        });
         return true;
       default:
         updateBlockStats();
@@ -1048,26 +1101,27 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
         vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString());
         updateBlockStats();
 
-        dq.getHeader()->rcode = RCode::NXDomain;
-        dq.getHeader()->qr = true;
+        setRCode(RCode::NXDomain);
         return true;
       case DNSAction::Action::Refused:
         vinfolog("Query from %s for %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString());
         updateBlockStats();
 
-        dq.getHeader()->rcode = RCode::Refused;
-        dq.getHeader()->qr = true;
+        setRCode(RCode::Refused);
         return true;
       case DNSAction::Action::Truncate:
         if (!dq.overTCP()) {
           updateBlockStats();
 
           vinfolog("Query from %s for %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString());
-          dq.getHeader()->tc = true;
-          dq.getHeader()->qr = true;
-          dq.getHeader()->ra = dq.getHeader()->rd;
-          dq.getHeader()->aa = false;
-          dq.getHeader()->ad = false;
+          dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+            header.tc = true;
+            header.qr = true;
+            header.ra = header.rd;
+            header.aa = false;
+            header.ad = false;
+            return true;
+          });
           return true;
         }
         else {
@@ -1077,7 +1131,10 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
       case DNSAction::Action::NoRecurse:
         updateBlockStats();
         vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort());
-        dq.getHeader()->rd = false;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+          header.rd = false;
+          return true;
+        });
         return true;
       default:
         updateBlockStats();
@@ -1368,7 +1425,10 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders
          yet, as we will do a second-lookup */
       if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKey, dq.ids.subnet, dq.ids.dnssecOK, forwardedOverUDP, allowExpired, false, true, dq.ids.protocol != dnsdist::Protocol::DoH || forwardedOverUDP)) {
 
-        restoreFlags(dq.getHeader(), dq.ids.origFlags);
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [flags=dq.ids.origFlags](dnsheader& header) {
+          restoreFlags(&header, flags);
+          return true;
+        });
 
         vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size());
 
@@ -1403,8 +1463,11 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders
 
       vinfolog("%s query for %s|%s from %s, no downstream server available", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort());
       if (g_servFailOnNoPolicy) {
-        dq.getHeader()->rcode = RCode::ServFail;
-        dq.getHeader()->qr = true;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+          header.rcode = RCode::ServFail;
+          header.qr = true;
+          return true;
+        });
 
         fixUpQueryTurnedResponse(dq, dq.ids.origFlags);
 
@@ -1421,7 +1484,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders
     }
 
     /* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */
-    dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader());
+    dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader().get());
 
     if (dq.addXPF && selectedBackend->d_config.xpfRRCode != 0) {
       addXPF(dq, selectedBackend->d_config.xpfRRCode);
@@ -1647,16 +1710,20 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 
     {
       /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
-      struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query.data());
+      const dnsheader_aligned dh(query.data());
       queryId = ntohs(dh->id);
 
-      if (!checkQueryHeaders(dh, cs)) {
+      if (!checkQueryHeaders(dh.get(), cs)) {
         return;
       }
 
       if (dh->qdcount == 0) {
-        dh->rcode = RCode::NotImp;
-        dh->qr = true;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
+          header.rcode = RCode::NotImp;
+          header.qr = true;
+          return true;
+        });
+
         sendUDPResponse(cs.udpFD, query, 0, dest, remote);
         return;
       }
@@ -1667,7 +1734,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ids.protocol = dnsdist::Protocol::DNSCryptUDP;
     }
     DNSQuestion dq(ids, query);
-    const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
+    const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader().get());
     ids.origFlags = *flags;
 
     if (!proxyProtocolValues.empty()) {
@@ -1682,7 +1749,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
 
     // the buffer might have been invalidated by now (resized)
-    struct dnsheader* dh = dq.getHeader();
+    const auto dh = dq.getHeader();
     if (result == ProcessQueryResult::SendAnswer) {
 #ifndef DISABLE_RECVMMSG
 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
index dfaa3d1e58db540a3130f81d10129bac1877cf64..e0e8da324e89948f35a104bd02c1a9b98ee80998 100644 (file)
@@ -89,20 +89,26 @@ struct DNSQuestion
     return data;
   }
 
-  dnsheader* getHeader()
+  bool editHeader(std::function<bool(dnsheader&)> editFunction);
+
+  const dnsheader_aligned getHeader() const
   {
     if (data.size() < sizeof(dnsheader)) {
       throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer");
     }
-    return reinterpret_cast<dnsheader*>(&data.at(0));
+    dnsheader_aligned dh(data.data());
+    return dh;
   }
 
-  const dnsheader* getHeader() const
+  /* this function is not safe against unaligned access, you should
+     use editHeader() instead, but we need it for the Lua bindings */
+  dnsheader* getMutableHeader()
   {
     if (data.size() < sizeof(dnsheader)) {
       throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer");
     }
-    return reinterpret_cast<const dnsheader*>(&data.at(0));
+    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+    return reinterpret_cast<dnsheader*>(data.data());
   }
 
   bool hasRoomFor(size_t more) const
index f9a7bc297bd10a5b96898f8f937fac995987cee2..b0904b295c4792a9164d8d0a1047ce61c77eea04 100644 (file)
@@ -513,6 +513,7 @@ fuzz_target_dnsdistcache_SOURCES = \
        channel.hh channel.cc \
        dns.cc dns.hh \
        dnsdist-cache.cc dnsdist-cache.hh \
+       dnsdist-dnsparser.cc dnsdist-dnsparser.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
        dnsdist-idstate.hh \
        dnsdist-protocols.cc dnsdist-protocols.hh \
index 5bb61ba450918995cdc3526706351e17099ec80e..889249d0c0fb03a0a31c76dec797336c47db99a7 100644 (file)
@@ -52,7 +52,7 @@ struct DesignatedResolvers
 static bool parseSVCParams(const PacketBuffer& answer, std::map<uint16_t, DesignatedResolvers>& resolvers)
 {
   std::map<DNSName, std::vector<ComboAddress>> hints;
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(answer.data());
+  const dnsheader_aligned dh(answer.data());
   PacketReader pr(std::string_view(reinterpret_cast<const char*>(answer.data()), answer.size()));
   uint16_t qdcount = ntohs(dh->qdcount);
   uint16_t ancount = ntohs(dh->ancount);
index 48ce507da8e7331fe56e48aab3fb0a90ff7bd97d..70f0ff2ab015339a35734956f901be5fc6556757 100644 (file)
@@ -129,7 +129,7 @@ int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq)
 
 void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq)
 {
-  return dq->dq->getHeader();
+  return dq->dq->getMutableHeader();
 }
 
 uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq)
@@ -458,14 +458,20 @@ void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, ui
 #ifdef HAVE_DNS_OVER_HTTPS
   PacketBuffer bodyVect(body, body + bodyLen);
   dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
-  dq->dq->getHeader()->qr = true;
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
+    header.qr = true;
+    return true;
+  });
 #endif
 }
 
 void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode)
 {
-  dq->dq->getHeader()->rcode = rcode;
-  dq->dq->getHeader()->qr = true;
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [rcode](dnsheader& header) {
+    header.rcode = rcode;
+    header.qr = true;
+    return true;
+  });
 }
 
 void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len)
@@ -950,11 +956,15 @@ bool dnsdist_ffi_set_answer_from_async(uint16_t asyncID, uint16_t queryID, const
     return false;
   }
 
-  auto oldId = reinterpret_cast<const dnsheader*>(query->query.d_buffer.data())->id;
+  dnsheader_aligned alignedHeader(query->query.d_buffer.data());
+  auto oldID = alignedHeader->id;
   query->query.d_buffer.clear();
   query->query.d_buffer.insert(query->query.d_buffer.begin(), raw, raw + rawSize);
-  reinterpret_cast<dnsheader*>(query->query.d_buffer.data())->id = oldId;
 
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(query->query.d_buffer, [oldID](dnsheader& header) {
+    header.id = oldID;
+    return true;
+  });
   query->query.d_idstate.skipCache = true;
 
   return dnsdist::queueQueryResumptionEvent(std::move(query));
index 35591ffd3e64eb65caccc19fba233916d46c13ec..ac1fe458ccb0473a3fb7f35b7da737b07ab8e69b 100644 (file)
@@ -21,6 +21,7 @@
  */
 
 #include "base64.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsdist-nghttp2-in.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsparser.hh"
@@ -197,10 +198,11 @@ void IncomingHTTP2Connection::handleResponse(const struct timeval& now, TCPRespo
     if (responseDH.get()->tc && state.d_packet && state.d_packet->size() > state.d_proxyProtocolPayloadSize && state.d_packet->size() - state.d_proxyProtocolPayloadSize > sizeof(dnsheader)) {
       vinfolog("Response received from backend %s via UDP, for query %d received from %s via DoH, is truncated, retrying over TCP", response.d_ds->getNameWithAddr(), state.d_streamID, state.origRemote.toStringWithPort());
       auto& query = *state.d_packet;
-      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-      auto* queryDH = reinterpret_cast<struct dnsheader*>(&query.at(state.d_proxyProtocolPayloadSize));
-      /* restoring the original ID */
-      queryDH->id = state.origID;
+      dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&query.at(state.d_proxyProtocolPayloadSize), [origID = state.origID](dnsheader& header) {
+        /* restoring the original ID */
+        header.id = origID;
+        return true;
+      });
 
       state.forwardedOverUDP = false;
       bool proxyProtocolPayloadAdded = state.d_proxyProtocolPayloadSize > 0;
index 3bc9aeb40cc990a6fcfad2f2e4d30de6cb0e874b..26c48ba90194dccd33bef8471c89fd1cd8271381 100644 (file)
@@ -49,7 +49,7 @@ static std::string getFirstTXTAnswer(const std::string& answer)
     throw std::runtime_error("Looking for a TXT record in an answer smaller than the DNS header");
   }
 
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(answer.data());
+  const dnsheader_aligned dh(answer.data());
   PacketReader pr(answer);
   uint16_t qdcount = ntohs(dh->qdcount);
   uint16_t ancount = ntohs(dh->ancount);
index aef6cf6ec38af31cf0a95b59379d8cd00a994abc..53c7bec2623416674bd4714a28b6934842068c14 100644 (file)
@@ -126,7 +126,8 @@ struct TCPResponse : public TCPQuery
     TCPQuery(std::move(buffer), std::move(state)), d_connection(std::move(conn)), d_ds(std::move(ds))
   {
     if (d_buffer.size() >= sizeof(dnsheader)) {
-      memcpy(&d_cleartextDH, reinterpret_cast<const dnsheader*>(d_buffer.data()), sizeof(d_cleartextDH));
+      dnsheader_aligned header(d_buffer.data());
+      memcpy(&d_cleartextDH, header.get(), sizeof(d_cleartextDH));
     }
     else {
       memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
@@ -137,7 +138,8 @@ struct TCPResponse : public TCPQuery
     TCPQuery(std::move(query))
   {
     if (d_buffer.size() >= sizeof(dnsheader)) {
-      memcpy(&d_cleartextDH, reinterpret_cast<const dnsheader*>(d_buffer.data()), sizeof(d_cleartextDH));
+      dnsheader_aligned header(d_buffer.data());
+      memcpy(&d_cleartextDH, header.get(), sizeof(d_cleartextDH));
     }
     else {
       memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
index 94782c812192c927b309229039a738620758e91e..a3ccd8016023b8b865e4da15b9acb8f873e621aa 100644 (file)
@@ -26,6 +26,7 @@
 #include "dns.hh"
 #include "dolog.hh"
 #include "dnsdist-concurrent-connections.hh"
+#include "dnsdist-dnsparser.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-metrics.hh"
 #include "dnsdist-proxy-protocol.hh"
@@ -499,7 +500,7 @@ public:
     DNSResponse dr(dohUnit->ids, dohUnit->response, dohUnit->downstream);
 
     dnsheader cleartextDH{};
-    memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
+    memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH));
 
     if (!response.isAsync()) {
       static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
@@ -716,17 +717,20 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
     {
       /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-      auto* dnsHeader = reinterpret_cast<struct dnsheader*>(unit->query.data());
+      const dnsheader_aligned dnsHeader(unit->query.data());
 
-      if (!checkQueryHeaders(dnsHeader, clientState)) {
+      if (!checkQueryHeaders(dnsHeader.get(), clientState)) {
         unit->status_code = 400;
         handleImmediateResponse(std::move(unit), "DoH invalid headers");
         return;
       }
 
       if (dnsHeader->qdcount == 0U) {
-        dnsHeader->rcode = RCode::NotImp;
-        dnsHeader->qr = true;
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+          header.rcode = RCode::NotImp;
+          header.qr = true;
+          return true;
+        });
         unit->response = std::move(unit->query);
 
         handleImmediateResponse(std::move(unit), "DoH empty query");
@@ -751,7 +755,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false)
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), static_cast<int>(sizeof(dnsheader)), false, &ids.qtype, &ids.qclass);
     DNSQuestion dnsQuestion(ids, unit->query);
-    const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader());
+    const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader().get());
     ids.origFlags = *flags;
     ids.cs = &clientState;
     dnsQuestion.sni = std::move(unit->sni);
@@ -1322,9 +1326,10 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
         dohUnit->query.size() > dohUnit->ids.d_proxyProtocolPayloadSize &&
         (dohUnit->query.size() - dohUnit->ids.d_proxyProtocolPayloadSize) > sizeof(dnsheader)) {
       /* restoring the original ID */
-      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
-      auto* queryDH = reinterpret_cast<struct dnsheader*>(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize));
-      queryDH->id = dohUnit->ids.origID;
+      dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize), [oldID=dohUnit->ids.origID](dnsheader& header) {
+        header.id = oldID;
+        return true;
+      });
       dohUnit->ids.forwardedOverUDP = false;
       dohUnit->tcp = true;
       dohUnit->truncated = false;
@@ -1645,7 +1650,7 @@ void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, InternalQueryState&&
 
     DNSResponse dnsResponse(dohUnit->ids, udpResponse, dohUnit->downstream);
     dnsheader cleartextDH{};
-    memcpy(&cleartextDH, dnsResponse.getHeader(), sizeof(cleartextDH));
+    memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
 
     dnsResponse.ids.du = std::move(dohUnit);
     if (!processResponse(udpResponse, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) {
index 9489120191bda6554456b15065cfb423380934ed..e2fc597d132ec43e8bc2dc5bcbde422fce10a2bf 100644 (file)
@@ -140,7 +140,7 @@ public:
     DNSResponse dnsResponse(unit->ids, unit->response, unit->downstream);
 
     dnsheader cleartextDH{};
-    memcpy(&cleartextDH, dnsResponse.getHeader(), sizeof(cleartextDH));
+    memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
 
     if (!response.isAsync()) {
 
index 81897a340d1ca88a152f5e45442af6c89d2e8282..b886b1fd49b3c4b82ec9c710076d3e556c864215 100644 (file)
@@ -466,7 +466,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCache)
   ids.queryRealTime.start();
   DNSQuestion dq(ids, query);
   packetCache->get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
-  packetCache->insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
+  packetCache->insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
 
   std::string poolName("test-pool");
   auto testPool = std::make_shared<ServerPool>();
index b971e4ac15173018fe92fd70808cf5c24fe4ea76..bd3f3b75fad22d7cab7c24da668d8ff56e2501b3 100644 (file)
@@ -251,7 +251,7 @@ private:
 
       auto& query = conn->d_queries.at(frame->hd.stream_id);
       BOOST_REQUIRE_GT(query.size(), sizeof(dnsheader));
-      auto dh = reinterpret_cast<const dnsheader*>(query.data());
+      const dnsheader_aligned dh(query.data());
       uint16_t id = ntohs(dh->id);
       // cerr<<"got query ID "<<id<<endl;
 
@@ -500,7 +500,7 @@ public:
     }
 
     BOOST_REQUIRE_GT(response.d_buffer.size(), sizeof(dnsheader));
-    auto dh = reinterpret_cast<const dnsheader*>(response.d_buffer.data());
+    const dnsheader_aligned dh(response.d_buffer.data());
     uint16_t id = ntohs(dh->id);
 
     BOOST_REQUIRE_EQUAL(id, d_id);
index 7b83ad55fc1b4a6d4cd85fa62c78b99bbc36c27b..e799c0d34ac0e5ff71f32f821cf0a3b8a66d0e1c 100644 (file)
@@ -768,7 +768,7 @@ static bool checkIfPacketContainsRecords(const PacketBuffer& packet, const std::
   }
 
   try {
-    auto dh = reinterpret_cast<const dnsheader*>(packet.data());
+    const dnsheader_aligned dh(packet.data());
     DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(packet.data())), length);
 
     const uint16_t qdcount = ntohs(dh->qdcount);
@@ -804,7 +804,7 @@ static int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, Pa
     return EINVAL;
   }
   try {
-    const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+    const dnsheader_aligned dh(initialPacket.data());
 
     if (ntohs(dh->qdcount) == 0)
       return ENOENT;
@@ -979,7 +979,7 @@ uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA
   }
   try
   {
-    const dnsheader* dh = (const dnsheader*) packet;
+    const dnsheader_aligned dh(packet);
     DNSPacketMangler dpm(const_cast<char*>(packet), length);
 
     const uint16_t qdcount = ntohs(dh->qdcount);
@@ -1026,7 +1026,7 @@ uint32_t getDNSPacketLength(const char* packet, size_t length)
   }
   try
   {
-    const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
+    const dnsheader_aligned dh(packet);
     DNSPacketMangler dpm(const_cast<char*>(packet), length);
 
     const uint16_t qdcount = ntohs(dh->qdcount);
@@ -1058,7 +1058,7 @@ uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t sectio
   }
   try
   {
-    const dnsheader* dh = (const dnsheader*) packet;
+    const dnsheader_aligned dh(packet);
     DNSPacketMangler dpm(const_cast<char*>(packet), length);
 
     const uint16_t qdcount = ntohs(dh->qdcount);
@@ -1148,7 +1148,7 @@ bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payl
 
   try
   {
-    const dnsheader* dh = (const dnsheader*) packet;
+    const dnsheader_aligned dh(packet);
     DNSPacketMangler dpm(const_cast<char*>(packet), length);
 
     const uint16_t qdcount = ntohs(dh->qdcount);
@@ -1191,13 +1191,12 @@ bool visitDNSPacket(const std::string_view& packet, const std::function<bool(uin
 
   try
   {
-    dnsheader dh;
-    memcpy(&dh, reinterpret_cast<const dnsheader*>(packet.data()), sizeof(dh));
-    uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
+    const dnsheader_aligned dh(packet.data());
+    uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
     PacketReader reader(packet);
 
     uint64_t n;
-    for (n = 0; n < ntohs(dh.qdcount) ; ++n) {
+    for (n = 0; n < ntohs(dh->qdcount) ; ++n) {
       (void) reader.getName();
       /* type and class */
       reader.skip(4);
@@ -1206,7 +1205,7 @@ bool visitDNSPacket(const std::string_view& packet, const std::function<bool(uin
     for (n = 0; n < numrecords; ++n) {
       (void) reader.getName();
 
-      uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
+      uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
       uint16_t dnstype = reader.get16BitInt();
       uint16_t dnsclass = reader.get16BitInt();
 
index 212c3b5f8fb88d63f03b985da4867c7a0720e097..909f08e0b791a69d08025c3d6f1dfb78b77ef868 100644 (file)
@@ -74,7 +74,7 @@ DnstapMessage::DnstapMessage(std::string& buffer, DnstapMessage::MessageType typ
   }
 
   if (packet != nullptr && len >= sizeof(dnsheader)) {
-    const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+    const dnsheader_aligned dh(packet);
     if (!dh->qr) {
       pbf_message.add_bytes(DnstapMessageFields::query_message, packet, len);
     } else {
index 6f6fcf35032332c04ebb1661aebce37060426f15..e8c1dceae7d6c131ea1567b19cd8f32628a1d3ba 100644 (file)
@@ -85,7 +85,7 @@ void pdns::ProtoZero::Message::addRRsFromPacket(const char* packet, const size_t
     return;
   }
 
-  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+  const dnsheader_aligned dh(packet);
 
   if (ntohs(dh->ancount) == 0) {
     return;
index 4bad3d0dbcf6ea46b04dc5d33e07f89ebc00a21b..63a86fef7694c1925e2621835e1a5e7791e5d071 100644 (file)
@@ -57,7 +57,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) {
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
 
-      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
+      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
 
       found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
       if (found == true) {
@@ -168,7 +168,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) {
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
 
-      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none);
+      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none);
 
       found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
       if (found == true) {
@@ -265,7 +265,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) {
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
 
-      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
+      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
       found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
       BOOST_CHECK_EQUAL(found, true);
       BOOST_CHECK(!subnet);
@@ -281,7 +281,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) {
       BOOST_CHECK_EQUAL(found, false);
       BOOST_CHECK(!subnet);
 
-      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, !receivedOverUDP, RCode::NoError, boost::none);
+      PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, !receivedOverUDP, RCode::NoError, boost::none);
       found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, !receivedOverUDP, 0, true);
       BOOST_CHECK_EQUAL(found, true);
       BOOST_CHECK(!subnet);
@@ -328,13 +328,13 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) {
     BOOST_CHECK(!subnet);
 
     // Insert with failure-TTL of 0 (-> should not enter cache).
-    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(0));
+    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(0));
     found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
 
     // Insert with failure-TTL non-zero (-> should enter cache).
-    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(300));
+    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(300));
     found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
     BOOST_CHECK_EQUAL(found, true);
     BOOST_CHECK(!subnet);
@@ -383,7 +383,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) {
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
 
-    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
+    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
     found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
     BOOST_CHECK_EQUAL(found, true);
     BOOST_CHECK(!subnet);
@@ -438,7 +438,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) {
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
 
-    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
+    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
     found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
     BOOST_CHECK_EQUAL(found, true);
     BOOST_CHECK(!subnet);
@@ -492,7 +492,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTruncated) {
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
 
-    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
+    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
 
     bool allowTruncated = true;
     found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true, allowTruncated);
@@ -542,7 +542,7 @@ static void threadMangler(unsigned int offset)
       DNSQuestion dq(ids, query);
       g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
 
-      g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
+      g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
     }
   }
   catch(PDNSException& e) {
@@ -1074,7 +1074,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheXFR) {
     BOOST_CHECK_EQUAL(found, false);
     BOOST_CHECK(!subnet);
 
-    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, ids.qtype, ids.qclass, response, receivedOverUDP, 0, boost::none);
+    PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, ids.qtype, ids.qclass, response, receivedOverUDP, 0, boost::none);
     found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
     BOOST_CHECK_EQUAL(found, false);
   }