From: Remi Gacogne Date: Mon, 17 Feb 2025 10:29:49 +0000 (+0100) Subject: dnsdist: Put the dnstap messages back to the queue in the correct order X-Git-Tag: dnsdist-2.0.0-alpha1~81^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fe36ae3e8df7a5bf4af1696f43e785fdeecb3cf5;p=thirdparty%2Fpdns.git dnsdist: Put the dnstap messages back to the queue in the correct order As noticed by Miod, the previous solution was not removing all messages from the queue, possibly putting back old messages behind newer ones. --- diff --git a/regression-tests.dnsdist/test_Dnstap.py b/regression-tests.dnsdist/test_Dnstap.py index c40f52f497..0d902f5ef4 100644 --- a/regression-tests.dnsdist/test_Dnstap.py +++ b/regression-tests.dnsdist/test_Dnstap.py @@ -77,6 +77,29 @@ def checkDnstapResponse(testinstance, dnstap, protocol, response, initiator='127 testinstance.assertEqual(wire_message, response) +def getFirstMatchingMessageFromQueue(queue, messageType=None): + unused_messages = [] + selected = None + while True: + data = queue.get(True, timeout=2.0) + if not data: + break + decoded_message = dnstap_pb2.Dnstap() + decoded_message.ParseFromString(data) + if not selected and (not messageType or decoded_message.message.type == messageType): + selected = decoded_message + else: + unused_messages.append(data) + + if queue.empty(): + break + + # put back non-matching messages for later + for msg in reversed(unused_messages): + queue.put(msg) + + return selected + class TestDnstapOverRemoteLogger(DNSDistTest): _remoteLoggerServerPort = pickAvailablePort() _remoteLoggerQueue = Queue() @@ -155,12 +178,7 @@ class TestDnstapOverRemoteLogger(DNSDistTest): cls._remoteLoggerListener.start() def getFirstDnstap(self): - self.assertFalse(self._remoteLoggerQueue.empty()) - data = self._remoteLoggerQueue.get(False) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue) def testDnstap(self): """ @@ -381,23 +399,7 @@ class TestDnstapOverRemoteLoggerPool(DNSDistTest): cls._remoteLoggerListener.start() def getFirstDnstap(self, messageType=None): - self.assertFalse(self._remoteLoggerQueue.empty()) - unused = [] - dnstap = None - while not self._remoteLoggerQueue.empty(): - data = self._remoteLoggerQueue.get(False) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - if not messageType or dnstap.message.type == messageType: - break - unused.append(data) - - # put back non-matching messages for later - for msg in reversed(unused): - self._remoteLoggerQueue.put(msg) - - return dnstap + return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue, messageType=messageType) def testDnstap(self): """ @@ -628,11 +630,7 @@ class TestDnstapOverFrameStreamUnixLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamUnix(self): """ @@ -724,11 +722,7 @@ class TestDnstapOverRemotePoolUnixLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamUnix(self): """ @@ -807,11 +801,7 @@ class TestDnstapOverFrameStreamTcpLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamTcp(self): """ @@ -899,11 +889,7 @@ class TestDnstapOverRemotePoolTcpLogger(DNSDistTest): cls._fstrmLoggerListener.start() def getFirstDnstap(self): - data = self._fstrmLoggerQueue.get(True, timeout=2.0) - self.assertTrue(data) - dnstap = dnstap_pb2.Dnstap() - dnstap.ParseFromString(data) - return dnstap + return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue) def testDnstapOverFrameStreamTcp(self): """