]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Put the dnstap messages back to the queue in the correct order 15156/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 17 Feb 2025 10:29:49 +0000 (11:29 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 17 Feb 2025 10:29:49 +0000 (11:29 +0100)
As noticed by Miod, the previous solution was not removing all messages
from the queue, possibly putting back old messages behind newer ones.

regression-tests.dnsdist/test_Dnstap.py

index c40f52f497db2b6de513a54fb5aa0e65575a9891..0d902f5ef46cb9785ba8ca3a0574d61b03d70b58 100644 (file)
@@ -77,6 +77,29 @@ def checkDnstapResponse(testinstance, dnstap, protocol, response, initiator='127
     testinstance.assertEqual(wire_message, response)
 
 
+def getFirstMatchingMessageFromQueue(queue, messageType=None):
+    unused_messages = []
+    selected = None
+    while True:
+        data = queue.get(True, timeout=2.0)
+        if not data:
+            break
+        decoded_message = dnstap_pb2.Dnstap()
+        decoded_message.ParseFromString(data)
+        if not selected and (not messageType or decoded_message.message.type == messageType):
+            selected = decoded_message
+        else:
+            unused_messages.append(data)
+
+        if queue.empty():
+            break
+
+    # put back non-matching messages for later
+    for msg in reversed(unused_messages):
+        queue.put(msg)
+
+    return selected
+
 class TestDnstapOverRemoteLogger(DNSDistTest):
     _remoteLoggerServerPort = pickAvailablePort()
     _remoteLoggerQueue = Queue()
@@ -155,12 +178,7 @@ class TestDnstapOverRemoteLogger(DNSDistTest):
         cls._remoteLoggerListener.start()
 
     def getFirstDnstap(self):
-        self.assertFalse(self._remoteLoggerQueue.empty())
-        data = self._remoteLoggerQueue.get(False)
-        self.assertTrue(data)
-        dnstap = dnstap_pb2.Dnstap()
-        dnstap.ParseFromString(data)
-        return dnstap
+        return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue)
 
     def testDnstap(self):
         """
@@ -381,23 +399,7 @@ class TestDnstapOverRemoteLoggerPool(DNSDistTest):
         cls._remoteLoggerListener.start()
 
     def getFirstDnstap(self, messageType=None):
-        self.assertFalse(self._remoteLoggerQueue.empty())
-        unused = []
-        dnstap = None
-        while not self._remoteLoggerQueue.empty():
-            data = self._remoteLoggerQueue.get(False)
-            self.assertTrue(data)
-            dnstap = dnstap_pb2.Dnstap()
-            dnstap.ParseFromString(data)
-            if not messageType or dnstap.message.type == messageType:
-                break
-            unused.append(data)
-
-        # put back non-matching messages for later
-        for msg in reversed(unused):
-            self._remoteLoggerQueue.put(msg)
-
-        return dnstap
+        return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue, messageType=messageType)
 
     def testDnstap(self):
         """
@@ -628,11 +630,7 @@ class TestDnstapOverFrameStreamUnixLogger(DNSDistTest):
         cls._fstrmLoggerListener.start()
 
     def getFirstDnstap(self):
-        data = self._fstrmLoggerQueue.get(True, timeout=2.0)
-        self.assertTrue(data)
-        dnstap = dnstap_pb2.Dnstap()
-        dnstap.ParseFromString(data)
-        return dnstap
+        return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
 
     def testDnstapOverFrameStreamUnix(self):
         """
@@ -724,11 +722,7 @@ class TestDnstapOverRemotePoolUnixLogger(DNSDistTest):
         cls._fstrmLoggerListener.start()
 
     def getFirstDnstap(self):
-        data = self._fstrmLoggerQueue.get(True, timeout=2.0)
-        self.assertTrue(data)
-        dnstap = dnstap_pb2.Dnstap()
-        dnstap.ParseFromString(data)
-        return dnstap
+        return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
 
     def testDnstapOverFrameStreamUnix(self):
         """
@@ -807,11 +801,7 @@ class TestDnstapOverFrameStreamTcpLogger(DNSDistTest):
         cls._fstrmLoggerListener.start()
 
     def getFirstDnstap(self):
-        data = self._fstrmLoggerQueue.get(True, timeout=2.0)
-        self.assertTrue(data)
-        dnstap = dnstap_pb2.Dnstap()
-        dnstap.ParseFromString(data)
-        return dnstap
+        return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
 
     def testDnstapOverFrameStreamTcp(self):
         """
@@ -899,11 +889,7 @@ class TestDnstapOverRemotePoolTcpLogger(DNSDistTest):
         cls._fstrmLoggerListener.start()
 
     def getFirstDnstap(self):
-        data = self._fstrmLoggerQueue.get(True, timeout=2.0)
-        self.assertTrue(data)
-        dnstap = dnstap_pb2.Dnstap()
-        dnstap.ParseFromString(data)
-        return dnstap
+        return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
 
     def testDnstapOverFrameStreamTcp(self):
         """