]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Cache based on the DNS flags of the query after applying the rules 10664/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 24 Aug 2021 09:23:54 +0000 (11:23 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 24 Aug 2021 09:32:03 +0000 (11:32 +0200)
The tentative fix in dbadb4d272a3317407e6bc934f55c2d41a87c0ac actually
introduced an issue, because the backend might not perfectly echo the
RD and CD flags as they were in the query.
We can't use the "original" (before applying rules) flags either, so
we need to store the flags as they were sent to the backend to be
able to correctly store them in the cache.

pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-idstate.cc
regression-tests.dnsdist/test_Caching.py

index 41af3461f32242a99d86e3a7c7bfd5984b5f7a5b..cf96c599ad06a313bbe803d4df735c8b118babf1 100644 (file)
@@ -292,11 +292,6 @@ static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
   *flags |= origFlags;
 }
 
-static uint16_t getRDAndCDFlagsFromDNSHeader(const struct dnsheader* dh)
-{
-  return static_cast<uint16_t>(dh->rd) << FLAGS_RD_OFFSET | static_cast<uint16_t>(dh->cd) << FLAGS_CD_OFFSET;
-}
-
 static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
 {
   restoreFlags(dq.getHeader(), origFlags);
@@ -450,9 +445,6 @@ bool processResponse(PacketBuffer& response, LocalStateHolder<vector<DNSDistResp
     return false;
   }
 
-  /* We need to get the flags for the packet cache before restoring the original ones, otherwise it might not match later queries */
-  const auto cacheFlags = getRDAndCDFlagsFromDNSHeader(dr.getHeader());
-
   bool zeroScope = false;
   if (!fixUpResponse(response, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, dr.useZeroScope ? &zeroScope : nullptr)) {
     return false;
@@ -471,7 +463,7 @@ bool processResponse(PacketBuffer& response, LocalStateHolder<vector<DNSDistResp
       zeroScope = false;
     }
     // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-    dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, cacheFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, response, receivedOverUDP, dr.getHeader()->rcode, dr.tempFailureTTL);
+    dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.cacheFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, response, receivedOverUDP, dr.getHeader()->rcode, dr.tempFailureTTL);
   }
 
 #ifdef HAVE_DNSCRYPT
@@ -1273,6 +1265,9 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
       return ProcessQueryResult::Drop;
     }
 
+    /* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */
+    dq.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader());
+
     if (dq.addXPF && selectedBackend->xpfRRCode != 0) {
       addXPF(dq, selectedBackend->xpfRRCode);
     }
index 30ef1ad0ce2292ff0224618b9e5de5844bdf018a..801f35f55f6c4381b05b53ed919310dc2ead4c24 100644 (file)
@@ -161,6 +161,7 @@ public:
   const uint16_t qclass;
   uint16_t ecsPrefixLength;
   uint16_t origFlags;
+  uint16_t cacheFlags; /* DNS flags as sent to the backend */
   const Protocol protocol;
   uint8_t ednsRCode{0};
   bool skipCache{false};
@@ -582,7 +583,7 @@ struct IDState
 {
   IDState(): sentTime(true), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;}
   IDState(const IDState& orig) = delete;
-  IDState(IDState&& rhs): subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), origFD(rhs.origFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope)
+  IDState(IDState&& rhs): subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), origFD(rhs.origFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope)
   {
     if (rhs.isInUse()) {
       throw std::runtime_error("Trying to move an in-use IDState");
@@ -632,6 +633,7 @@ struct IDState
     qclass = rhs.qclass;
     origID = rhs.origID;
     origFlags = rhs.origFlags;
+    cacheFlags = rhs.cacheFlags;
     protocol = rhs.protocol;
     uniqueId = std::move(rhs.uniqueId);
     ednsAdded = rhs.ednsAdded;
@@ -738,6 +740,7 @@ struct IDState
   uint16_t qclass{0};                                         // 2
   uint16_t origID{0};                                         // 2
   uint16_t origFlags{0};                                      // 2
+  uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2
   DNSQuestion::Protocol protocol;                             // 1
   boost::optional<boost::uuids::uuid> uniqueId{boost::none};  // 17 (placed here to reduce the space lost to padding)
   bool ednsAdded{false};
index a88ab219d954eb435b5e55bc7b3c4ce2441f5b3c..f34e79de587293badb7877903d0b690362188605 100644 (file)
@@ -5,6 +5,7 @@ DNSResponse makeDNSResponseFromIDState(IDState& ids, PacketBuffer& data)
 {
   DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, &ids.origDest, &ids.origRemote, data, ids.protocol, &ids.sentTime.d_start);
   dr.origFlags = ids.origFlags;
+  dr.cacheFlags = ids.cacheFlags;
   dr.ecsAdded = ids.ecsAdded;
   dr.ednsAdded = ids.ednsAdded;
   dr.useZeroScope = ids.useZeroScope;
@@ -41,6 +42,7 @@ void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname)
   ids.delayMsec = dq.delayMsec;
   ids.tempFailureTTL = dq.tempFailureTTL;
   ids.origFlags = dq.origFlags;
+  ids.cacheFlags = dq.cacheFlags;
   ids.cacheKey = dq.cacheKey;
   ids.cacheKeyNoECS = dq.cacheKeyNoECS;
   ids.subnet = dq.subnet;
index f5a6d4683813a8f81fd8df989c73be200104a404..ac4061afe8a1ca08850d591375d84dc90f0eb2e4 100644 (file)
@@ -2567,18 +2567,19 @@ class TestCachingAlteredHeader(DNSDistTest):
                                     '192.0.2.1')
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEqual(expectedQuery, receivedQuery)
-        self.assertEqual(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEqual(expectedQuery, receivedQuery)
+            self.assertEqual(receivedResponse, expectedResponse)
 
         # next query should hit the cache
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertFalse(receivedQuery)
-        self.assertTrue(receivedResponse)
-        self.assertEqual(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
 
         # same query with RD=0, should hit the cache as well
         query.flags &= ~dns.flags.RD
@@ -2589,7 +2590,78 @@ class TestCachingAlteredHeader(DNSDistTest):
                                     dns.rdatatype.A,
                                     '192.0.2.1')
         expectedResponse.answer.append(rrset)
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertFalse(receivedQuery)
-        self.assertTrue(receivedResponse)
-        self.assertEqual(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
+
+class TestCachingBackendSettingRD(DNSDistTest):
+
+    _config_template = """
+    pc = newPacketCache(100)
+    getPool(""):setCache(pc)
+    newServer{address="127.0.0.1:%d"}
+    """
+
+    def testCachingBackendSetRD(self):
+        """
+        Cache: The backend sets RD=1 in the response even if the query had RD=0
+        """
+        name = 'backend-sets-rd.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
+        expectedQuery = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        response.flags |= dns.flags.RD
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags &= ~dns.flags.RD
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEqual(expectedQuery, receivedQuery)
+            self.assertEqual(receivedResponse, expectedResponse)
+
+        # exact same query should be cached
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
+
+        # same query with RD=1, should NOT hit the cache
+        query.flags |= dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)