]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Set up the dns header for timeout response rules
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 10 Oct 2025 09:45:25 +0000 (11:45 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 13 Oct 2025 07:19:33 +0000 (09:19 +0200)
Response actions expect that there is a DNS payload containing at
least a DNS header, as an incoming packet smaller than a DNS header
would have been discarded early in the processing path.
Unfortunately this is not true for timeout response rules, where we
no longer have the DNS payload from the query and obviously don't
have a response either. This commit restores a DNS header from the
information we have (query ID, flags) so that most actions can
proceed normally.

Signed-off-by: Remi Gacogne <remi.gacogne@powerdns.com>
pdns/dnsdistdist/dnsdist.cc
regression-tests.dnsdist/test_Protobuf.py

index c06c42acf5b2d591cb10dff2be16cad6629f7b4a..8d6ab47e593e799b0bf7ed49dff813864689ad98 100644 (file)
@@ -1592,8 +1592,15 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dnsQuestion, std::shared_
 
 bool handleTimeoutResponseRules(const std::vector<dnsdist::rules::ResponseRuleAction>& rules, InternalQueryState& ids, const std::shared_ptr<DownstreamState>& d_ds, const std::shared_ptr<TCPQuerySender>& sender)
 {
-  PacketBuffer empty;
-  DNSResponse dnsResponse(ids, empty, d_ds);
+  /* let's be nice and restore the original DNS header as well as we can with what we have */
+  PacketBuffer payload(sizeof(dnsheader));
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(payload, [&ids](dnsheader& header) {
+    header.id = ids.origID;
+    header.qdcount = htons(1);
+    restoreFlags(&header, ids.origFlags);
+    return true;
+  });
+  DNSResponse dnsResponse(ids, payload, d_ds);
   auto protocol = dnsResponse.getProtocol();
 
   vinfolog("Handling timeout response rules for incoming protocol = %s", protocol.toString());
index eb51483961b3cd294a7dc2a34613bc38d12b9f77..13d70c097ab79c48a54a038eab736d2afb3d8902 100644 (file)
@@ -1210,3 +1210,95 @@ query_rules:
         # TCP query
         msg = self.getFirstProtobufMessage()
         self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+class TestTimeoutResponseRuleProtobuf(DNSDistProtobufTest):
+
+    _yaml_config_template = """---
+binds:
+  - listen_address: "127.0.0.1:%d"
+    protocol: Do53
+
+backends:
+  - address: "127.0.0.1:%d"
+    protocol: Do53
+    health_checks:
+      mode: "up"
+
+remote_logging:
+  protobuf_loggers:
+    - name: "my-logger"
+      address: "127.0.0.1:%d"
+      timeout: 1
+
+timeout_response_rules:
+  - name: "my-rule"
+    selector:
+      type: "All"
+    action:
+      type: "RemoteLog"
+      logger_name: "my-logger"
+      server_id: "%s"
+"""
+    _dnsDistPort = pickAvailablePort()
+    _testServerPortNotListening = pickAvailablePort()
+    _testServerPort = pickAvailablePort()
+    _yaml_config_params = ['_dnsDistPort', '_testServerPortNotListening', '_protobufServerPort', '_protobufServerID']
+    _config_params = []
+
+    def testProtobuf(self):
+        """
+        Yaml: Remote logging via protobuf of timeouts
+        """
+        name = 'remote-logging-timeout.protobuf.yaml.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False, timeout=2)
+            self.assertEqual(receivedQuery, None)
+            self.assertEqual(receivedResponse, None)
+
+
+        # the UDP timeout usually takes longer to be detected
+        # than the TCP one
+        gotUDP = False
+        gotTCP = False
+        waited = 0
+        while (not gotUDP or not gotTCP) and waited <= 4:
+            if self._protobufQueue.empty():
+                # let the protobuf messages the time to get there
+                time.sleep(1)
+                waited += 1
+                if self._protobufQueue.empty():
+                    continue
+
+            # check the protobuf message
+            msg = self.getFirstProtobufMessage()
+            if msg.socketProtocol == dnsmessage_pb2.PBDNSMessage.UDP:
+                gotUDP = True
+                protocol = dnsmessage_pb2.PBDNSMessage.UDP
+            else:
+                gotTCP = True
+                protocol = dnsmessage_pb2.PBDNSMessage.TCP
+
+            self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+            self.checkProtobufBase(msg, protocol, response, '127.0.0.1', normalQueryResponse=False, v6=False)
+            self.assertTrue(msg.HasField('response'))
+            self.assertTrue(msg.response.HasField('queryTimeSec'))
+            self.assertTrue(msg.HasField('question'))
+            self.assertTrue(msg.question.HasField('qClass'))
+            self.assertTrue(msg.question.HasField('qType'))
+            self.assertTrue(msg.question.HasField('qName'))
+            self.assertEqual(msg.question.qName, name)
+
+        self.assertTrue(gotUDP)
+        self.assertTrue(gotTCP)