From cae561a92410c29053e297bc99d6cb5000665245 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Mon, 30 Sep 2024 12:01:27 +0200 Subject: [PATCH] dnsdist: Add EDNS to responses generated from raw record data My reasoning is that it makes sense to add EDNS to responses generated from DNSdist provided that: - the initial query had EDNS - `setAddEDNSToSelfGeneratedResponses` has not been set to `false` - we are only provided part of the response and not a full response packet --- pdns/dnsdistdist/dnsdist-lua-actions.cc | 4 +- regression-tests.dnsdist/test_Spoofing.py | 84 ++++++++++++++++++----- 2 files changed, 66 insertions(+), 22 deletions(-) diff --git a/pdns/dnsdistdist/dnsdist-lua-actions.cc b/pdns/dnsdistdist/dnsdist-lua-actions.cc index 5d6be53d52..e4f730d77c 100644 --- a/pdns/dnsdistdist/dnsdist-lua-actions.cc +++ b/pdns/dnsdistdist/dnsdist-lua-actions.cc @@ -936,7 +936,6 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* static_assert(recordstart.size() == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid"); memcpy(&recordstart[4], &qclass, sizeof(qclass)); memcpy(&recordstart[6], &ttl, sizeof(ttl)); - bool raw = false; if (qtype == QType::CNAME) { const auto& wireData = d_cname.getStorage(); // Note! This doesn't do compression! @@ -977,7 +976,6 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* return true; }); } - raw = true; } else { for (const auto& addr : addrs) { @@ -1009,7 +1007,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dnsquestion, std::string* return true; }); - if (hadEDNS && !raw) { + if (hadEDNS) { addEDNS(dnsquestion->getMutableData(), dnsquestion->getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, 0); } diff --git a/regression-tests.dnsdist/test_Spoofing.py b/regression-tests.dnsdist/test_Spoofing.py index fafc94e100..7234e94112 100644 --- a/regression-tests.dnsdist/test_Spoofing.py +++ b/regression-tests.dnsdist/test_Spoofing.py @@ -47,7 +47,33 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) + + def testSpoofActionAWithEDNS(self): + """ + Spoofing: Spoof A via Action (EDNS) + + Send an A query to "spoofaction.spoofing.tests.powerdns.com.", + check that dnsdist sends a spoofed result. + """ + name = 'spoofaction.spoofing.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN', use_edns=True) + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + expectedResponse.use_edns(edns=True, payload=1232) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse) def testSpoofActionAAAA(self): """ @@ -101,7 +127,7 @@ class TestSpoofingSpoof(DNSDistTest): def testSpoofActionMultiA(self): """ - Spoofing: Spoof multiple IPv4 addresses via AddDomainSpoof + Spoofing: Spoof multiple IPv4 addresses Send an A query for "multispoof.spoofing.tests.powerdns.com.", check that dnsdist sends a spoofed result. @@ -126,7 +152,7 @@ class TestSpoofingSpoof(DNSDistTest): def testSpoofActionMultiAAAA(self): """ - Spoofing: Spoof multiple IPv6 addresses via AddDomainSpoof + Spoofing: Spoof multiple IPv6 addresses Send an AAAA query for "multispoof.spoofing.tests.powerdns.com.", check that dnsdist sends a spoofed result. @@ -151,7 +177,7 @@ class TestSpoofingSpoof(DNSDistTest): def testSpoofActionMultiANY(self): """ - Spoofing: Spoof multiple addresses via AddDomainSpoof + Spoofing: Spoof multiple addresses Send an ANY query for "multispoof.spoofing.tests.powerdns.com.", check that dnsdist sends a spoofed result. @@ -320,7 +346,27 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) + self.assertEqual(receivedResponse.answer[0].ttl, 60) + + # A with EDNS + query = dns.message.make_query(name, 'A', 'IN', use_edns=True) + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + expectedResponse.use_edns(edns=True, payload=1232) + expectedResponse.flags &= ~dns.flags.AA + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.checkMessageEDNSWithoutOptions(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # TXT @@ -339,7 +385,7 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # SRV @@ -359,7 +405,7 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 3600) def testSpoofRawChaosAction(self): @@ -384,7 +430,7 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) def testSpoofRawANYAction(self): @@ -408,7 +454,7 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) def testSpoofRawActionMulti(self): @@ -433,7 +479,7 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # TXT @@ -452,7 +498,7 @@ class TestSpoofingSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) class TestSpoofingLuaSpoof(DNSDistTest): @@ -617,7 +663,7 @@ class TestSpoofingLuaSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # TXT @@ -636,7 +682,7 @@ class TestSpoofingLuaSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # SRV @@ -656,7 +702,7 @@ class TestSpoofingLuaSpoof(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) # sorry, we can't set the TTL from the Lua API right now #self.assertEqual(receivedResponse.answer[0].ttl, 3600) @@ -769,7 +815,7 @@ class TestSpoofingLuaSpoofMulti(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # TXT @@ -788,7 +834,7 @@ class TestSpoofingLuaSpoofMulti(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # SRV @@ -808,7 +854,7 @@ class TestSpoofingLuaSpoofMulti(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) # sorry, we can't set the TTL from the Lua API right now #self.assertEqual(receivedResponse.answer[0].ttl, 3600) @@ -878,7 +924,7 @@ class TestSpoofingLuaFFISpoofMulti(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) # TXT @@ -897,7 +943,7 @@ class TestSpoofingLuaFFISpoofMulti(DNSDistTest): sender = getattr(self, method) (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertTrue(receivedResponse) - self.assertEqual(expectedResponse, receivedResponse) + self.checkMessageNoEDNS(expectedResponse, receivedResponse) self.assertEqual(receivedResponse.answer[0].ttl, 60) class TestSpoofingLuaWithStatistics(DNSDistTest): -- 2.47.2