]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for protobuf export
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 6 Sep 2016 15:21:25 +0000 (17:21 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 6 Sep 2016 15:21:25 +0000 (17:21 +0200)
regression-tests.dnsdist/requirements.txt
regression-tests.dnsdist/runtests
regression-tests.dnsdist/test_Protobuf.py [new file with mode: 0644]

index 2063456d4fe5bc5c4864ecab582231aec0f1b47c..1a878ab9ad017a0e6979f240e14e653edf2113e1 100644 (file)
@@ -2,3 +2,4 @@ dnspython>=1.11
 nose>=1.3.7
 libnacl>=1.4.3
 requests>=2.1.0
+protobuf>=2.5,<3.0
index 12951e26df707ded5a8bed581234d044486c2950..6bf1b906abd7cf60f313d4660f7306b069349d8f 100755 (executable)
@@ -9,6 +9,7 @@ fi
 . .venv/bin/activate
 python -V
 pip install -r requirements.txt
+protoc -I=../pdns/ --python_out=. ../pdns/dnsmessage.proto
 
 if [ -z "${DNSDISTBIN}" ]; then
     DNSDISTBIN=$(ls ../pdns/dnsdistdist/dnsdist-*/dnsdist)
diff --git a/regression-tests.dnsdist/test_Protobuf.py b/regression-tests.dnsdist/test_Protobuf.py
new file mode 100644 (file)
index 0000000..c5722ec
--- /dev/null
@@ -0,0 +1,189 @@
+#!/usr/bin/env python
+import Queue
+import threading
+import socket
+import struct
+import sys
+import time
+from dnsdisttests import DNSDistTest
+
+import dns
+import dnsmessage_pb2
+
+class TestProtobuf(DNSDistTest):
+
+    _protobufServerPort = 4242
+    _protobufQueue = Queue.Queue()
+    _protobufCounter = 0
+    _config_params = ['_testServerPort', '_protobufServerPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    rl = newRemoteLogger('127.0.0.1:%s')
+    addAction(AllRule(), RemoteLogAction(rl))
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl))
+    """
+
+    @classmethod
+    def ProtobufListener(cls, port):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", port))
+        except socket.error as e:
+            print("Error binding in the protbuf listener: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        while True:
+            (conn, _) = sock.accept()
+            data = None
+            while True:
+                data = conn.recv(2)
+                if not data:
+                    break
+                (datalen,) = struct.unpack("!H", data)
+                data = conn.recv(datalen)
+                if not data:
+                    break
+
+                cls._protobufQueue.put(data, True, timeout=2.0)
+
+            conn.close()
+        sock.close()
+
+    @classmethod
+    def startResponders(cls):
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
+        cls._UDPResponder.setDaemon(True)
+        cls._UDPResponder.start()
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
+        cls._TCPResponder.setDaemon(True)
+        cls._TCPResponder.start()
+
+        cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
+        cls._protobufListener.setDaemon(True)
+        cls._protobufListener.start()
+
+    def getFirstProtobufMessage(self):
+        self.assertFalse(self._protobufQueue.empty())
+        data = self._protobufQueue.get(False)
+        self.assertTrue(data)
+        msg = dnsmessage_pb2.PBDNSMessage()
+        msg.ParseFromString(data)
+        return msg
+
+    def checkProtobufBase(self, msg, protocol, query):
+        self.assertTrue(msg)
+        self.assertTrue(msg.HasField('timeSec'))
+        self.assertTrue(msg.HasField('socketFamily'))
+        self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
+        self.assertTrue(msg.HasField('from'))
+        fromvalue = getattr(msg, 'from')
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), '127.0.0.1')
+        self.assertTrue(msg.HasField('socketProtocol'))
+        self.assertEquals(msg.socketProtocol, protocol)
+        self.assertTrue(msg.HasField('messageId'))
+        self.assertTrue(msg.HasField('id'))
+        self.assertEquals(msg.id, query.id)
+        self.assertTrue(msg.HasField('inBytes'))
+        self.assertEquals(msg.inBytes, len(query.to_wire()))
+        # dnsdist doesn't set the existing EDNS Subnet for now,
+        # although it might be set from Lua
+        # self.assertTrue(msg.HasField('originalRequestorSubnet'))
+        # self.assertEquals(len(msg.originalRequestorSubnet), 4)
+        # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
+
+    def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
+        self.checkProtobufBase(msg, protocol, query)
+        # dnsdist doesn't fill the responder field for responses
+        # because it doesn't keep the information around.
+        self.assertTrue(msg.HasField('to'))
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
+        self.assertTrue(msg.HasField('question'))
+        self.assertTrue(msg.question.HasField('qClass'))
+        self.assertEquals(msg.question.qClass, qclass)
+        self.assertTrue(msg.question.HasField('qType'))
+        self.assertEquals(msg.question.qClass, qtype)
+        self.assertTrue(msg.question.HasField('qName'))
+        self.assertEquals(msg.question.qName, qname)
+
+    def checkProtobufResponse(self, msg, protocol, response):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+        self.checkProtobufBase(msg, protocol, response)
+        self.assertTrue(msg.HasField('response'))
+        self.assertTrue(msg.response.HasField('queryTimeSec'))
+
+    def testProtobuf(self):
+        """
+        Protobuf: Send data to a protobuf server
+        """
+        name = 'query.protobuf.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the UDP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+        # check the protobuf message corresponding to the UDP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
+        self.assertEquals(len(msg.response.rrs), 1)
+        for rr in msg.response.rrs:
+            self.assertTrue(rr.HasField('class'))
+            self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
+            self.assertTrue(rr.HasField('type'))
+            self.assertEquals(rr.type, dns.rdatatype.A)
+            self.assertTrue(rr.HasField('name'))
+            self.assertEquals(rr.name, name)
+            self.assertTrue(rr.HasField('ttl'))
+            self.assertEquals(rr.ttl, 3600)
+            self.assertTrue(rr.HasField('rdata'))
+            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the TCP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+        # check the protobuf message corresponding to the TCP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
+        self.assertEquals(len(msg.response.rrs), 1)
+        for rr in msg.response.rrs:
+            self.assertTrue(rr.HasField('class'))
+            self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN)
+            self.assertTrue(rr.HasField('type'))
+            self.assertEquals(rr.type, dns.rdatatype.A)
+            self.assertTrue(rr.HasField('name'))
+            self.assertEquals(rr.name, name)
+            self.assertTrue(rr.HasField('ttl'))
+            self.assertEquals(rr.ttl, 3600)
+            self.assertTrue(rr.HasField('rdata'))
+            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')