From 1d0bd88ab934b869a0355ab8ee3a123c23cc9ca7 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Tue, 6 Sep 2016 17:21:25 +0200 Subject: [PATCH] dnsdist: Add regression tests for protobuf export --- regression-tests.dnsdist/requirements.txt | 1 + regression-tests.dnsdist/runtests | 1 + regression-tests.dnsdist/test_Protobuf.py | 189 ++++++++++++++++++++++ 3 files changed, 191 insertions(+) create mode 100644 regression-tests.dnsdist/test_Protobuf.py diff --git a/regression-tests.dnsdist/requirements.txt b/regression-tests.dnsdist/requirements.txt index 2063456d4f..1a878ab9ad 100644 --- a/regression-tests.dnsdist/requirements.txt +++ b/regression-tests.dnsdist/requirements.txt @@ -2,3 +2,4 @@ dnspython>=1.11 nose>=1.3.7 libnacl>=1.4.3 requests>=2.1.0 +protobuf>=2.5,<3.0 diff --git a/regression-tests.dnsdist/runtests b/regression-tests.dnsdist/runtests index 12951e26df..6bf1b906ab 100755 --- a/regression-tests.dnsdist/runtests +++ b/regression-tests.dnsdist/runtests @@ -9,6 +9,7 @@ fi . .venv/bin/activate python -V pip install -r requirements.txt +protoc -I=../pdns/ --python_out=. ../pdns/dnsmessage.proto if [ -z "${DNSDISTBIN}" ]; then DNSDISTBIN=$(ls ../pdns/dnsdistdist/dnsdist-*/dnsdist) diff --git a/regression-tests.dnsdist/test_Protobuf.py b/regression-tests.dnsdist/test_Protobuf.py new file mode 100644 index 0000000000..c5722ec73b --- /dev/null +++ b/regression-tests.dnsdist/test_Protobuf.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +import Queue +import threading +import socket +import struct +import sys +import time +from dnsdisttests import DNSDistTest + +import dns +import dnsmessage_pb2 + +class TestProtobuf(DNSDistTest): + + _protobufServerPort = 4242 + _protobufQueue = Queue.Queue() + _protobufCounter = 0 + _config_params = ['_testServerPort', '_protobufServerPort'] + _config_template = """ + newServer{address="127.0.0.1:%s", useClientSubnet=true} + rl = newRemoteLogger('127.0.0.1:%s') + addAction(AllRule(), RemoteLogAction(rl)) + addResponseAction(AllRule(), RemoteLogResponseAction(rl)) + """ + + @classmethod + def ProtobufListener(cls, port): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + try: + sock.bind(("127.0.0.1", port)) + except socket.error as e: + print("Error binding in the protbuf listener: %s" % str(e)) + sys.exit(1) + + sock.listen(100) + while True: + (conn, _) = sock.accept() + data = None + while True: + data = conn.recv(2) + if not data: + break + (datalen,) = struct.unpack("!H", data) + data = conn.recv(datalen) + if not data: + break + + cls._protobufQueue.put(data, True, timeout=2.0) + + conn.close() + sock.close() + + @classmethod + def startResponders(cls): + cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort]) + cls._UDPResponder.setDaemon(True) + cls._UDPResponder.start() + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort]) + cls._TCPResponder.setDaemon(True) + cls._TCPResponder.start() + + cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort]) + cls._protobufListener.setDaemon(True) + cls._protobufListener.start() + + def getFirstProtobufMessage(self): + self.assertFalse(self._protobufQueue.empty()) + data = self._protobufQueue.get(False) + self.assertTrue(data) + msg = dnsmessage_pb2.PBDNSMessage() + msg.ParseFromString(data) + return msg + + def checkProtobufBase(self, msg, protocol, query): + self.assertTrue(msg) + self.assertTrue(msg.HasField('timeSec')) + self.assertTrue(msg.HasField('socketFamily')) + self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET) + self.assertTrue(msg.HasField('from')) + fromvalue = getattr(msg, 'from') + self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), '127.0.0.1') + self.assertTrue(msg.HasField('socketProtocol')) + self.assertEquals(msg.socketProtocol, protocol) + self.assertTrue(msg.HasField('messageId')) + self.assertTrue(msg.HasField('id')) + self.assertEquals(msg.id, query.id) + self.assertTrue(msg.HasField('inBytes')) + self.assertEquals(msg.inBytes, len(query.to_wire())) + # dnsdist doesn't set the existing EDNS Subnet for now, + # although it might be set from Lua + # self.assertTrue(msg.HasField('originalRequestorSubnet')) + # self.assertEquals(len(msg.originalRequestorSubnet), 4) + # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1') + + def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname): + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) + self.checkProtobufBase(msg, protocol, query) + # dnsdist doesn't fill the responder field for responses + # because it doesn't keep the information around. + self.assertTrue(msg.HasField('to')) + self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1') + self.assertTrue(msg.HasField('question')) + self.assertTrue(msg.question.HasField('qClass')) + self.assertEquals(msg.question.qClass, qclass) + self.assertTrue(msg.question.HasField('qType')) + self.assertEquals(msg.question.qClass, qtype) + self.assertTrue(msg.question.HasField('qName')) + self.assertEquals(msg.question.qName, qname) + + def checkProtobufResponse(self, msg, protocol, response): + self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) + self.checkProtobufBase(msg, protocol, response) + self.assertTrue(msg.HasField('response')) + self.assertTrue(msg.response.HasField('queryTimeSec')) + + def testProtobuf(self): + """ + Protobuf: Send data to a protobuf server + """ + name = 'query.protobuf.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the UDP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name) + + # check the protobuf message corresponding to the UDP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) + self.assertEquals(len(msg.response.rrs), 1) + for rr in msg.response.rrs: + self.assertTrue(rr.HasField('class')) + self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN) + self.assertTrue(rr.HasField('type')) + self.assertEquals(rr.type, dns.rdatatype.A) + self.assertTrue(rr.HasField('name')) + self.assertEquals(rr.name, name) + self.assertTrue(rr.HasField('ttl')) + self.assertEquals(rr.ttl, 3600) + self.assertTrue(rr.HasField('rdata')) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the TCP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name) + + # check the protobuf message corresponding to the TCP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) + self.assertEquals(len(msg.response.rrs), 1) + for rr in msg.response.rrs: + self.assertTrue(rr.HasField('class')) + self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN) + self.assertTrue(rr.HasField('type')) + self.assertEquals(rr.type, dns.rdatatype.A) + self.assertTrue(rr.HasField('name')) + self.assertEquals(rr.name, name) + self.assertTrue(rr.HasField('ttl')) + self.assertEquals(rr.ttl, 3600) + self.assertTrue(rr.HasField('rdata')) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') -- 2.47.2