From fe36ae3e8df7a5bf4af1696f43e785fdeecb3cf5 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Mon, 17 Feb 2025 11:29:49 +0100 Subject: [PATCH] dnsdist: Put the dnstap messages back to the queue in the correct order As noticed by Miod, the previous solution was not removing all messages from the queue, possibly putting back old messages behind newer ones. --- regression-tests.dnsdist/test_Dnstap.py | 72 ++++++++++--------------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/regression-tests.dnsdist/test_Dnstap.py b/regression-tests.dnsdist/test_Dnstap.py index c40f52f497..0d902f5ef4 100644 --- a/regression-tests.dnsdist/test_Dnstap.py +++ b/regression-tests.dnsdist/test_Dnstap.py @@ -77,6 +77,29 @@ def checkDnstapResponse(testinstance, dnstap, protocol, response, initiator='127 testinstance.assertEqual(wire_message, response) +def getFirstMatchingMessageFromQueue(queue, messageType=None): + unused_messages = [] + selected = None + while True: + data = queue.get(True, timeout=2.0) + if not data: + break + decoded_message = dnstap_pb2.Dnstap() + decoded_message.ParseFromString(data) + if not selected and (not messageType or decoded_message.message.type == messageType): + selected = decoded_message + else: + unused_messages.append(data) + + if queue.empty(): + break + + # put back non-matching messages for later + for msg in reversed(unused_messages): + queue.put(msg) + + return selected + class TestDnstapOverRemoteLogger(DNSDistTest): _remoteLoggerServerPort = pickAvailablePort() _remoteLoggerQueue = Queue() @@ -155,12 +178,7 @@ class TestDnstapOverRemoteLogger(DNSDistTest): cls._remoteLoggerListener.start() def getFirstDnstap(self): - self.assertFalse(self._remoteLoggerQueue.empty()) - data = self._remoteLoggerQueue.get(False) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue) def testDnstap(self): """ @@ -381,23 +399,7 @@ class TestDnstapOverRemoteLoggerPool(DNSDistTest): cls._remoteLoggerListener.start() def getFirstDnstap(self, messageType=None): - self.assertFalse(self._remoteLoggerQueue.empty()) - unused = [] - dnstap = None - while not self._remoteLoggerQueue.empty(): - data = self._remoteLoggerQueue.get(False) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - if not messageType or dnstap.message.type == messageType: - break - unused.append(data) - - # put back non-matching messages for later - for msg in reversed(unused): - self._remoteLoggerQueue.put(msg) - - return dnstap + return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue, messageType=messageType) def testDnstap(self): """ @@ -628,11 +630,7 @@ class TestDnstapOverFrameStreamUnixLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamUnix(self): """ @@ -724,11 +722,7 @@ class TestDnstapOverRemotePoolUnixLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamUnix(self): """ @@ -807,11 +801,7 @@ class TestDnstapOverFrameStreamTcpLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamTcp(self): """ @@ -899,11 +889,7 @@ class TestDnstapOverRemotePoolTcpLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamTcp(self): """ -- 2.47.2