]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.recursor-dnssec/test_Protobuf.py
dnsdist doc typo fix
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_Protobuf.py
index 3f86caf804308365974d140e7e6407f3cadb0355..50a6449d4b45487175a9d1813e3d973eefd0ed67 100644 (file)
@@ -8,18 +8,18 @@ import threading
 import time
 
 # Python2/3 compatibility hacks
-if sys.version_info[0] == 2:
+try:
+  from queue import Queue
+except ImportError:
   from Queue import Queue
+
+try:
   range = xrange
-else:
-  from queue import Queue
-  range = range  # allow re-export of the builtin name
+except NameError:
+  pass
 
 from recursortests import RecursorTest
 
-protobufQueue = Queue()
-protobufServerPort = 4243
-
 def ProtobufConnectionHandler(queue, conn):
     data = None
     while True:
@@ -35,8 +35,7 @@ def ProtobufConnectionHandler(queue, conn):
 
     conn.close()
 
-def ProtobufListener(port):
-    global protobufQueue
+def ProtobufListener(queue, port):
     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
     try:
@@ -51,7 +50,7 @@ def ProtobufListener(port):
             (conn, _) = sock.accept()
             thread = threading.Thread(name='Connection Handler',
                                       target=ProtobufConnectionHandler,
-                                      args=[protobufQueue, conn])
+                                      args=[queue, conn])
             thread.setDaemon(True)
             thread.start()
 
@@ -61,41 +60,60 @@ def ProtobufListener(port):
     sock.close()
 
 
-protobufListener = threading.Thread(name='Protobuf Listener', target=ProtobufListener, args=[protobufServerPort])
-protobufListener.setDaemon(True)
-protobufListener.start()
+class ProtobufServerParams:
+  def __init__(self, port):
+    self.queue = Queue()
+    self.port = port
+
+protobufServersParameters = [ProtobufServerParams(4243), ProtobufServerParams(4244)]
+protobufListeners = []
+for param in protobufServersParameters:
+  listener = threading.Thread(name='Protobuf Listener', target=ProtobufListener, args=[param.queue, param.port])
+  listener.setDaemon(True)
+  listener.start()
+  protobufListeners.append(listener)
 
 class TestRecursorProtobuf(RecursorTest):
 
-    global protobufServerPort
     _lua_config_file = """
-    protobufServer("127.0.0.1:%d")
-    """ % (protobufServerPort)
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
 
 
     def getFirstProtobufMessage(self, retries=1, waitTime=1):
-        global protobufQueue
-        failed = 0
-
-        while protobufQueue.empty:
-          if failed >= retries:
-            break
-
-          failed = failed + 1
-          time.sleep(waitTime)
+        msg = None
+
+        print("in getFirstProtobufMessage")
+        for param in protobufServersParameters:
+          print(param.port)
+          failed = 0
+
+          while param.queue.empty:
+            print(failed)
+            print(retries)
+            if failed >= retries:
+              break
+
+            failed = failed + 1
+            print("waiting")
+            time.sleep(waitTime)
+
+          self.assertFalse(param.queue.empty())
+          data = param.queue.get(False)
+          self.assertTrue(data)
+          oldmsg = msg
+          msg = dnsmessage_pb2.PBDNSMessage()
+          msg.ParseFromString(data)
+          if oldmsg is not None:
+            self.assertEquals(msg, oldmsg)
 
-        self.assertFalse(protobufQueue.empty())
-        data = protobufQueue.get(False)
-        self.assertTrue(data)
-        msg = dnsmessage_pb2.PBDNSMessage()
-        msg.ParseFromString(data)
         return msg
 
     def checkNoRemainingMessage(self):
-        global protobufQueue
-        self.assertTrue(protobufQueue.empty())
+        for param in protobufServersParameters:
+          self.assertTrue(param.queue.empty())
 
-    def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True, expectedECS=None):
+    def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True, expectedECS=None, receivedSize=None):
         self.assertTrue(msg)
         self.assertTrue(msg.HasField('timeSec'))
         self.assertTrue(msg.HasField('socketFamily'))
@@ -112,7 +130,13 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertTrue(msg.HasField('inBytes'))
         if normalQueryResponse:
             # compare inBytes with length of query/response
-            self.assertEquals(msg.inBytes, len(query.to_wire()))
+            # Note that for responses, the size we received might differ
+            # because dnspython might compress labels differently from
+            # the recursor
+            if receivedSize:
+                self.assertEquals(msg.inBytes, receivedSize)
+            else:
+                self.assertEquals(msg.inBytes, len(query.to_wire()))
         if expectedECS is not None:
             self.assertTrue(msg.HasField('originalRequestorSubnet'))
             # v4 only for now
@@ -149,13 +173,13 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertTrue(msg.question.HasField('qName'))
         self.assertEquals(msg.question.qName, qname)
 
-    def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
+    def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1', receivedSize=None):
         self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
-        self.checkProtobufBase(msg, protocol, response, initiator)
+        self.checkProtobufBase(msg, protocol, response, initiator, receivedSize=receivedSize)
         self.assertTrue(msg.HasField('response'))
         self.assertTrue(msg.response.HasField('queryTimeSec'))
 
-    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
+    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl, checkTTL=True):
         self.assertTrue(record.HasField('class'))
         self.assertEquals(getattr(record, 'class'), rclass)
         self.assertTrue(record.HasField('type'))
@@ -163,7 +187,8 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertTrue(record.HasField('name'))
         self.assertEquals(record.name, rname)
         self.assertTrue(record.HasField('ttl'))
-        self.assertEquals(record.ttl, rttl)
+        if checkTTL:
+            self.assertEquals(record.ttl, rttl)
         self.assertTrue(record.HasField('rdata'))
 
     def checkProtobufPolicy(self, msg, policyType, reason):
@@ -199,14 +224,6 @@ class TestRecursorProtobuf(RecursorTest):
     @classmethod
     def setUpClass(cls):
 
-        global protobufListener
-        global protobufServerPort
-        global ProtobufListener
-        if protobufListener is None or not protobufListener.isAlive():
-            protobufListener = threading.Thread(name='Protobuf Listener', target=ProtobufListener, args=[protobufServerPort])
-            protobufListener.setDaemon(True)
-            protobufListener.start()
-
         cls.setUpSockets()
 
         cls.startResponders()
@@ -220,9 +237,9 @@ class TestRecursorProtobuf(RecursorTest):
     def setUp(self):
       # Make sure the queue is empty, in case
       # a previous test failed
-      global protobufQueue
-      while not protobufQueue.empty():
-        protobufQueue.get(False)
+      for param in protobufServersParameters:
+        while not param.queue.empty():
+          param.queue.get(False)
 
     @classmethod
     def generateRecursorConfig(cls, confdir):
@@ -234,6 +251,14 @@ a 3600 IN A 192.0.2.42
 tagged 3600 IN A 192.0.2.84
 query-selected 3600 IN A 192.0.2.84
 answer-selected 3600 IN A 192.0.2.84
+types 3600 IN A 192.0.2.84
+types 3600 IN AAAA 2001:DB8::1
+types 3600 IN TXT "Lorem ipsum dolor sit amet"
+types 3600 IN MX 10 a.example.
+types 3600 IN SPF "v=spf1 -all"
+types 3600 IN SRV 10 20 443 a.example.
+cname 3600 IN CNAME a.example.
+
 """.format(soa=cls._SOA))
         super(TestRecursorProtobuf, cls).generateRecursorConfig(confdir)
 
@@ -256,6 +281,7 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         query = dns.message.make_query(name, 'A', want_dnssec=True)
         query.flags |= dns.flags.CD
         res = self.sendUDPQuery(query)
+
         self.assertRRsetInAnswer(res, expected)
 
         # check the protobuf messages corresponding to the UDP query and answer
@@ -263,7 +289,7 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
         # then the response
         msg = self.getFirstProtobufMessage()
-        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1')
         self.assertEquals(len(msg.response.rrs), 1)
         rr = msg.response.rrs[0]
         # we have max-cache-ttl set to 15
@@ -271,6 +297,35 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.42')
         self.checkNoRemainingMessage()
 
+    def testCNAME(self):
+        name = 'cname.example.'
+        expectedCNAME = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'CNAME', 'a.example.')
+        expectedA = dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        raw = self.sendUDPQuery(query, decode=False)
+        res = dns.message.from_wire(raw)
+        self.assertRRsetInAnswer(res, expectedCNAME)
+        self.assertRRsetInAnswer(res, expectedA)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        # but first let the protobuf messages the time to get there
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1', receivedSize=len(raw))
+        self.assertEquals(len(msg.response.rrs), 2)
+        rr = msg.response.rrs[0]
+        # we don't want to check the TTL for the A record, it has been cached by the previous test
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 15)
+        self.assertEquals(rr.rdata, 'a.example.')
+        rr = msg.response.rrs[1]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, 'a.example.', 15, checkTTL=False)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.42')
+        self.checkNoRemainingMessage()
+
 class OutgoingProtobufDefaultTest(TestRecursorProtobuf):
     """
     This test makes sure that we correctly export outgoing queries over protobuf.
@@ -282,8 +337,8 @@ class OutgoingProtobufDefaultTest(TestRecursorProtobuf):
     _config_template = """
 auth-zones=example=configs/%s/example.zone""" % _confdir
     _lua_config_file = """
-    outgoingProtobufServer("127.0.0.1:%d")
-    """ % (protobufServerPort)
+    outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
 
     def testA(self):
         name = 'www.example.org.'
@@ -308,13 +363,12 @@ class ProtobufMasksTest(TestRecursorProtobuf):
     _confdir = 'ProtobufMasks'
     _config_template = """
 auth-zones=example=configs/%s/example.zone""" % _confdir
-    global protobufServerPort
     _protobufMaskV4 = 4
     _protobufMaskV6 = 128
     _lua_config_file = """
-    protobufServer("127.0.0.1:%d")
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
     setProtobufMasks(%d, %d)
-    """ % (protobufServerPort, _protobufMaskV4, _protobufMaskV6)
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port, _protobufMaskV4, _protobufMaskV6)
 
     def testA(self):
         name = 'a.example.'
@@ -346,10 +400,9 @@ class ProtobufQueriesOnlyTest(TestRecursorProtobuf):
     _confdir = 'ProtobufQueriesOnly'
     _config_template = """
 auth-zones=example=configs/%s/example.zone""" % _confdir
-    global protobufServerPort
     _lua_config_file = """
-    protobufServer("127.0.0.1:%d", { logQueries=true, logResponses=false } )
-    """ % (protobufServerPort)
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=true, logResponses=false } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
 
     def testA(self):
         name = 'a.example.'
@@ -373,10 +426,9 @@ class ProtobufResponsesOnlyTest(TestRecursorProtobuf):
     _confdir = 'ProtobufResponsesOnly'
     _config_template = """
 auth-zones=example=configs/%s/example.zone""" % _confdir
-    global protobufServerPort
     _lua_config_file = """
-    protobufServer("127.0.0.1:%d", { logQueries=false, logResponses=true } )
-    """ % (protobufServerPort)
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
 
     def testA(self):
         name = 'a.example.'
@@ -405,10 +457,9 @@ class ProtobufTaggedOnlyTest(TestRecursorProtobuf):
     _confdir = 'ProtobufTaggedOnly'
     _config_template = """
 auth-zones=example=configs/%s/example.zone""" % _confdir
-    global protobufServerPort
     _lua_config_file = """
-    protobufServer("127.0.0.1:%d", { logQueries=true, logResponses=true, taggedOnly=true } )
-    """ % (protobufServerPort)
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=true, logResponses=true, taggedOnly=true } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
     _tags = ['tag1', 'tag2']
     _tag_from_gettag = 'tag-from-gettag'
     _lua_dns_script_file = """
@@ -472,10 +523,9 @@ class ProtobufSelectedFromLuaTest(TestRecursorProtobuf):
     _confdir = 'ProtobufSelectedFromLua'
     _config_template = """
 auth-zones=example=configs/%s/example.zone""" % _confdir
-    global protobufServerPort
     _lua_config_file = """
-    protobufServer("127.0.0.1:%d", { logQueries=false, logResponses=false } )
-    """ % (protobufServerPort)
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=false } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
     _lua_dns_script_file = """
     local ffi = require("ffi")
 
@@ -545,3 +595,60 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
         self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
         self.checkNoRemainingMessage()
+
+class ProtobufExportTypesTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly export other types than A, AAAA and CNAME over protobuf.
+    """
+
+    _confdir = 'ProtobufExportTypes'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { exportTypes={"AAAA", "MX", "SPF", "SRV", "TXT"} } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+
+    def testA(self):
+        name = 'types.example.'
+        expected = [dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84'),
+                    dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'AAAA', '2001:DB8::1'),
+                    dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'MX', '10 a.example.'),
+                    dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'SPF', '"v=spf1 -all"'),
+                    dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'SRV', '10 20 443 a.example.'),
+                    dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'TXT', '"Lorem ipsum dolor sit amet"'),
+                   ]
+        query = dns.message.make_query(name, 'ANY', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        raw = self.sendUDPQuery(query, decode=False)
+        res = dns.message.from_wire(raw)
+
+        for rrset in expected:
+            self.assertRRsetInAnswer(res, rrset)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1', receivedSize=len(raw))
+        self.assertEquals(len(msg.response.rrs), 5)
+        for rr in msg.response.rrs:
+            self.assertTrue(rr.type in [dns.rdatatype.AAAA, dns.rdatatype.TXT, dns.rdatatype.MX, dns.rdatatype.SPF, dns.rdatatype.SRV])
+
+            if rr.type == dns.rdatatype.AAAA:
+                self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.AAAA, name, 15)
+                self.assertEquals(socket.inet_ntop(socket.AF_INET6, rr.rdata), '2001:db8::1')
+            elif rr.type == dns.rdatatype.TXT:
+                self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.TXT, name, 15)
+                self.assertEquals(rr.rdata, '"Lorem ipsum dolor sit amet"')
+            elif rr.type == dns.rdatatype.MX:
+                self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.MX, name, 15)
+                self.assertEquals(rr.rdata, 'a.example.')
+            elif rr.type == dns.rdatatype.SPF:
+                self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.SPF, name, 15)
+                self.assertEquals(rr.rdata, '"v=spf1 -all"')
+            elif rr.type == dns.rdatatype.SRV:
+                self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.SRV, name, 15)
+                self.assertEquals(rr.rdata, 'a.example.')
+
+        self.checkNoRemainingMessage()