]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add EDNS to responses generated from raw record data 14728/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 30 Sep 2024 10:01:27 +0000 (12:01 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 30 Sep 2024 10:06:43 +0000 (12:06 +0200)
My reasoning is that it makes sense to add EDNS to responses generated
from DNSdist provided that:
- the initial query had EDNS
- `setAddEDNSToSelfGeneratedResponses` has not been set to `false`
- we are only provided part of the response and not a full response
  packet

pdns/dnsdistdist/dnsdist-lua-actions.cc
regression-tests.dnsdist/test_Spoofing.py

index 5d6be53d5234548463b582f03498efbe57bce65c..e4f730d77cabddc9547c0c87a0e85f2bcae452bb 100644 (file)
@@ -936,7 +936,6 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string*
   static_assert(recordstart.size() == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid");
   memcpy(&recordstart[4], &qclass, sizeof(qclass));
   memcpy(&recordstart[6], &ttl, sizeof(ttl));
-  bool raw = false;
 
   if (qtype == QType::CNAME) {
     const auto& wireData = d_cname.getStorage(); // Note! This doesn't do compression!
@@ -977,7 +976,6 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string*
         return true;
       });
     }
-    raw = true;
   }
   else {
     for (const auto& addr : addrs) {
@@ -1009,7 +1007,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string*
     return true;
   });
 
-  if (hadEDNS && !raw) {
+  if (hadEDNS) {
     addEDNS(dnsquestion->getMutableData(), dnsquestion->getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, 0);
   }
 
index fafc94e1007ad8a2fb78884734330ace0afc7981..7234e941128f05b6044efb3d26c23adcd17a4348 100644 (file)
@@ -47,7 +47,33 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+
+    def testSpoofActionAWithEDNS(self):
+        """
+        Spoofing: Spoof A via Action (EDNS)
+
+        Send an A query to "spoofaction.spoofing.tests.powerdns.com.",
+        check that dnsdist sends a spoofed result.
+        """
+        name = 'spoofaction.spoofing.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True)
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.use_edns(edns=True, payload=1232)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
 
     def testSpoofActionAAAA(self):
         """
@@ -101,7 +127,7 @@ class TestSpoofingSpoof(DNSDistTest):
 
     def testSpoofActionMultiA(self):
         """
-        Spoofing: Spoof multiple IPv4 addresses via AddDomainSpoof
+        Spoofing: Spoof multiple IPv4 addresses
 
         Send an A query for "multispoof.spoofing.tests.powerdns.com.",
         check that dnsdist sends a spoofed result.
@@ -126,7 +152,7 @@ class TestSpoofingSpoof(DNSDistTest):
 
     def testSpoofActionMultiAAAA(self):
         """
-        Spoofing: Spoof multiple IPv6 addresses via AddDomainSpoof
+        Spoofing: Spoof multiple IPv6 addresses
 
         Send an AAAA query for "multispoof.spoofing.tests.powerdns.com.",
         check that dnsdist sends a spoofed result.
@@ -151,7 +177,7 @@ class TestSpoofingSpoof(DNSDistTest):
 
     def testSpoofActionMultiANY(self):
         """
-        Spoofing: Spoof multiple addresses via AddDomainSpoof
+        Spoofing: Spoof multiple addresses
 
         Send an ANY query for "multispoof.spoofing.tests.powerdns.com.",
         check that dnsdist sends a spoofed result.
@@ -320,7 +346,27 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
+            self.assertEqual(receivedResponse.answer[0].ttl, 60)
+
+        # A with EDNS
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=True)
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.use_edns(edns=True, payload=1232)
+        expectedResponse.flags &= ~dns.flags.AA
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # TXT
@@ -339,7 +385,7 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # SRV
@@ -359,7 +405,7 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 3600)
 
     def testSpoofRawChaosAction(self):
@@ -384,7 +430,7 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
     def testSpoofRawANYAction(self):
@@ -408,7 +454,7 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
     def testSpoofRawActionMulti(self):
@@ -433,7 +479,7 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # TXT
@@ -452,7 +498,7 @@ class TestSpoofingSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
 class TestSpoofingLuaSpoof(DNSDistTest):
@@ -617,7 +663,7 @@ class TestSpoofingLuaSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # TXT
@@ -636,7 +682,7 @@ class TestSpoofingLuaSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # SRV
@@ -656,7 +702,7 @@ class TestSpoofingLuaSpoof(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             # sorry, we can't set the TTL from the Lua API right now
             #self.assertEqual(receivedResponse.answer[0].ttl, 3600)
 
@@ -769,7 +815,7 @@ class TestSpoofingLuaSpoofMulti(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # TXT
@@ -788,7 +834,7 @@ class TestSpoofingLuaSpoofMulti(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # SRV
@@ -808,7 +854,7 @@ class TestSpoofingLuaSpoofMulti(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             # sorry, we can't set the TTL from the Lua API right now
             #self.assertEqual(receivedResponse.answer[0].ttl, 3600)
 
@@ -878,7 +924,7 @@ class TestSpoofingLuaFFISpoofMulti(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
         # TXT
@@ -897,7 +943,7 @@ class TestSpoofingLuaFFISpoofMulti(DNSDistTest):
             sender = getattr(self, method)
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertTrue(receivedResponse)
-            self.assertEqual(expectedResponse, receivedResponse)
+            self.checkMessageNoEDNS(expectedResponse, receivedResponse)
             self.assertEqual(receivedResponse.answer[0].ttl, 60)
 
 class TestSpoofingLuaWithStatistics(DNSDistTest):