]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for DNS flags in ProtoBuf messages
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 5 Feb 2026 15:56:13 +0000 (16:56 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 5 Feb 2026 15:59:38 +0000 (16:59 +0100)
Signed-off-by: Remi Gacogne <remi.gacogne@powerdns.com>
regression-tests.dnsdist/test_Protobuf.py

index 4258db7cf9e873e15fcd3be472e28d84d7e608a9..551a8ac85fe8fd913441302c2376cc870066bb10 100644 (file)
@@ -80,7 +80,7 @@ class DNSDistProtobufTest(DNSDistTest):
         msg.ParseFromString(data)
         return msg
 
-    def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True, v6=False):
+    def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True, v6=False, flags=None):
         self.assertTrue(msg)
         self.assertTrue(msg.HasField('timeSec'))
         self.assertTrue(msg.HasField('socketFamily'))
@@ -100,6 +100,9 @@ class DNSDistProtobufTest(DNSDistTest):
         self.assertTrue(msg.HasField('id'))
         self.assertEqual(msg.id, query.id)
         self.assertTrue(msg.HasField('inBytes'))
+        self.assertTrue(msg.HasField('headerFlags'))
+        queryFlags = flags or int.from_bytes(query.to_wire()[2:4], byteorder=sys.byteorder)
+        self.assertEqual(msg.headerFlags, queryFlags)
         self.assertTrue(msg.HasField('serverIdentity'))
         self.assertEqual(msg.serverIdentity, self._protobufServerID.encode('utf-8'))
 
@@ -134,10 +137,10 @@ class DNSDistProtobufTest(DNSDistTest):
         # exclusive or of lists should be empty
         self.assertEqual(len(listx), 0, "Protobuf tags don't match")
 
-    def checkProtobufQueryConvertedToResponse(self, msg, protocol, response, initiator='127.0.0.0'):
+    def checkProtobufQueryConvertedToResponse(self, msg, protocol, response, initiator='127.0.0.0', flags=None):
         self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
         # skip comparing inBytes (size of the query) with the length of the generated response
-        self.checkProtobufBase(msg, protocol, response, initiator, False)
+        self.checkProtobufBase(msg, protocol, response, initiator, False, flags=flags)
         self.assertTrue(msg.HasField('response'))
         self.assertTrue(msg.response.HasField('queryTimeSec'))
 
@@ -379,8 +382,8 @@ class TestProtobuf(DNSDistProtobufTest):
 
         # check the protobuf message corresponding to the UDP query
         msg = self.getFirstProtobufMessage()
-
-        self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
+        flags = int.from_bytes(query.to_wire()[2:4], byteorder=sys.byteorder)
+        self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0', flags=flags)
         self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
 
         # check the protobuf message corresponding to the UDP response
@@ -405,7 +408,8 @@ class TestProtobuf(DNSDistProtobufTest):
 
         # check the protobuf message corresponding to the TCP query
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
+        flags = int.from_bytes(query.to_wire()[2:4], byteorder=sys.byteorder)
+        self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0', flags=flags)
         self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
 
         # check the protobuf message corresponding to the TCP response