]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Only parse EDNS Z once 15872/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 17 Jul 2025 15:11:11 +0000 (17:11 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 18 Jul 2025 15:11:27 +0000 (17:11 +0200)
Signed-off-by: Remi Gacogne <remi.gacogne@powerdns.com>
pdns/dnsdistdist/dnsdist-idstate.hh
pdns/dnsdistdist/dnsdist-self-answers.cc
pdns/dnsdistdist/dnsdist.cc

index 86db2dcbc2ee8f389513cf9414f719c753882491..10e24120511f2c39d5949b8991eaedf1477b5e78 100644 (file)
@@ -167,12 +167,12 @@ struct InternalQueryState
   uint16_t origFlags{0}; // 2
   uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2
   uint16_t udpPayloadSize{0}; // Max UDP payload size from the query // 2
+  std::optional<bool> dnssecOK;
   dnsdist::Protocol protocol; // 1
   uint8_t restartCount{0}; // 1
   bool ednsAdded{false};
   bool ecsAdded{false};
   bool skipCache{false};
-  bool dnssecOK{false};
   bool useZeroScope{false};
   bool forwardedOverUDP{false};
   bool selfGenerated{false};
index 2c1f681046ec8a63e12e185c1b80d80aa717eb81..26beb80e3432ce9e88d6adc79fcdde00fc44fe7f 100644 (file)
@@ -52,6 +52,29 @@ static void addRecordHeader(PacketBuffer& packet, size_t& position, uint16_t qcl
   position += recordstart.size();
 }
 
+static std::pair<bool, bool> getEDNSStatusInQuery(DNSQuestion& dnsQuestion)
+{
+  if (!dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses) {
+    return {false, false};
+  }
+
+  if (dnsQuestion.ids.dnssecOK) {
+    if (*dnsQuestion.ids.dnssecOK) {
+      /* DNSSECOK was set, which means the query had EDNS */
+      return {true, true};
+    }
+  }
+
+  if (queryHasEDNS(dnsQuestion)) {
+    bool dnssecOK = ((dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
+    dnsQuestion.ids.dnssecOK = dnssecOK;
+    return {true, dnssecOK};
+  }
+
+  dnsQuestion.ids.dnssecOK = false;
+  return {false, false};
+}
+
 bool generateAnswerFromCNAME(DNSQuestion& dnsQuestion, const DNSName& cname, const dnsdist::ResponseConfig& responseConfig)
 {
   QType qtype = QType::CNAME;
@@ -62,13 +85,7 @@ bool generateAnswerFromCNAME(DNSQuestion& dnsQuestion, const DNSName& cname, con
     return false;
   }
 
-  bool dnssecOK = false;
-  bool hadEDNS = false;
-  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses && queryHasEDNS(dnsQuestion)) {
-    hadEDNS = true;
-    dnssecOK = ((dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
-  }
-
+  auto [hadEDNS, dnssecOK] = getEDNSStatusInQuery(dnsQuestion);
   auto& data = dnsQuestion.getMutableData();
   data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + totrdatalen); // there goes your EDNS
   size_t position = sizeof(dnsheader) + qnameWireLength + 4;
@@ -122,13 +139,7 @@ bool generateAnswerFromIPAddresses(DNSQuestion& dnsQuestion, const std::vector<C
     return false;
   }
 
-  bool dnssecOK = false;
-  bool hadEDNS = false;
-  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses && queryHasEDNS(dnsQuestion)) {
-    hadEDNS = true;
-    dnssecOK = ((dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
-  }
-
+  auto [hadEDNS, dnssecOK] = getEDNSStatusInQuery(dnsQuestion);
   auto& data = dnsQuestion.getMutableData();
   data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + totrdatalen); // there goes your EDNS
   size_t position = sizeof(dnsheader) + qnameWireLength + 4;
@@ -187,13 +198,7 @@ bool generateAnswerFromRDataEntries(DNSQuestion& dnsQuestion, const std::vector<
     return false;
   }
 
-  bool dnssecOK = false;
-  bool hadEDNS = false;
-  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses && queryHasEDNS(dnsQuestion)) {
-    hadEDNS = true;
-    dnssecOK = ((dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
-  }
-
+  auto [hadEDNS, dnssecOK] = getEDNSStatusInQuery(dnsQuestion);
   auto& data = dnsQuestion.getMutableData();
   data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords * 12 /* recordstart */ + totrdatalen); // there goes your EDNS
   size_t position = sizeof(dnsheader) + qnameWireLength + 4;
@@ -246,12 +251,7 @@ bool generateAnswerFromRawPacket(DNSQuestion& dnsQuestion, const PacketBuffer& p
 
 bool removeRecordsAndSetRCode(DNSQuestion& dnsQuestion, uint8_t rcode)
 {
-  bool dnssecOK = false;
-  bool hadEDNS = false;
-  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses && queryHasEDNS(dnsQuestion)) {
-    hadEDNS = true;
-    dnssecOK = ((dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0);
-  }
+  auto [hadEDNS, dnssecOK] = getEDNSStatusInQuery(dnsQuestion);
 
   dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [rcode](dnsheader& header) {
     header.rcode = rcode;
index 75ed52232be8c27049815c90132e620b521e2590..e858d66ca5e9e4ed4186510b0b02119220dc8d0e 100644 (file)
@@ -526,7 +526,7 @@ bool processResponseAfterRules(PacketBuffer& response, DNSResponse& dnsResponse,
       // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
       cacheKey = dnsResponse.ids.cacheKeyNoECS;
     }
-    dnsResponse.ids.packetCache->insert(cacheKey, zeroScope ? boost::none : dnsResponse.ids.subnet, dnsResponse.ids.cacheFlags, dnsResponse.ids.dnssecOK, dnsResponse.ids.qname, dnsResponse.ids.qtype, dnsResponse.ids.qclass, response, dnsResponse.ids.forwardedOverUDP, dnsResponse.getHeader()->rcode, dnsResponse.ids.tempFailureTTL);
+    dnsResponse.ids.packetCache->insert(cacheKey, zeroScope ? boost::none : dnsResponse.ids.subnet, dnsResponse.ids.cacheFlags, dnsResponse.ids.dnssecOK ? *dnsResponse.ids.dnssecOK : false, dnsResponse.ids.qname, dnsResponse.ids.qtype, dnsResponse.ids.qclass, response, dnsResponse.ids.forwardedOverUDP, dnsResponse.getHeader()->rcode, dnsResponse.ids.tempFailureTTL);
 
     const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains;
     const auto& cacheInsertedRespRuleActions = dnsdist::rules::getResponseRuleChain(chains, dnsdist::rules::ResponseRuleChain::CacheInsertedResponseRules);
@@ -1436,7 +1436,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
 
     uint32_t allowExpired = selectedBackend ? 0 : dnsdist::configuration::getCurrentRuntimeConfiguration().d_staleCacheEntriesTTL;
 
-    if (dnsQuestion.ids.packetCache && !dnsQuestion.ids.skipCache) {
+    if (dnsQuestion.ids.packetCache && !dnsQuestion.ids.skipCache && !dnsQuestion.ids.dnssecOK) {
       dnsQuestion.ids.dnssecOK = (dnsdist::getEDNSZ(dnsQuestion) & EDNS_HEADER_FLAG_DO) != 0;
     }
 
@@ -1445,7 +1445,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
       // we need ECS parsing (parseECS) to be true so we can be sure that the initial incoming query did not have an existing
       // ECS option, which would make it unsuitable for the zero-scope feature.
       if (dnsQuestion.ids.packetCache && !dnsQuestion.ids.skipCache && (!selectedBackend || !selectedBackend->d_config.disableZeroScope) && dnsQuestion.ids.packetCache->isECSParsingEnabled()) {
-        if (dnsQuestion.ids.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, &dnsQuestion.ids.cacheKeyNoECS, dnsQuestion.ids.subnet, dnsQuestion.ids.dnssecOK, willBeForwardedOverUDP, allowExpired, false, true, false)) {
+        if (dnsQuestion.ids.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, &dnsQuestion.ids.cacheKeyNoECS, dnsQuestion.ids.subnet, *dnsQuestion.ids.dnssecOK, willBeForwardedOverUDP, allowExpired, false, true, false)) {
 
           vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dnsQuestion.ids.qname.toLogString(), QType(dnsQuestion.ids.qtype).toString(), dnsQuestion.ids.origRemote.toStringWithPort(), dnsQuestion.ids.protocol.toString(), dnsQuestion.getData().size());
 
@@ -1475,7 +1475,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
          For DoH, this lookup is done with the protocol set to TCP but we will retry over UDP below,
          therefore we do not record a miss for queries received over DoH and forwarded over TCP
          yet, as we will do a second-lookup */
-      if (dnsQuestion.ids.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, dnsQuestion.ids.protocol == dnsdist::Protocol::DoH ? &dnsQuestion.ids.cacheKeyTCP : &dnsQuestion.ids.cacheKey, dnsQuestion.ids.subnet, dnsQuestion.ids.dnssecOK, dnsQuestion.ids.protocol != dnsdist::Protocol::DoH && willBeForwardedOverUDP, allowExpired, false, true, dnsQuestion.ids.protocol != dnsdist::Protocol::DoH || !willBeForwardedOverUDP)) {
+      if (dnsQuestion.ids.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, dnsQuestion.ids.protocol == dnsdist::Protocol::DoH ? &dnsQuestion.ids.cacheKeyTCP : &dnsQuestion.ids.cacheKey, dnsQuestion.ids.subnet, *dnsQuestion.ids.dnssecOK, dnsQuestion.ids.protocol != dnsdist::Protocol::DoH && willBeForwardedOverUDP, allowExpired, false, true, dnsQuestion.ids.protocol != dnsdist::Protocol::DoH || !willBeForwardedOverUDP)) {
 
         dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [flags = dnsQuestion.ids.origFlags](dnsheader& header) {
           restoreFlags(&header, flags);
@@ -1495,7 +1495,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
       if (dnsQuestion.ids.protocol == dnsdist::Protocol::DoH && willBeForwardedOverUDP) {
         /* do a second-lookup for responses received over UDP, but we do not want TC=1 answers */
         /* we need to be careful to keep the existing cache-key (TCP) */
-        if (dnsQuestion.ids.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, &dnsQuestion.ids.cacheKey, dnsQuestion.ids.subnet, dnsQuestion.ids.dnssecOK, true, allowExpired, false, false, true)) {
+        if (dnsQuestion.ids.packetCache->get(dnsQuestion, dnsQuestion.getHeader()->id, &dnsQuestion.ids.cacheKey, dnsQuestion.ids.subnet, *dnsQuestion.ids.dnssecOK, true, allowExpired, false, false, true)) {
           if (!prepareOutgoingResponse(*dnsQuestion.ids.cs, dnsQuestion, true)) {
             return ProcessQueryResult::Drop;
           }
@@ -1658,6 +1658,9 @@ public:
       uint16_t zValue = 0;
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
       getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(buffer.data()), buffer.size(), &ids.udpPayloadSize, &zValue);
+      if (!ids.dnssecOK) {
+        ids.dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
+      }
       if (ids.udpPayloadSize < 512) {
         ids.udpPayloadSize = 512;
       }
@@ -1871,6 +1874,9 @@ static void processUDPQuery(ClientState& clientState, const struct msghdr* msgh,
     uint16_t zValue = 0;
     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
     getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(query.data()), query.size(), &udpPayloadSize, &zValue);
+    if (!ids.dnssecOK) {
+      ids.dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
+    }
     if (udpPayloadSize < 512) {
       udpPayloadSize = 512;
     }