From: Otto Moerbeek Date: Tue, 15 Jul 2025 11:54:30 +0000 (+0200) Subject: Fix two issues with chaining ECS enabled queries X-Git-Tag: rec-5.4.0-alpha0~18^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=19fdc9dfd9eaa279f2b5fb7a1053281973d350eb;p=thirdparty%2Fpdns.git Fix two issues with chaining ECS enabled queries 1. The main index does not sort on subnet, so we cannot assume any ordering in the birthday compare used for chains. 2. The lookup key is overwritten by the matched key from the waiters, meaning that we cannot use it to pass values. This means we have to recompute the ECS info in the incoming path for each chain member. Signed-off-by: Otto Moerbeek --- diff --git a/pdns/recursordist/pdns_recursor.cc b/pdns/recursordist/pdns_recursor.cc index 36d7526eea..bb34ebfefb 100644 --- a/pdns/recursordist/pdns_recursor.cc +++ b/pdns/recursordist/pdns_recursor.cc @@ -299,7 +299,7 @@ LWResult::Result asendto(const void* data, size_t len, int /* flags */, // Line below detected an issue with the two ways of ordering PacketIDs (birthday and non-birthday) assert(chain.first->key->domain == pident->domain); // NOLINT // don't chain onto existing chained waiter or a chain already processed - if (chain.first->key->fd > -1 && !chain.first->key->closed) { + if (chain.first->key->fd > -1 && !chain.first->key->closed && pident->ecsSubnet == chain.first->key->ecsSubnet) { auto currentChainSize = chain.first->key->authReqChain.size(); *fileDesc = -static_cast(currentChainSize + 1); // value <= -1, gets used in waitEvent / sendEvent later on if (g_maxChainLength > 0 && currentChainSize >= g_maxChainLength) { @@ -341,6 +341,8 @@ LWResult::Result asendto(const void* data, size_t len, int /* flags */, return LWResult::Result::Success; } +static bool checkIncomingECSSource(const PacketBuffer& packet, const Netmask& subnet); + LWResult::Result arecvfrom(PacketBuffer& packet, int /* flags */, const ComboAddress& fromAddr, size_t& len, uint16_t qid, const DNSName& domain, uint16_t qtype, int fileDesc, const std::optional& ecs, const struct timeval& now) { @@ -373,9 +375,8 @@ LWResult::Result arecvfrom(PacketBuffer& packet, int /* flags */, const ComboAdd len = packet.size(); // In ecs hardening mode, we consider a missing or a mismatched ECS in the reply as a case for - // retrying without ECS (matchingECSReceived only gets set if a matching ECS was received). The actual - // logic to do that is in Syncres::doResolveAtThisIP() - if (g_ECSHardening && pident->ecsSubnet && !*pident->matchingECSReceived) { + // retrying without ECS. The actual logic to do that is in Syncres::doResolveAtThisIP() + if (g_ECSHardening && pident->ecsSubnet && !checkIncomingECSSource(packet, *pident->ecsSubnet)) { t_Counters.at(rec::Counter::ecsMissingCount)++; return LWResult::Result::ECSMissing; } @@ -2914,7 +2915,7 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func) } // resend event to everybody chained onto it -static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr& resend, const PacketBuffer& content, const std::optional& matchingECSReceived) +static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr& resend, const PacketBuffer& content) { // We close the chain for new entries, since they won't be processed anyway iter->key->closed = true; @@ -2923,11 +2924,6 @@ static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptrkey->matchingECSReceived = matchingECSReceived; - } - auto maxWeight = t_Counters.at(rec::Counter::maxChainWeight); auto weight = iter->key->authReqChain.size() * content.size(); if (weight > maxWeight) { @@ -2999,7 +2995,7 @@ static void handleUDPServerResponse(int fileDesc, FDMultiplexer::funcparam_t& va PacketBuffer empty; auto iter = g_multiTasker->getWaiters().find(pid); if (iter != g_multiTasker->getWaiters().end()) { - doResends(iter, pid, empty, false); + doResends(iter, pid, empty); } g_multiTasker->sendEvent(pid, &empty); // this denotes error (does retry lookup using other NS) return; @@ -3060,10 +3056,7 @@ static void handleUDPServerResponse(int fileDesc, FDMultiplexer::funcparam_t& va if (!pident->domain.empty()) { auto iter = g_multiTasker->getWaiters().find(pident); if (iter != g_multiTasker->getWaiters().end()) { - if (g_ECSHardening) { - iter->key->matchingECSReceived = iter->key->ecsSubnet && checkIncomingECSSource(packet, *iter->key->ecsSubnet); - } - doResends(iter, pident, packet, iter->key->matchingECSReceived); + doResends(iter, pident, packet); } } diff --git a/pdns/recursordist/syncres.hh b/pdns/recursordist/syncres.hh index 32a358ec39..1427776d60 100644 --- a/pdns/recursordist/syncres.hh +++ b/pdns/recursordist/syncres.hh @@ -804,7 +804,6 @@ struct PacketID bool inIncompleteOkay{false}; uint16_t id{0}; // wait for a specific id/remote pair uint16_t type{0}; // and this is its type - std::optional matchingECSReceived; // only set in ecsHardened mode TCPAction highState{TCPAction::DoingRead}; IOState lowState{IOState::NeedRead}; @@ -817,7 +816,7 @@ struct PacketID inline ostream& operator<<(ostream& ostr, const PacketID& pid) { - return ostr << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" << pid.tcpsock << ",fd=" << pid.fd << ',' << pid.domain << ')'; + return ostr << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" << pid.tcpsock << ",fd=" << pid.fd << ",name=" << pid.domain << ",ecs=" << (pid.ecsSubnet ? pid.ecsSubnet->toString() : "") << ')'; } inline ostream& operator<<(ostream& ostr, const shared_ptr& pid) @@ -849,10 +848,10 @@ struct PacketIDBirthdayCompare { bool operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs) const { - if (std::tie(lhs->remote, lhs->tcpsock, lhs->type, lhs->ecsSubnet) < std::tie(rhs->remote, rhs->tcpsock, rhs->type, rhs->ecsSubnet)) { + if (std::tie(lhs->remote, lhs->tcpsock, lhs->type) < std::tie(rhs->remote, rhs->tcpsock, rhs->type)) { return true; } - if (std::tie(lhs->remote, lhs->tcpsock, lhs->type, lhs->ecsSubnet) > std::tie(rhs->remote, rhs->tcpsock, rhs->type, rhs->ecsSubnet)) { + if (std::tie(lhs->remote, lhs->tcpsock, lhs->type) > std::tie(rhs->remote, rhs->tcpsock, rhs->type)) { return false; } return lhs->domain < rhs->domain; diff --git a/regression-tests.recursor-dnssec/recursortests.py b/regression-tests.recursor-dnssec/recursortests.py index 0215c2a379..ef2b265eae 100644 --- a/regression-tests.recursor-dnssec/recursortests.py +++ b/regression-tests.recursor-dnssec/recursortests.py @@ -428,9 +428,9 @@ PrivateKey: Ep9uo6+wwjb4MaOmqq7LHav2FLrjotVOeZg8JT1Qk04= 'zones': ['optout.example']}, '15': {'threads': 1, 'zones': ['insecure.optout.example', 'secure.optout.example', 'cname-secure.example']}, - '16': {'threads': 2, + '16': {'threads': 10, 'zones': ['delay1.example']}, - '17': {'threads': 2, + '17': {'threads': 10, 'zones': ['delay2.example']}, '18': {'threads': 1, 'zones': ['example']} diff --git a/regression-tests.recursor-dnssec/test_Chain.py b/regression-tests.recursor-dnssec/test_Chain.py index 107c66ac72..4672df103e 100644 --- a/regression-tests.recursor-dnssec/test_Chain.py +++ b/regression-tests.recursor-dnssec/test_Chain.py @@ -2,6 +2,7 @@ import pytest import dns import os import time +import clientsubnetoption from recursortests import RecursorTest class ChainTest(RecursorTest): @@ -57,3 +58,149 @@ class ChainTest(RecursorTest): 'servfail-answers': 0, 'noerror-answers': count, }) + +class ChainECSTest(RecursorTest): + """ + These regression tests test the chaining of outgoing requests with ECS + """ + _auth_zones = RecursorTest._default_auth_zones + _chainSize = 200 + _confdir = 'ChainECS' + _wsPort = 8042 + _wsTimeout = 2 + _wsPassword = 'secretpassword' + _apiKey = 'secretapikey' + + _config_template = """dnssec=validate + trace=no + edns-subnet-allow-list=0.0.0.0/0 + use-incoming-edns-subnet=yes + edns-subnet-allow-list=0.0.0.0/0 + devonly-regression-test-mode + webserver=yes + webserver-port=%d + webserver-address=127.0.0.1 + webserver-password=%s + api-key=%s + max-concurrent-requests-per-tcp-connection=%s +""" % (_wsPort, _wsPassword, _apiKey, _chainSize) + + def testBasic(self): + """ + Tests the case of #14624. Sending many equal requests could lead to ServFail because of + clashing waiter ids. + """ + count = self._chainSize + name1 = '1.delay1.example.' + name2 = '2.delay1.example.' + exp1 = dns.rrset.from_text(name1, 0, dns.rdataclass.IN, 'TXT', 'a') + exp2 = dns.rrset.from_text(name2, 0, dns.rdataclass.IN, 'TXT', 'a') + queries = [] + for i in range(count): + if i % 3 == 0: + name = name1 + else: + name = name2 + if i % 2 == 0: + ecso = clientsubnetoption.ClientSubnetOption('192.0.2.0', 24) + else: + ecso = clientsubnetoption.ClientSubnetOption('192.0.3.0', 24) + query = dns.message.make_query(name, 'TXT', use_edns=True, options=[ecso], want_dnssec=True) + query.flags |= dns.flags.AD + queries.append(query) + + answers = self.sendTCPQueries(queries) + self.assertEqual(len(answers), count) + + for i in range(count): + res = answers[i] + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertMessageIsAuthenticated(res) + if res.question[0].name.to_text() == name1: + self.assertRRsetInAnswer(res, exp1) + self.assertMatchingRRSIGInAnswer(res, exp1) + elif res.question[0].name.to_text() == name2: + self.assertRRsetInAnswer(res, exp2) + self.assertMatchingRRSIGInAnswer(res, exp2) + else: + print("?? " + res.question[0].name.to_text()) + self.assertEqual(0, 1) + + self.checkMetrics({ + 'servfail-answers': 0, + 'noerror-answers': count, + }) + +class ChainECSHardenedTest(RecursorTest): + """ + These regression tests test the chaining of outgoing requests with ECS + """ + _auth_zones = RecursorTest._default_auth_zones + _chainSize = 200 + _confdir = 'ChainECSHardened' + _wsPort = 8042 + _wsTimeout = 2 + _wsPassword = 'secretpassword' + _apiKey = 'secretapikey' + + _config_template = """dnssec=validate + trace=no + edns-subnet-allow-list=0.0.0.0/0 + use-incoming-edns-subnet=yes + edns-subnet-allow-list=0.0.0.0/0 + edns-subnet-harden=yes + devonly-regression-test-mode + webserver=yes + webserver-port=%d + webserver-address=127.0.0.1 + webserver-password=%s + api-key=%s + max-concurrent-requests-per-tcp-connection=%s +""" % (_wsPort, _wsPassword, _apiKey, _chainSize) + + def testBasic(self): + """ + Tests the case of #14624. Sending many equal requests could lead to ServFail because of + clashing waiter ids. + """ + count = self._chainSize + name1 = '1.delay1.example.' + name2 = '2.delay1.example.' + exp1 = dns.rrset.from_text(name1, 0, dns.rdataclass.IN, 'TXT', 'a') + exp2 = dns.rrset.from_text(name2, 0, dns.rdataclass.IN, 'TXT', 'a') + queries = [] + for i in range(count): + if i % 3 == 0: + name = name1 + else: + name = name2 + if i % 2 == 0: + ecso = clientsubnetoption.ClientSubnetOption('192.0.2.0', 24) + else: + ecso = clientsubnetoption.ClientSubnetOption('192.0.3.0', 24) + query = dns.message.make_query(name, 'TXT', use_edns=True, options=[ecso], want_dnssec=True) + query.flags |= dns.flags.AD + queries.append(query) + + answers = self.sendTCPQueries(queries) + self.assertEqual(len(answers), count) + + for i in range(count): + res = answers[i] + self.assertRcodeEqual(res, dns.rcode.NOERROR) + self.assertMessageIsAuthenticated(res) + if res.question[0].name.to_text() == name1: + self.assertRRsetInAnswer(res, exp1) + self.assertMatchingRRSIGInAnswer(res, exp1) + elif res.question[0].name.to_text() == name2: + self.assertRRsetInAnswer(res, exp2) + self.assertMatchingRRSIGInAnswer(res, exp2) + else: + print("?? " + res.question[0].name.to_text()) + self.assertEqual(0, 1) + + self.checkMetrics({ + 'servfail-answers': 0, + 'noerror-answers': count, + }) +