testinstance.assertEqual(socket.inet_ntop(socket.AF_INET, dnstap.message.response_address), initiator)
testinstance.assertTrue(dnstap.message.HasField('response_port'))
testinstance.assertEqual(dnstap.message.response_port, testinstance._dnsDistPort)
-
+
def checkDnstapQuery(testinstance, dnstap, protocol, query, initiator='127.0.0.1'):
testinstance.assertEqual(dnstap.message.type, dnstap_pb2.Message.CLIENT_QUERY)
return cft
-def fstrm_handle_bidir_connection(conn, on_data):
+def fstrm_handle_bidir_connection(conn, on_data, exit_early=False):
data = None
while True:
data = conn.recv(4)
break
on_data(data)
+ if exit_early:
+ break
class TestDnstapOverFrameStreamUnixLogger(DNSDistTest):
checkDnstapQuery(self, dnstap, dnstap_pb2.UDP, query)
checkDnstapNoExtra(self, dnstap)
+
+class TestDnstapOverRemotePoolTcpLogger(DNSDistTest):
+ _fstrmLoggerPort = pickAvailablePort()
+ _fstrmLoggerQueue = Queue()
+ _fstrmLoggerCounter = 0
+ _poolConnectionCount = 5
+ _config_params = ['_testServerPort', '_fstrmLoggerPort', '_poolConnectionCount']
+ _config_template = """
+ newServer{address="127.0.0.1:%s", useClientSubnet=true}
+ fslu = newFrameStreamTcpLogger('127.0.0.1:%s', { connectionCount = %s })
+
+ addAction(AllRule(), DnstapLogAction("a.server", fslu))
+ """
+
+ @classmethod
+ def FrameStreamUnixListener(cls, port):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ try:
+ sock.bind(("127.0.0.1", port))
+ except socket.error as e:
+ print("Error binding in the framestream listener: %s" % str(e))
+ sys.exit(1)
+
+ sock.listen(100)
+
+ def handle_connection(conn):
+ fstrm_handle_bidir_connection(conn, lambda data: \
+ cls._fstrmLoggerQueue.put(data, True, timeout=2.0), exit_early=True)
+ conn.close()
+
+ threads = []
+ while True:
+ (conn, _) = sock.accept()
+ thread = threading.Thread(target=handle_connection, args=[conn])
+ threads.append(thread)
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+ sock.close()
+
+ @classmethod
+ def startResponders(cls):
+ DNSDistTest.startResponders()
+
+ cls._fstrmLoggerListener = threading.Thread(name='FrameStreamUnixListener', target=cls.FrameStreamUnixListener, args=[cls._fstrmLoggerPort])
+ cls._fstrmLoggerListener.daemon = True
+ cls._fstrmLoggerListener.start()
+
+ def getFirstDnstap(self):
+ data = self._fstrmLoggerQueue.get(True, timeout=2.0)
+ self.assertTrue(data)
+ dnstap = dnstap_pb2.Dnstap()
+ dnstap.ParseFromString(data)
+ return dnstap
+
+ def testDnstapOverFrameStreamTcp(self):
+ """
+ Dnstap: Send query packed in dnstap to a tcp socket fstrmlogger server
+ """
+ for i in range(self._poolConnectionCount):
+ name = 'query.dnstap.tests.powerdns.com.'
+
+ target = 'target.dnstap.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.CNAME,
+ target)
+ response.answer.append(rrset)
+
+ rrset = dns.rrset.from_text(target,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(response, receivedResponse)
+
+ # check the dnstap message corresponding to the UDP query
+ dnstap = self.getFirstDnstap()
+
+ checkDnstapQuery(self, dnstap, dnstap_pb2.UDP, query)
+ checkDnstapNoExtra(self, dnstap)