]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Backport of Fix two issues with chaining ECS enabled queries
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 15 Jul 2025 11:54:30 +0000 (13:54 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 15 Jul 2025 14:38:32 +0000 (16:38 +0200)
Signed-off-by: Otto Moerbeek <otto.moerbeek@open-xchange.com>
pdns/ednsoptions.cc
pdns/recursordist/pdns_recursor.cc
pdns/recursordist/syncres.hh
regression-tests.recursor-dnssec/recursortests.py
regression-tests.recursor-dnssec/test_Chain.py

index 2d0377b8a7b3c945a3edd9f5a6183cfdbdb974f3..9398ecd0f568332eea48584b7dfd64893b202763 100644 (file)
@@ -107,7 +107,7 @@ bool slowParseEDNSOptions(const PacketBuffer& packet, EDNSOptionViewMap& options
   }
 
   if (ntohs(dnsHeader->arcount) == 0) {
-    throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
+    return false;
   }
 
   try {
index 0f3c795c6ed8f931a7870c7f8fba91c034859ccd..eb887cf7bfe8a11bd8b581da74e5338042946db8 100644 (file)
@@ -302,7 +302,7 @@ LWResult::Result asendto(const void* data, size_t len, int /* flags */,
     // Line below detected an issue with the two ways of ordering PacketIDs (birthday and non-birthday)
     assert(chain.first->key->domain == pident->domain); // NOLINT
     // don't chain onto existing chained waiter or a chain already processed
-    if (chain.first->key->fd > -1 && !chain.first->key->closed) {
+    if (chain.first->key->fd > -1 && !chain.first->key->closed && pident->ecsSubnet == chain.first->key->ecsSubnet) {
       auto currentChainSize = chain.first->key->authReqChain.size();
       *fileDesc = -static_cast<int>(currentChainSize + 1); // value <= -1, gets used in waitEvent / sendEvent later on
       if (g_maxChainLength > 0 && currentChainSize >= g_maxChainLength) {
@@ -344,6 +344,8 @@ LWResult::Result asendto(const void* data, size_t len, int /* flags */,
   return LWResult::Result::Success;
 }
 
+static bool checkIncomingECSSource(const PacketBuffer& packet, const Netmask& subnet);
+
 LWResult::Result arecvfrom(PacketBuffer& packet, int /* flags */, const ComboAddress& fromAddr, size_t& len,
                            uint16_t qid, const DNSName& domain, uint16_t qtype, int fileDesc, const std::optional<EDNSSubnetOpts>& ecs, const struct timeval& now)
 {
@@ -376,9 +378,8 @@ LWResult::Result arecvfrom(PacketBuffer& packet, int /* flags */, const ComboAdd
     len = packet.size();
 
     // In ecs hardening mode, we consider a missing or a mismatched ECS in the reply as a case for
-    // retrying without ECS (matchingECSReceived only gets set if a matching ECS was received). The actual
-    // logic to do that is in Syncres::doResolveAtThisIP()
-    if (g_ECSHardening && pident->ecsSubnet && !*pident->matchingECSReceived) {
+    // retrying without ECS. The actual logic to do that is in Syncres::doResolveAtThisIP()
+    if (g_ECSHardening && pident->ecsSubnet && !checkIncomingECSSource(packet, *pident->ecsSubnet)) {
       t_Counters.at(rec::Counter::ecsMissingCount)++;
       return LWResult::Result::ECSMissing;
     }
@@ -2890,7 +2891,7 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func)
 }
 
 // resend event to everybody chained onto it
-static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr<PacketID>& resend, const PacketBuffer& content, const std::optional<bool>& matchingECSReceived)
+static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr<PacketID>& resend, const PacketBuffer& content)
 {
   // We close the chain for new entries, since they won't be processed anyway
   iter->key->closed = true;
@@ -2899,11 +2900,6 @@ static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr<Pac
     return;
   }
 
-  // Only set if g_ECSHardening
-  if (matchingECSReceived) {
-    iter->key->matchingECSReceived = matchingECSReceived;
-  }
-
   auto maxWeight = t_Counters.at(rec::Counter::maxChainWeight);
   auto weight = iter->key->authReqChain.size() * content.size();
   if (weight > maxWeight) {
@@ -2975,7 +2971,7 @@ static void handleUDPServerResponse(int fileDesc, FDMultiplexer::funcparam_t& va
     PacketBuffer empty;
     auto iter = g_multiTasker->getWaiters().find(pid);
     if (iter != g_multiTasker->getWaiters().end()) {
-      doResends(iter, pid, empty, false);
+      doResends(iter, pid, empty);
     }
     g_multiTasker->sendEvent(pid, &empty); // this denotes error (does retry lookup using other NS)
     return;
@@ -3036,10 +3032,7 @@ static void handleUDPServerResponse(int fileDesc, FDMultiplexer::funcparam_t& va
   if (!pident->domain.empty()) {
     auto iter = g_multiTasker->getWaiters().find(pident);
     if (iter != g_multiTasker->getWaiters().end()) {
-      if (g_ECSHardening) {
-        iter->key->matchingECSReceived = iter->key->ecsSubnet && checkIncomingECSSource(packet, *iter->key->ecsSubnet);
-      }
-      doResends(iter, pident, packet, iter->key->matchingECSReceived);
+      doResends(iter, pident, packet);
     }
   }
 
index 6d22624b47af81fa8ebbf51d62bee815835339be..cfa081f6235f7898818258c7f2aac2a4fdfe857f 100644 (file)
@@ -779,7 +779,6 @@ struct PacketID
   bool inIncompleteOkay{false};
   uint16_t id{0}; // wait for a specific id/remote pair
   uint16_t type{0}; // and this is its type
-  std::optional<bool> matchingECSReceived; // only set in ecsHardened mode
   TCPAction highState{TCPAction::DoingRead};
   IOState lowState{IOState::NeedRead};
 
@@ -792,7 +791,7 @@ struct PacketID
 
 inline ostream& operator<<(ostream& ostr, const PacketID& pid)
 {
-  return ostr << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" << pid.tcpsock << ",fd=" << pid.fd << ',' << pid.domain << ')';
+  return ostr << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" << pid.tcpsock << ",fd=" << pid.fd << ",name=" << pid.domain << ",ecs=" << (pid.ecsSubnet ? pid.ecsSubnet->toString() : "") << ')';
 }
 
 inline ostream& operator<<(ostream& ostr, const shared_ptr<PacketID>& pid)
@@ -824,10 +823,10 @@ struct PacketIDBirthdayCompare
 {
   bool operator()(const std::shared_ptr<PacketID>& lhs, const std::shared_ptr<PacketID>& rhs) const
   {
-    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type, lhs->ecsSubnet) < std::tie(rhs->remote, rhs->tcpsock, rhs->type, rhs->ecsSubnet)) {
+    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type) < std::tie(rhs->remote, rhs->tcpsock, rhs->type)) {
       return true;
     }
-    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type, lhs->ecsSubnet) > std::tie(rhs->remote, rhs->tcpsock, rhs->type, rhs->ecsSubnet)) {
+    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type) > std::tie(rhs->remote, rhs->tcpsock, rhs->type)) {
       return false;
     }
     return lhs->domain < rhs->domain;
index 97a2cfc246c3aa3d2112488797a427f4cdda3a0a..52a3427d89a1b0ba1bb10fa9402628060c3cd426 100644 (file)
@@ -428,9 +428,9 @@ PrivateKey: Ep9uo6+wwjb4MaOmqq7LHav2FLrjotVOeZg8JT1Qk04=
                'zones': ['optout.example']},
         '15': {'threads': 1,
                'zones': ['insecure.optout.example', 'secure.optout.example', 'cname-secure.example']},
-        '16': {'threads': 2,
+        '16': {'threads': 10,
                'zones': ['delay1.example']},
-        '17': {'threads': 2,
+        '17': {'threads': 10,
                'zones': ['delay2.example']},
         '18': {'threads': 1,
                'zones': ['example']}
index ae436a4799f843087c1d3e018de595a380fee866..bffb317143ebf0c6d73662436e069d5851d50234 100644 (file)
@@ -2,6 +2,7 @@ import dns
 import os
 import time
 import requests
+import clientsubnetoption
 from recursortests import RecursorTest
 
 class ChainTest(RecursorTest):
@@ -73,3 +74,149 @@ class ChainTest(RecursorTest):
             'servfail-answers': 0,
             'noerror-answers': (lambda x: x <= count),
         })
+
+class ChainECSTest(RecursorTest):
+    """
+    These regression tests test the chaining of outgoing requests with ECS
+    """
+    _auth_zones = RecursorTest._default_auth_zones
+    _chainSize = 200
+    _confdir = 'ChainECS'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+
+    _config_template = """dnssec=validate
+    trace=no
+    edns-subnet-allow-list=0.0.0.0/0
+    use-incoming-edns-subnet=yes
+    edns-subnet-allow-list=0.0.0.0/0
+    devonly-regression-test-mode
+    webserver=yes
+    webserver-port=%d
+    webserver-address=127.0.0.1
+    webserver-password=%s
+    api-key=%s
+    max-concurrent-requests-per-tcp-connection=%s
+""" % (_wsPort, _wsPassword, _apiKey, _chainSize)
+
+    def testBasic(self):
+        """
+        Tests the case of #14624. Sending many equal requests could lead to ServFail because of
+        clashing waiter ids.
+        """
+        count = self._chainSize
+        name1 = '1.delay1.example.'
+        name2 = '2.delay1.example.'
+        exp1 = dns.rrset.from_text(name1, 0, dns.rdataclass.IN, 'TXT', 'a')
+        exp2 = dns.rrset.from_text(name2, 0, dns.rdataclass.IN, 'TXT', 'a')
+        queries = []
+        for i in range(count):
+            if i % 3 == 0:
+                name = name1
+            else:
+               name = name2
+            if i % 2 == 0:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.2.0', 24)
+            else:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.3.0', 24)
+            query = dns.message.make_query(name, 'TXT', use_edns=True, options=[ecso], want_dnssec=True)
+            query.flags |= dns.flags.AD
+            queries.append(query)
+
+        answers = self.sendTCPQueries(queries)
+        self.assertEqual(len(answers), count)
+
+        for i in range(count):
+            res = answers[i]
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertMessageIsAuthenticated(res)
+            if res.question[0].name.to_text() == name1:
+                self.assertRRsetInAnswer(res, exp1)
+                self.assertMatchingRRSIGInAnswer(res, exp1)
+            elif res.question[0].name.to_text() == name2:
+                self.assertRRsetInAnswer(res, exp2)
+                self.assertMatchingRRSIGInAnswer(res, exp2)
+            else:
+                print("?? " + res.question[0].name.to_text())
+                self.assertEqual(0, 1)
+
+        self.checkMetrics({
+            'servfail-answers': 0,
+            'noerror-answers': count,
+        })
+
+class ChainECSHardenedTest(RecursorTest):
+    """
+    These regression tests test the chaining of outgoing requests with ECS
+    """
+    _auth_zones = RecursorTest._default_auth_zones
+    _chainSize = 200
+    _confdir = 'ChainECSHardened'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+
+    _config_template = """dnssec=validate
+    trace=no
+    edns-subnet-allow-list=0.0.0.0/0
+    use-incoming-edns-subnet=yes
+    edns-subnet-allow-list=0.0.0.0/0
+    edns-subnet-harden=yes
+    devonly-regression-test-mode
+    webserver=yes
+    webserver-port=%d
+    webserver-address=127.0.0.1
+    webserver-password=%s
+    api-key=%s
+    max-concurrent-requests-per-tcp-connection=%s
+""" % (_wsPort, _wsPassword, _apiKey, _chainSize)
+
+    def testBasic(self):
+        """
+        Tests the case of #14624. Sending many equal requests could lead to ServFail because of
+        clashing waiter ids.
+        """
+        count = self._chainSize
+        name1 = '1.delay1.example.'
+        name2 = '2.delay1.example.'
+        exp1 = dns.rrset.from_text(name1, 0, dns.rdataclass.IN, 'TXT', 'a')
+        exp2 = dns.rrset.from_text(name2, 0, dns.rdataclass.IN, 'TXT', 'a')
+        queries = []
+        for i in range(count):
+            if i % 3 == 0:
+                name = name1
+            else:
+               name = name2
+            if i % 2 == 0:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.2.0', 24)
+            else:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.3.0', 24)
+            query = dns.message.make_query(name, 'TXT', use_edns=True, options=[ecso], want_dnssec=True)
+            query.flags |= dns.flags.AD
+            queries.append(query)
+
+        answers = self.sendTCPQueries(queries)
+        self.assertEqual(len(answers), count)
+
+        for i in range(count):
+            res = answers[i]
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertMessageIsAuthenticated(res)
+            if res.question[0].name.to_text() == name1:
+                self.assertRRsetInAnswer(res, exp1)
+                self.assertMatchingRRSIGInAnswer(res, exp1)
+            elif res.question[0].name.to_text() == name2:
+                self.assertRRsetInAnswer(res, exp2)
+                self.assertMatchingRRSIGInAnswer(res, exp2)
+            else:
+                print("?? " + res.question[0].name.to_text())
+                self.assertEqual(0, 1)
+
+        self.checkMetrics({
+            'servfail-answers': 0,
+            'noerror-answers': count,
+        })
+