testinstance.assertEqual(wire_message, response)
+def getFirstMatchingMessageFromQueue(queue, messageType=None):
+ unused_messages = []
+ selected = None
+ while True:
+ data = queue.get(True, timeout=2.0)
+ if not data:
+ break
+ decoded_message = dnstap_pb2.Dnstap()
+ decoded_message.ParseFromString(data)
+ if not selected and (not messageType or decoded_message.message.type == messageType):
+ selected = decoded_message
+ else:
+ unused_messages.append(data)
+
+ if queue.empty():
+ break
+
+ # put back non-matching messages for later
+ for msg in reversed(unused_messages):
+ queue.put(msg)
+
+ return selected
+
class TestDnstapOverRemoteLogger(DNSDistTest):
_remoteLoggerServerPort = pickAvailablePort()
_remoteLoggerQueue = Queue()
cls._remoteLoggerListener.start()
def getFirstDnstap(self):
- self.assertFalse(self._remoteLoggerQueue.empty())
- data = self._remoteLoggerQueue.get(False)
- self.assertTrue(data)
- dnstap = dnstap_pb2.Dnstap()
- dnstap.ParseFromString(data)
- return dnstap
+ return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue)
def testDnstap(self):
"""
cls._remoteLoggerListener.start()
def getFirstDnstap(self, messageType=None):
- self.assertFalse(self._remoteLoggerQueue.empty())
- unused = []
- dnstap = None
- while not self._remoteLoggerQueue.empty():
- data = self._remoteLoggerQueue.get(False)
- self.assertTrue(data)
- dnstap = dnstap_pb2.Dnstap()
- dnstap.ParseFromString(data)
- if not messageType or dnstap.message.type == messageType:
- break
- unused.append(data)
-
- # put back non-matching messages for later
- for msg in reversed(unused):
- self._remoteLoggerQueue.put(msg)
-
- return dnstap
+ return getFirstMatchingMessageFromQueue(self._remoteLoggerQueue, messageType=messageType)
def testDnstap(self):
"""
cls._fstrmLoggerListener.start()
def getFirstDnstap(self):
- data = self._fstrmLoggerQueue.get(True, timeout=2.0)
- self.assertTrue(data)
- dnstap = dnstap_pb2.Dnstap()
- dnstap.ParseFromString(data)
- return dnstap
+ return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
def testDnstapOverFrameStreamUnix(self):
"""
cls._fstrmLoggerListener.start()
def getFirstDnstap(self):
- data = self._fstrmLoggerQueue.get(True, timeout=2.0)
- self.assertTrue(data)
- dnstap = dnstap_pb2.Dnstap()
- dnstap.ParseFromString(data)
- return dnstap
+ return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
def testDnstapOverFrameStreamUnix(self):
"""
cls._fstrmLoggerListener.start()
def getFirstDnstap(self):
- data = self._fstrmLoggerQueue.get(True, timeout=2.0)
- self.assertTrue(data)
- dnstap = dnstap_pb2.Dnstap()
- dnstap.ParseFromString(data)
- return dnstap
+ return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
def testDnstapOverFrameStreamTcp(self):
"""
cls._fstrmLoggerListener.start()
def getFirstDnstap(self):
- data = self._fstrmLoggerQueue.get(True, timeout=2.0)
- self.assertTrue(data)
- dnstap = dnstap_pb2.Dnstap()
- dnstap.ParseFromString(data)
- return dnstap
+ return getFirstMatchingMessageFromQueue(self._fstrmLoggerQueue)
def testDnstapOverFrameStreamTcp(self):
"""