]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Fix two issues with chaining ECS enabled queries
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 15 Jul 2025 11:54:30 +0000 (13:54 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 15 Jul 2025 12:13:01 +0000 (14:13 +0200)
1. The main index does not sort on subnet, so we cannot assume any ordering in the
birthday compare used for chains.
2. The lookup key is overwritten by the matched key from the waiters, meaning
that we cannot use it to pass values. This means we have to
recompute the ECS info in the incoming path for each chain member.

Signed-off-by: Otto Moerbeek <otto.moerbeek@open-xchange.com>
pdns/recursordist/pdns_recursor.cc
pdns/recursordist/syncres.hh
regression-tests.recursor-dnssec/recursortests.py
regression-tests.recursor-dnssec/test_Chain.py

index 36d7526eeae4ea91ca147e53e6a0fda769655a4b..bb34ebfefbf49690ef9e26eb09bb0260d650d619 100644 (file)
@@ -299,7 +299,7 @@ LWResult::Result asendto(const void* data, size_t len, int /* flags */,
     // Line below detected an issue with the two ways of ordering PacketIDs (birthday and non-birthday)
     assert(chain.first->key->domain == pident->domain); // NOLINT
     // don't chain onto existing chained waiter or a chain already processed
-    if (chain.first->key->fd > -1 && !chain.first->key->closed) {
+    if (chain.first->key->fd > -1 && !chain.first->key->closed && pident->ecsSubnet == chain.first->key->ecsSubnet) {
       auto currentChainSize = chain.first->key->authReqChain.size();
       *fileDesc = -static_cast<int>(currentChainSize + 1); // value <= -1, gets used in waitEvent / sendEvent later on
       if (g_maxChainLength > 0 && currentChainSize >= g_maxChainLength) {
@@ -341,6 +341,8 @@ LWResult::Result asendto(const void* data, size_t len, int /* flags */,
   return LWResult::Result::Success;
 }
 
+static bool checkIncomingECSSource(const PacketBuffer& packet, const Netmask& subnet);
+
 LWResult::Result arecvfrom(PacketBuffer& packet, int /* flags */, const ComboAddress& fromAddr, size_t& len,
                            uint16_t qid, const DNSName& domain, uint16_t qtype, int fileDesc, const std::optional<EDNSSubnetOpts>& ecs, const struct timeval& now)
 {
@@ -373,9 +375,8 @@ LWResult::Result arecvfrom(PacketBuffer& packet, int /* flags */, const ComboAdd
     len = packet.size();
 
     // In ecs hardening mode, we consider a missing or a mismatched ECS in the reply as a case for
-    // retrying without ECS (matchingECSReceived only gets set if a matching ECS was received). The actual
-    // logic to do that is in Syncres::doResolveAtThisIP()
-    if (g_ECSHardening && pident->ecsSubnet && !*pident->matchingECSReceived) {
+    // retrying without ECS. The actual logic to do that is in Syncres::doResolveAtThisIP()
+    if (g_ECSHardening && pident->ecsSubnet && !checkIncomingECSSource(packet, *pident->ecsSubnet)) {
       t_Counters.at(rec::Counter::ecsMissingCount)++;
       return LWResult::Result::ECSMissing;
     }
@@ -2914,7 +2915,7 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func)
 }
 
 // resend event to everybody chained onto it
-static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr<PacketID>& resend, const PacketBuffer& content, const std::optional<bool>& matchingECSReceived)
+static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr<PacketID>& resend, const PacketBuffer& content)
 {
   // We close the chain for new entries, since they won't be processed anyway
   iter->key->closed = true;
@@ -2923,11 +2924,6 @@ static void doResends(MT_t::waiters_t::iterator& iter, const std::shared_ptr<Pac
     return;
   }
 
-  // Only set if g_ECSHardening
-  if (matchingECSReceived) {
-    iter->key->matchingECSReceived = matchingECSReceived;
-  }
-
   auto maxWeight = t_Counters.at(rec::Counter::maxChainWeight);
   auto weight = iter->key->authReqChain.size() * content.size();
   if (weight > maxWeight) {
@@ -2999,7 +2995,7 @@ static void handleUDPServerResponse(int fileDesc, FDMultiplexer::funcparam_t& va
     PacketBuffer empty;
     auto iter = g_multiTasker->getWaiters().find(pid);
     if (iter != g_multiTasker->getWaiters().end()) {
-      doResends(iter, pid, empty, false);
+      doResends(iter, pid, empty);
     }
     g_multiTasker->sendEvent(pid, &empty); // this denotes error (does retry lookup using other NS)
     return;
@@ -3060,10 +3056,7 @@ static void handleUDPServerResponse(int fileDesc, FDMultiplexer::funcparam_t& va
   if (!pident->domain.empty()) {
     auto iter = g_multiTasker->getWaiters().find(pident);
     if (iter != g_multiTasker->getWaiters().end()) {
-      if (g_ECSHardening) {
-        iter->key->matchingECSReceived = iter->key->ecsSubnet && checkIncomingECSSource(packet, *iter->key->ecsSubnet);
-      }
-      doResends(iter, pident, packet, iter->key->matchingECSReceived);
+      doResends(iter, pident, packet);
     }
   }
 
index 32a358ec399b509a74d7a06e413f24ad998e7510..1427776d60042bb89c1f1fb81fe1e85b081ce685 100644 (file)
@@ -804,7 +804,6 @@ struct PacketID
   bool inIncompleteOkay{false};
   uint16_t id{0}; // wait for a specific id/remote pair
   uint16_t type{0}; // and this is its type
-  std::optional<bool> matchingECSReceived; // only set in ecsHardened mode
   TCPAction highState{TCPAction::DoingRead};
   IOState lowState{IOState::NeedRead};
 
@@ -817,7 +816,7 @@ struct PacketID
 
 inline ostream& operator<<(ostream& ostr, const PacketID& pid)
 {
-  return ostr << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" << pid.tcpsock << ",fd=" << pid.fd << ',' << pid.domain << ')';
+  return ostr << "PacketID(id=" << pid.id << ",remote=" << pid.remote.toString() << ",type=" << pid.type << ",tcpsock=" << pid.tcpsock << ",fd=" << pid.fd << ",name=" << pid.domain << ",ecs=" << (pid.ecsSubnet ? pid.ecsSubnet->toString() : "") << ')';
 }
 
 inline ostream& operator<<(ostream& ostr, const shared_ptr<PacketID>& pid)
@@ -849,10 +848,10 @@ struct PacketIDBirthdayCompare
 {
   bool operator()(const std::shared_ptr<PacketID>& lhs, const std::shared_ptr<PacketID>& rhs) const
   {
-    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type, lhs->ecsSubnet) < std::tie(rhs->remote, rhs->tcpsock, rhs->type, rhs->ecsSubnet)) {
+    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type) < std::tie(rhs->remote, rhs->tcpsock, rhs->type)) {
       return true;
     }
-    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type, lhs->ecsSubnet) > std::tie(rhs->remote, rhs->tcpsock, rhs->type, rhs->ecsSubnet)) {
+    if (std::tie(lhs->remote, lhs->tcpsock, lhs->type) > std::tie(rhs->remote, rhs->tcpsock, rhs->type)) {
       return false;
     }
     return lhs->domain < rhs->domain;
index 0215c2a37997e57c6a3f5ce03a385877b08ccef1..ef2b265eaecc5201a783be50959e7421aa00e990 100644 (file)
@@ -428,9 +428,9 @@ PrivateKey: Ep9uo6+wwjb4MaOmqq7LHav2FLrjotVOeZg8JT1Qk04=
                'zones': ['optout.example']},
         '15': {'threads': 1,
                'zones': ['insecure.optout.example', 'secure.optout.example', 'cname-secure.example']},
-        '16': {'threads': 2,
+        '16': {'threads': 10,
                'zones': ['delay1.example']},
-        '17': {'threads': 2,
+        '17': {'threads': 10,
                'zones': ['delay2.example']},
         '18': {'threads': 1,
                'zones': ['example']}
index 107c66ac729270f54a2d881ca006901b0f3ff220..4672df103e6d0ec42c027d3649d3e0921c80346c 100644 (file)
@@ -2,6 +2,7 @@ import pytest
 import dns
 import os
 import time
+import clientsubnetoption
 from recursortests import RecursorTest
 
 class ChainTest(RecursorTest):
@@ -57,3 +58,149 @@ class ChainTest(RecursorTest):
             'servfail-answers': 0,
             'noerror-answers': count,
         })
+
+class ChainECSTest(RecursorTest):
+    """
+    These regression tests test the chaining of outgoing requests with ECS
+    """
+    _auth_zones = RecursorTest._default_auth_zones
+    _chainSize = 200
+    _confdir = 'ChainECS'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+
+    _config_template = """dnssec=validate
+    trace=no
+    edns-subnet-allow-list=0.0.0.0/0
+    use-incoming-edns-subnet=yes
+    edns-subnet-allow-list=0.0.0.0/0
+    devonly-regression-test-mode
+    webserver=yes
+    webserver-port=%d
+    webserver-address=127.0.0.1
+    webserver-password=%s
+    api-key=%s
+    max-concurrent-requests-per-tcp-connection=%s
+""" % (_wsPort, _wsPassword, _apiKey, _chainSize)
+
+    def testBasic(self):
+        """
+        Tests the case of #14624. Sending many equal requests could lead to ServFail because of
+        clashing waiter ids.
+        """
+        count = self._chainSize
+        name1 = '1.delay1.example.'
+        name2 = '2.delay1.example.'
+        exp1 = dns.rrset.from_text(name1, 0, dns.rdataclass.IN, 'TXT', 'a')
+        exp2 = dns.rrset.from_text(name2, 0, dns.rdataclass.IN, 'TXT', 'a')
+        queries = []
+        for i in range(count):
+            if i % 3 == 0:
+                name = name1
+            else:
+               name = name2
+            if i % 2 == 0:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.2.0', 24)
+            else:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.3.0', 24)
+            query = dns.message.make_query(name, 'TXT', use_edns=True, options=[ecso], want_dnssec=True)
+            query.flags |= dns.flags.AD
+            queries.append(query)
+
+        answers = self.sendTCPQueries(queries)
+        self.assertEqual(len(answers), count)
+
+        for i in range(count):
+            res = answers[i]
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertMessageIsAuthenticated(res)
+            if res.question[0].name.to_text() == name1:
+                self.assertRRsetInAnswer(res, exp1)
+                self.assertMatchingRRSIGInAnswer(res, exp1)
+            elif res.question[0].name.to_text() == name2:
+                self.assertRRsetInAnswer(res, exp2)
+                self.assertMatchingRRSIGInAnswer(res, exp2)
+            else:
+                print("?? " + res.question[0].name.to_text())
+                self.assertEqual(0, 1)
+
+        self.checkMetrics({
+            'servfail-answers': 0,
+            'noerror-answers': count,
+        })
+
+class ChainECSHardenedTest(RecursorTest):
+    """
+    These regression tests test the chaining of outgoing requests with ECS
+    """
+    _auth_zones = RecursorTest._default_auth_zones
+    _chainSize = 200
+    _confdir = 'ChainECSHardened'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+
+    _config_template = """dnssec=validate
+    trace=no
+    edns-subnet-allow-list=0.0.0.0/0
+    use-incoming-edns-subnet=yes
+    edns-subnet-allow-list=0.0.0.0/0
+    edns-subnet-harden=yes
+    devonly-regression-test-mode
+    webserver=yes
+    webserver-port=%d
+    webserver-address=127.0.0.1
+    webserver-password=%s
+    api-key=%s
+    max-concurrent-requests-per-tcp-connection=%s
+""" % (_wsPort, _wsPassword, _apiKey, _chainSize)
+
+    def testBasic(self):
+        """
+        Tests the case of #14624. Sending many equal requests could lead to ServFail because of
+        clashing waiter ids.
+        """
+        count = self._chainSize
+        name1 = '1.delay1.example.'
+        name2 = '2.delay1.example.'
+        exp1 = dns.rrset.from_text(name1, 0, dns.rdataclass.IN, 'TXT', 'a')
+        exp2 = dns.rrset.from_text(name2, 0, dns.rdataclass.IN, 'TXT', 'a')
+        queries = []
+        for i in range(count):
+            if i % 3 == 0:
+                name = name1
+            else:
+               name = name2
+            if i % 2 == 0:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.2.0', 24)
+            else:
+                ecso = clientsubnetoption.ClientSubnetOption('192.0.3.0', 24)
+            query = dns.message.make_query(name, 'TXT', use_edns=True, options=[ecso], want_dnssec=True)
+            query.flags |= dns.flags.AD
+            queries.append(query)
+
+        answers = self.sendTCPQueries(queries)
+        self.assertEqual(len(answers), count)
+
+        for i in range(count):
+            res = answers[i]
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertMessageIsAuthenticated(res)
+            if res.question[0].name.to_text() == name1:
+                self.assertRRsetInAnswer(res, exp1)
+                self.assertMatchingRRSIGInAnswer(res, exp1)
+            elif res.question[0].name.to_text() == name2:
+                self.assertRRsetInAnswer(res, exp2)
+                self.assertMatchingRRSIGInAnswer(res, exp2)
+            else:
+                print("?? " + res.question[0].name.to_text())
+                self.assertEqual(0, 1)
+
+        self.checkMetrics({
+            'servfail-answers': 0,
+            'noerror-answers': count,
+        })
+