]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Cache based on the DNS flags of the query after applying the rules 10696/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 24 Aug 2021 09:23:54 +0000 (11:23 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 7 Sep 2021 14:51:18 +0000 (16:51 +0200)
The tentative fix in dbadb4d272a3317407e6bc934f55c2d41a87c0ac actually
introduced an issue, because the backend might not perfectly echo the
RD and CD flags as they were in the query.
We can't use the "original" (before applying rules) flags either, so
we need to store the flags as they were sent to the backend to be
able to correctly store them in the cache.

(cherry picked from commit 29d9661fe21a2c4e96f94db1735629363dc07b2e)

pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-idstate.cc
regression-tests.dnsdist/test_Caching.py

index a61b0b687aba283907cd7bc4ea20be76fcfae069..48ba2f7817ee0a122de84432a8254041f771ed28 100644 (file)
@@ -292,11 +292,6 @@ static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
   *flags |= origFlags;
 }
 
-static uint16_t getRDAndCDFlagsFromDNSHeader(const struct dnsheader* dh)
-{
-  return static_cast<uint16_t>(dh->rd) << FLAGS_RD_OFFSET | static_cast<uint16_t>(dh->cd) << FLAGS_CD_OFFSET;
-}
-
 static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
 {
   restoreFlags(dq.getHeader(), origFlags);
@@ -448,9 +443,6 @@ bool processResponse(PacketBuffer& response, LocalStateHolder<vector<DNSDistResp
     return false;
   }
 
-  /* We need to get the flags for the packet cache before restoring the original ones, otherwise it might not match later queries */
-  const auto cacheFlags = getRDAndCDFlagsFromDNSHeader(dr.getHeader());
-
   bool zeroScope = false;
   if (!fixUpResponse(response, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, dr.useZeroScope ? &zeroScope : nullptr)) {
     return false;
@@ -469,7 +461,7 @@ bool processResponse(PacketBuffer& response, LocalStateHolder<vector<DNSDistResp
       zeroScope = false;
     }
     // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
-    dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, cacheFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, response, dr.tcp, dr.getHeader()->rcode, dr.tempFailureTTL);
+    dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.cacheFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, response, dr.tcp, dr.getHeader()->rcode, dr.tempFailureTTL);
   }
 
 #ifdef HAVE_DNSCRYPT
@@ -1277,6 +1269,9 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders&
       return ProcessQueryResult::Drop;
     }
 
+    /* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */
+    dq.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader());
+
     if (dq.addXPF && selectedBackend->xpfRRCode != 0) {
       addXPF(dq, selectedBackend->xpfRRCode);
     }
index e2448fab46268b169967e5e3a59fd9a8a96630cc..69d6ec2ae605f5aa93d5ef9b8f2eb304634c1717 100644 (file)
@@ -144,6 +144,7 @@ public:
   const uint16_t qclass;
   uint16_t ecsPrefixLength;
   uint16_t origFlags;
+  uint16_t cacheFlags; /* DNS flags as sent to the backend */
   uint8_t ednsRCode{0};
   const bool tcp;
   bool skipCache{false};
@@ -559,7 +560,7 @@ struct IDState
 {
   IDState(): sentTime(true), delayMsec(0), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;}
   IDState(const IDState& orig) = delete;
-  IDState(IDState&& rhs): origRemote(rhs.origRemote), origDest(rhs.origDest), sentTime(rhs.sentTime), qname(std::move(rhs.qname)), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), subnet(rhs.subnet), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), origFD(rhs.origFD), delayMsec(rhs.delayMsec), tempFailureTTL(rhs.tempFailureTTL), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope)
+  IDState(IDState&& rhs): origRemote(rhs.origRemote), origDest(rhs.origDest), sentTime(rhs.sentTime), qname(std::move(rhs.qname)), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), subnet(rhs.subnet), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), origFD(rhs.origFD), delayMsec(rhs.delayMsec), tempFailureTTL(rhs.tempFailureTTL), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope)
   {
     if (rhs.isInUse()) {
       throw std::runtime_error("Trying to move an in-use IDState");
@@ -599,6 +600,7 @@ struct IDState
     qclass = rhs.qclass;
     origID = rhs.origID;
     origFlags = rhs.origFlags;
+    cacheFlags = rhs.cacheFlags;
     origFD = rhs.origFD;
     delayMsec = rhs.delayMsec;
     tempFailureTTL = rhs.tempFailureTTL;
@@ -710,6 +712,7 @@ struct IDState
   uint16_t qclass{0};                                         // 2
   uint16_t origID{0};                                         // 2
   uint16_t origFlags{0};                                      // 2
+  uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2
   int origFD{-1};
   int delayMsec{0};
   boost::optional<uint32_t> tempFailureTTL;
index ad124179f4b9f8cfb7ab6efeaebf56582fdbc4cc..21b7d0663e16c29a56e249f2b5c122a48903a4df 100644 (file)
@@ -5,6 +5,7 @@ DNSResponse makeDNSResponseFromIDState(IDState& ids, PacketBuffer& data, bool is
 {
   DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, &ids.origDest, &ids.origRemote, data, isTCP, &ids.sentTime.d_start);
   dr.origFlags = ids.origFlags;
+  dr.cacheFlags = ids.cacheFlags;
   dr.ecsAdded = ids.ecsAdded;
   dr.ednsAdded = ids.ednsAdded;
   dr.useZeroScope = ids.useZeroScope;
@@ -40,6 +41,7 @@ void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname)
   ids.delayMsec = dq.delayMsec;
   ids.tempFailureTTL = dq.tempFailureTTL;
   ids.origFlags = dq.origFlags;
+  ids.cacheFlags = dq.cacheFlags;
   ids.cacheKey = dq.cacheKey;
   ids.cacheKeyNoECS = dq.cacheKeyNoECS;
   ids.subnet = dq.subnet;
index 28c517843f254006966a44117c1a09af5c2eeb43..af74c564f9eb4cff0946d3364c81a24bffc604ec 100644 (file)
@@ -2567,18 +2567,19 @@ class TestCachingAlteredHeader(DNSDistTest):
                                     '192.0.2.1')
         expectedResponse.answer.append(rrset)
 
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = expectedQuery.id
-        self.assertEqual(expectedQuery, receivedQuery)
-        self.assertEqual(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEqual(expectedQuery, receivedQuery)
+            self.assertEqual(receivedResponse, expectedResponse)
 
         # next query should hit the cache
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertFalse(receivedQuery)
-        self.assertTrue(receivedResponse)
-        self.assertEqual(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
 
         # same query with RD=0, should hit the cache as well
         query.flags &= ~dns.flags.RD
@@ -2589,7 +2590,78 @@ class TestCachingAlteredHeader(DNSDistTest):
                                     dns.rdatatype.A,
                                     '192.0.2.1')
         expectedResponse.answer.append(rrset)
-        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
-        self.assertFalse(receivedQuery)
-        self.assertTrue(receivedResponse)
-        self.assertEqual(receivedResponse, expectedResponse)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
+
+class TestCachingBackendSettingRD(DNSDistTest):
+
+    _config_template = """
+    pc = newPacketCache(100)
+    getPool(""):setCache(pc)
+    newServer{address="127.0.0.1:%d"}
+    """
+
+    def testCachingBackendSetRD(self):
+        """
+        Cache: The backend sets RD=1 in the response even if the query had RD=0
+        """
+        name = 'backend-sets-rd.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
+        expectedQuery = dns.message.make_query(name, 'A', 'IN')
+        expectedQuery.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        response.flags |= dns.flags.RD
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.flags &= ~dns.flags.RD
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = expectedQuery.id
+            self.assertEqual(expectedQuery, receivedQuery)
+            self.assertEqual(receivedResponse, expectedResponse)
+
+        # exact same query should be cached
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)
+
+        # same query with RD=1, should NOT hit the cache
+        query.flags |= dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(receivedResponse, expectedResponse)