]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
rec-4.1.x: Backport some protobuf regression tests 7430/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 29 Jan 2019 16:10:55 +0000 (17:10 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 29 Jan 2019 16:10:55 +0000 (17:10 +0100)
regression-tests.recursor-dnssec/.gitignore
regression-tests.recursor-dnssec/requirements.txt
regression-tests.recursor-dnssec/runtests
regression-tests.recursor-dnssec/test_Protobuf.py [new file with mode: 0644]

index 7103d74161e8a3311a96385d19e45f7e4453da6e..4b621b2f5e0b04401450e1495161132d0adffe2c 100644 (file)
@@ -3,3 +3,4 @@
 /.venv
 /configs
 /vars
+/*_pb2.py
index 8bd274ed75f571fcd7575d0d8e2d73a13bee7df8..90c10a35bd9a73fba2c39dadb6fce163c1bacfb5 100644 (file)
@@ -1,4 +1,6 @@
 dnspython>=1.11
 nose>=1.3.7
+protobuf>=2.5; sys_platform != 'darwin'
+protobuf>=3.0; sys_platform == 'darwin'
 pysnmp>=4.3.4
 Twisted>0.15.0
index 840b045f93b4d489476e508fc5f2d4bdf6d70d07..658e42e63d57224788f17769b5fee32891484d07 100755 (executable)
@@ -9,6 +9,8 @@ fi
 python -V
 pip install -r requirements.txt
 
+protoc -I=../pdns/ --python_out=. ../pdns/dnsmessage.proto
+
 mkdir -p configs
 
 [ -f ./vars ] && . ./vars
diff --git a/regression-tests.recursor-dnssec/test_Protobuf.py b/regression-tests.recursor-dnssec/test_Protobuf.py
new file mode 100644 (file)
index 0000000..b58138d
--- /dev/null
@@ -0,0 +1,459 @@
+import dns
+import dnsmessage_pb2
+import os
+import socket
+import struct
+import sys
+import threading
+import time
+
+# Python2/3 compatibility hacks
+try:
+  from queue import Queue
+except ImportError:
+  from Queue import Queue
+
+try:
+  range = xrange
+except NameError:
+  pass
+
+from recursortests import RecursorTest
+
+def ProtobufConnectionHandler(queue, conn):
+    data = None
+    while True:
+        data = conn.recv(2)
+        if not data:
+            break
+        (datalen,) = struct.unpack("!H", data)
+        data = conn.recv(datalen)
+        if not data:
+            break
+
+        queue.put(data, True, timeout=2.0)
+
+    conn.close()
+
+def ProtobufListener(queue, port):
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+    try:
+        sock.bind(("127.0.0.1", port))
+    except socket.error as e:
+        print("Error binding in the protobuf listener: %s" % str(e))
+        sys.exit(1)
+
+    sock.listen(100)
+    while True:
+        try:
+            (conn, _) = sock.accept()
+            thread = threading.Thread(name='Connection Handler',
+                                      target=ProtobufConnectionHandler,
+                                      args=[queue, conn])
+            thread.setDaemon(True)
+            thread.start()
+
+        except socket.error as e:
+            print('Error in protobuf socket: %s' % str(e))
+
+    sock.close()
+
+
+class ProtobufServerParams:
+  def __init__(self, port):
+    self.queue = Queue()
+    self.port = port
+
+protobufServersParameters = [ProtobufServerParams(4244)]
+protobufListeners = []
+for param in protobufServersParameters:
+  listener = threading.Thread(name='Protobuf Listener', target=ProtobufListener, args=[param.queue, param.port])
+  listener.setDaemon(True)
+  listener.start()
+  protobufListeners.append(listener)
+
+class TestRecursorProtobuf(RecursorTest):
+
+    _lua_config_file = """
+    protobufServer("127.0.0.1:%d")
+    """ % (protobufServersParameters[0].port)
+
+
+    def getFirstProtobufMessage(self, retries=1, waitTime=1):
+        msg = None
+
+        print("in getFirstProtobufMessage")
+        for param in protobufServersParameters:
+          print(param.port)
+          failed = 0
+
+          while param.queue.empty:
+            print(failed)
+            print(retries)
+            if failed >= retries:
+              break
+
+            failed = failed + 1
+            print("waiting")
+            time.sleep(waitTime)
+
+          self.assertFalse(param.queue.empty())
+          data = param.queue.get(False)
+          self.assertTrue(data)
+          oldmsg = msg
+          msg = dnsmessage_pb2.PBDNSMessage()
+          msg.ParseFromString(data)
+          if oldmsg is not None:
+            self.assertEquals(msg, oldmsg)
+
+        return msg
+
+    def checkNoRemainingMessage(self):
+        for param in protobufServersParameters:
+          self.assertTrue(param.queue.empty())
+
+    def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True, expectedECS=None, receivedSize=None):
+        self.assertTrue(msg)
+        self.assertTrue(msg.HasField('timeSec'))
+        self.assertTrue(msg.HasField('socketFamily'))
+        self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
+        self.assertTrue(msg.HasField('from'))
+        fromvalue = getattr(msg, 'from')
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
+        self.assertTrue(msg.HasField('socketProtocol'))
+        self.assertEquals(msg.socketProtocol, protocol)
+        self.assertTrue(msg.HasField('messageId'))
+        self.assertTrue(msg.HasField('serverIdentity'))
+        self.assertTrue(msg.HasField('id'))
+        self.assertEquals(msg.id, query.id)
+        self.assertTrue(msg.HasField('inBytes'))
+        if normalQueryResponse:
+            # compare inBytes with length of query/response
+            # Note that for responses, the size we received might differ
+            # because dnspython might compress labels differently from
+            # the recursor
+            if receivedSize:
+                self.assertEquals(msg.inBytes, receivedSize)
+            else:
+                self.assertEquals(msg.inBytes, len(query.to_wire()))
+        if expectedECS is not None:
+            self.assertTrue(msg.HasField('originalRequestorSubnet'))
+            # v4 only for now
+            self.assertEquals(len(msg.originalRequestorSubnet), 4)
+            self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
+
+    def checkOutgoingProtobufBase(self, msg, protocol, query, initiator):
+        self.assertTrue(msg)
+        self.assertTrue(msg.HasField('timeSec'))
+        self.assertTrue(msg.HasField('socketFamily'))
+        self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
+        self.assertTrue(msg.HasField('socketProtocol'))
+        self.assertEquals(msg.socketProtocol, protocol)
+        self.assertTrue(msg.HasField('messageId'))
+        self.assertTrue(msg.HasField('serverIdentity'))
+        self.assertTrue(msg.HasField('id'))
+        self.assertNotEquals(msg.id, query.id)
+        self.assertTrue(msg.HasField('inBytes'))
+        # compare inBytes with length of query/response
+        self.assertEquals(msg.inBytes, len(query.to_wire()))
+
+    def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
+        self.checkProtobufBase(msg, protocol, query, initiator)
+        # dnsdist doesn't fill the responder field for responses
+        # because it doesn't keep the information around.
+        self.assertTrue(msg.HasField('to'))
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
+        self.assertTrue(msg.HasField('question'))
+        self.assertTrue(msg.question.HasField('qClass'))
+        self.assertEquals(msg.question.qClass, qclass)
+        self.assertTrue(msg.question.HasField('qType'))
+        self.assertEquals(msg.question.qClass, qtype)
+        self.assertTrue(msg.question.HasField('qName'))
+        self.assertEquals(msg.question.qName, qname)
+
+    def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1', receivedSize=None):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+        self.checkProtobufBase(msg, protocol, response, initiator, receivedSize=receivedSize)
+        self.assertTrue(msg.HasField('response'))
+        self.assertTrue(msg.response.HasField('queryTimeSec'))
+
+    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl, checkTTL=True):
+        self.assertTrue(record.HasField('class'))
+        self.assertEquals(getattr(record, 'class'), rclass)
+        self.assertTrue(record.HasField('type'))
+        self.assertEquals(record.type, rtype)
+        self.assertTrue(record.HasField('name'))
+        self.assertEquals(record.name, rname)
+        self.assertTrue(record.HasField('ttl'))
+        if checkTTL:
+            self.assertEquals(record.ttl, rttl)
+        self.assertTrue(record.HasField('rdata'))
+
+    def checkProtobufPolicy(self, msg, policyType, reason):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+        self.assertTrue(msg.response.HasField('appliedPolicyType'))
+        self.assertTrue(msg.response.HasField('appliedPolicy'))
+        self.assertEquals(msg.response.appliedPolicy, reason)
+        self.assertEquals(msg.response.appliedPolicyType, policyType)
+
+    def checkProtobufTags(self, msg, tags):
+        self.assertEquals(len(msg.response.tags), len(tags))
+        for tag in msg.response.tags:
+            self.assertTrue(tag in tags)
+
+    def checkProtobufOutgoingQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSOutgoingQueryType)
+        self.checkOutgoingProtobufBase(msg, protocol, query, initiator)
+        self.assertTrue(msg.HasField('to'))
+        self.assertTrue(msg.HasField('question'))
+        self.assertTrue(msg.question.HasField('qClass'))
+        self.assertEquals(msg.question.qClass, qclass)
+        self.assertTrue(msg.question.HasField('qType'))
+        self.assertEquals(msg.question.qClass, qtype)
+        self.assertTrue(msg.question.HasField('qName'))
+        self.assertEquals(msg.question.qName, qname)
+
+    def checkProtobufIncomingResponse(self, msg, protocol, response, initiator='127.0.0.1'):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSIncomingResponseType)
+        self.checkOutgoingProtobufBase(msg, protocol, response, initiator)
+        self.assertTrue(msg.HasField('response'))
+        self.assertTrue(msg.response.HasField('queryTimeSec'))
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.setUpSockets()
+
+        cls.startResponders()
+
+        confdir = os.path.join('configs', cls._confdir)
+        cls.createConfigDir(confdir)
+
+        cls.generateRecursorConfig(confdir)
+        cls.startRecursor(confdir, cls._recursorPort)
+
+    def setUp(self):
+      # Make sure the queue is empty, in case
+      # a previous test failed
+      for param in protobufServersParameters:
+        while not param.queue.empty():
+          param.queue.get(False)
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+tagged 3600 IN A 192.0.2.84
+query-selected 3600 IN A 192.0.2.84
+answer-selected 3600 IN A 192.0.2.84
+types 3600 IN A 192.0.2.84
+types 3600 IN AAAA 2001:DB8::1
+types 3600 IN TXT "Lorem ipsum dolor sit amet"
+types 3600 IN MX 10 a.example.
+types 3600 IN SPF "v=spf1 -all"
+types 3600 IN SRV 10 20 443 a.example.
+cname 3600 IN CNAME a.example.
+
+""".format(soa=cls._SOA))
+        super(TestRecursorProtobuf, cls).generateRecursorConfig(confdir)
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.tearDownRecursor()
+
+class ProtobufDefaultTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly export queries and response over protobuf.
+    """
+
+    _confdir = 'ProtobufDefault'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+
+    def testA(self):
+        name = 'a.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+
+        self.assertRRsetInAnswer(res, expected)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1')
+        self.assertEquals(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.42')
+        self.checkNoRemainingMessage()
+
+    def testCNAME(self):
+        name = 'cname.example.'
+        expectedCNAME = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'CNAME', 'a.example.')
+        expectedA = dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        raw = self.sendUDPQuery(query, decode=False)
+        res = dns.message.from_wire(raw)
+        self.assertRRsetInAnswer(res, expectedCNAME)
+        self.assertRRsetInAnswer(res, expectedA)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        # but first let the protobuf messages the time to get there
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '127.0.0.1', receivedSize=len(raw))
+        self.assertEquals(len(msg.response.rrs), 2)
+        rr = msg.response.rrs[0]
+        # we don't want to check the TTL for the A record, it has been cached by the previous test
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 15)
+        self.assertEquals(rr.rdata, 'a.example.')
+        rr = msg.response.rrs[1]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, 'a.example.', 15, checkTTL=False)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.42')
+        self.checkNoRemainingMessage()
+
+class OutgoingProtobufDefaultTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly export outgoing queries over protobuf.
+    It must be improved and setup env so we can check for incoming responses, but makes sure for now
+    that the recursor at least connects to the protobuf server.
+    """
+
+    _confdir = 'OutgoingProtobufDefault'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    outgoingProtobufServer("127.0.0.1:%d")
+    """ % (protobufServersParameters[0].port)
+
+    def testA(self):
+        name = 'www.example.org.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.RD
+        res = self.sendUDPQuery(query)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufOutgoingQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+#        # then the response
+#        msg = self.getFirstProtobufMessage()
+#        self.checkProtobufIncomingResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.checkNoRemainingMessage()
+
+class ProtobufMasksTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly export queries and response over protobuf, respecting the configured initiator masking.
+    """
+
+    _confdir = 'ProtobufMasks'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _protobufMaskV4 = 4
+    _protobufMaskV6 = 128
+    _lua_config_file = """
+    protobufServer("127.0.0.1:%d", 2, 100, 1, %d, %d)
+    """ % (protobufServersParameters[0].port, _protobufMaskV4, _protobufMaskV6)
+
+    def testA(self):
+        name = 'a.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        # but first let the protobuf messages the time to get there
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '112.0.0.0')
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res, '112.0.0.0')
+        self.assertEquals(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.42')
+        self.checkNoRemainingMessage()
+
+class ProtobufTaggedOnlyTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly export queries and responses but only if they have been tagged.
+    """
+
+    _confdir = 'ProtobufTaggedOnly'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    protobufServer("127.0.0.1:%d", 2, 100, 1, 32, 128, false, true)
+    """ % (protobufServersParameters[0].port)
+    _tags = ['tag1', 'tag2']
+    _tag_from_gettag = 'tag-from-gettag'
+    _lua_dns_script_file = """
+    function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp)
+      if qname:equal('tagged.example.') then
+        return 0, { '%s' }
+      end
+      return 0
+    end
+    function preresolve(dq)
+      if dq.qname:equal('tagged.example.') then
+        dq:addPolicyTag('%s')
+        dq:addPolicyTag('%s')
+      end
+      return false
+    end
+    """ % (_tag_from_gettag, _tags[0], _tags[1])
+
+    def testA(self):
+        name = 'a.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        # check the protobuf message corresponding to the UDP response
+        # the first query and answer are not tagged, so there is nothing in the queue
+        time.sleep(1)
+        self.checkNoRemainingMessage()
+
+    def testTagged(self):
+        name = 'tagged.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        self.checkProtobufTags(msg, [self._tag_from_gettag])
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEquals(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        tags = [self._tag_from_gettag] + self._tags
+        self.checkProtobufTags(msg, tags)
+        self.checkNoRemainingMessage()
+