]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
completed test for mods to dnsdist in pdns/regression-tests.dnsdist/test_ProtobufTag.py
authorSeth Ornstein <sornstein@globalcyberalliance.org>
Fri, 23 Jun 2017 04:12:27 +0000 (00:12 -0400)
committerNick Saika <nicksaika@gmail.com>
Wed, 2 Aug 2017 15:06:43 +0000 (11:06 -0400)
script to execute it in pdns/zzz-gca-example/test-protobuf-tag.sh

regression-tests.dnsdist/test_ProtobufTag.py [new file with mode: 0644]
zzz-gca-example/dnsdist.conf
zzz-gca-example/test-protobuf-tag.sh [new file with mode: 0755]

diff --git a/regression-tests.dnsdist/test_ProtobufTag.py b/regression-tests.dnsdist/test_ProtobufTag.py
new file mode 100644 (file)
index 0000000..976cecf
--- /dev/null
@@ -0,0 +1,363 @@
+#!/usr/bin/env python
+import Queue
+import threading
+import socket
+import struct
+import sys
+import time
+from dnsdisttests import DNSDistTest
+
+import dns
+import dnsmessage_pb2
+
+
+class TestProtobuf(DNSDistTest):
+    _protobufServerPort = 4242
+    _protobufQueue = Queue.Queue()
+    _protobufCounter = 0
+    _config_params = ['_testServerPort', '_protobufServerPort']
+    _config_template = """
+    luasmn = newSuffixMatchNode()                              
+    luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) 
+
+    function alterProtobufResponse(dq, protobuf)               
+      if luasmn:check(dq.qname) then                           
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()  
+        if requestor:isIPv4() then                             
+          requestor:truncate(24)                               
+        else
+          requestor:truncate(56)                               
+        end
+        protobuf:setRequestor(requestor)                       
+
+       local tableTags = {}                                     
+        tableTags["TestLabel2"] = "TestData2"                    
+        tableTags["TestLabel1"] = "TestData1"                    
+       protobuf:setTagArray(tableTags)                         -- setTagArray 
+
+       protobuf:setTag('TestLabel3', 'TestData3')              -- setTag
+
+       protobuf:setTag("Response", "456")                      -- setTag
+      else                                                     
+       local tableTags = {}                                    -- called by testProtobuf()                                   
+        tableTags["TestLabel2"] = "TestData2"                   
+        tableTags["TestLabel1"] = "TestData1"                   
+       protobuf:setTagArray(tableTags)                         -- setTagArray
+
+       protobuf:setTag('TestLabel3', 'TestData3')              -- setTag
+
+       protobuf:setTag("Response", "456")                      -- setTag
+      end
+    end
+
+    function alterProtobufQuery(dq, protobuf)                  
+      if luasmn:check(dq.qname) then                           
+        requestor = newCA(dq.remoteaddr:toString())            -- called by testLuaProtobuf()
+        if requestor:isIPv4() then                             
+          requestor:truncate(24)                               
+        else
+          requestor:truncate(56)                               
+        end
+        protobuf:setRequestor(requestor)                       
+
+       local tableTags = {}                                    -- declare table
+       tableTags = dq:getTagArray()                            -- get table from DNSQuery
+
+       protobuf:setTagArray(tableTags)                         -- store table in protobuf
+       protobuf:setTag("Query", "123")                         -- add another tag entry in protobuf
+
+       protobuf:setResponseCode(dnsdist.NXDOMAIN)              -- set protobuf response code to be NXDOMAIN
+
+       local strReqName = dq.qname:toString()                  -- get request dns name
+
+       protobuf:setProtobufResponseType(strReqName)            -- set protobuf to look like a response and not a query
+                                                               
+      else
+       local tableTags = {}                                    -- called by testProtobuf()
+        tableTags["TestLabel2"] = "TestData2"                   
+        tableTags["TestLabel1"] = "TestData1"                   
+       protobuf:setTagArray(tableTags)                         -- setTagArray
+       protobuf:setTag('TestLabel3', 'TestData3')              -- setTag
+       protobuf:setTag("Query", "123")                         -- setTag
+      end
+    end
+
+    function alterLuaFirst(dq)                                 -- called when dnsdist receives new request
+
+       local tt = {}                                   
+        tt["TestLabel2"] = "TestData2"                  
+        tt["TestLabel1"] = "TestData1"                  
+
+       dq:setTagArray(tt)                                      -- setTagArray
+
+       dq:setTag('TestLabel3', 'TestData3')                    -- setTag
+
+       return DNSAction.None, ""                               -- continue to the next rule
+    end
+
+
+    newServer{address="127.0.0.1:%s", useClientSubnet=true}
+    rl = newRemoteLogger('127.0.0.1:%s')
+
+    addLuaAction(AllRule(), alterLuaFirst)                                                     -- Add tags to DNSQuery first
+
+    addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery))                              -- Send protobuf message before lookup
+
+    addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true))     -- Send protobuf message after lookup
+
+    """
+
+    @classmethod
+    def ProtobufListener(cls, port):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", port))
+        except socket.error as e:
+            print("Error binding in the protbuf listener: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        while True:
+            (conn, _) = sock.accept()
+            data = None
+            while True:
+                data = conn.recv(2)
+                if not data:
+                    break
+                (datalen,) = struct.unpack("!H", data)
+                data = conn.recv(datalen)
+                if not data:
+                    break
+
+                cls._protobufQueue.put(data, True, timeout=2.0)
+
+            conn.close()
+        sock.close()
+
+    @classmethod
+    def startResponders(cls):
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
+        cls._UDPResponder.setDaemon(True)
+        cls._UDPResponder.start()
+
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
+        cls._TCPResponder.setDaemon(True)
+        cls._TCPResponder.start()
+
+        cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
+        cls._protobufListener.setDaemon(True)
+        cls._protobufListener.start()
+
+    def getFirstProtobufMessage(self):
+        self.assertFalse(self._protobufQueue.empty())
+        data = self._protobufQueue.get(False)
+        self.assertTrue(data)
+        msg = dnsmessage_pb2.PBDNSMessage()
+        msg.ParseFromString(data)
+        return msg
+
+    def checkProtobufBase(self, msg, protocol, query, initiator):
+        self.assertTrue(msg)
+        self.assertTrue(msg.HasField('timeSec'))
+        self.assertTrue(msg.HasField('socketFamily'))
+        self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
+        self.assertTrue(msg.HasField('from'))
+        fromvalue = getattr(msg, 'from')
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)              
+        self.assertTrue(msg.HasField('socketProtocol'))
+        self.assertEquals(msg.socketProtocol, protocol)
+        self.assertTrue(msg.HasField('messageId'))
+        self.assertTrue(msg.HasField('id'))
+        self.assertEquals(msg.id, query.id)                                                    
+        self.assertTrue(msg.HasField('inBytes'))
+        self.assertEquals(msg.inBytes, len(query.to_wire()))           
+        # dnsdist doesn't set the existing EDNS Subnet for now,
+        # although it might be set from Lua
+        # self.assertTrue(msg.HasField('originalRequestorSubnet'))
+        # self.assertEquals(len(msg.originalRequestorSubnet), 4)
+        # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
+
+
+    def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
+
+       if initiator == '127.0.0.1':
+               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)           # testProtobuf()
+       else:
+               self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)        # testLuaProtobuf()
+
+        self.checkProtobufBase(msg, protocol, query, initiator)
+        # dnsdist doesn't fill the responder field for responses
+        # because it doesn't keep the information around.
+        self.assertTrue(msg.HasField('to'))
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
+        self.assertTrue(msg.HasField('question'))
+        self.assertTrue(msg.question.HasField('qClass'))
+        self.assertEquals(msg.question.qClass, qclass)
+        self.assertTrue(msg.question.HasField('qType'))
+        self.assertEquals(msg.question.qClass, qtype)
+        self.assertTrue(msg.question.HasField('qName'))
+        self.assertEquals(msg.question.qName, qname)
+
+
+
+       testList = [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"]
+       listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
+       self.assertEqual(len(listx), 0, "Lists don't match up in Protobuf Query")               # exclusive or of lists should be empty
+
+    def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
+        self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)       
+        self.checkProtobufBase(msg, protocol, response, initiator)
+        self.assertTrue(msg.HasField('response'))                      
+        self.assertTrue(msg.response.HasField('queryTimeSec')) 
+
+       testList = [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"]
+       listx = set(msg.response.tags) ^ set(testList)                                          # only differences will be in new list
+       self.assertEqual(len(listx), 0, "List's don't match up in Protobuf Response")           # exclusive or of lists should be empty
+
+    def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
+        self.assertTrue(record.HasField('class'))
+        self.assertEquals(getattr(record, 'class'), rclass)
+        self.assertTrue(record.HasField('type'))
+        self.assertEquals(record.type, rtype)
+        self.assertTrue(record.HasField('name'))
+        self.assertEquals(record.name, rname)
+        self.assertTrue(record.HasField('ttl'))
+        self.assertEquals(record.ttl, rttl)
+        self.assertTrue(record.HasField('rdata'))
+
+    def testProtobuf(self):
+        """
+        Protobuf: Send data to a protobuf server
+        """
+        name = 'query.protobuf.tests.powerdns.com.'            
+
+
+        target = 'target.protobuf.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    target)
+        response.answer.append(rrset)
+
+        rrset = dns.rrset.from_text(target,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the UDP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)         
+
+        # check the protobuf message corresponding to the UDP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)     # check UDP response
+        self.assertEquals(len(msg.response.rrs), 2)                    
+        rr = msg.response.rrs[0]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
+        self.assertEquals(rr.rdata, target)
+        rr = msg.response.rrs[1]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the TCP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+        # check the protobuf message corresponding to the TCP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)     # check TCP response
+        self.assertEquals(len(msg.response.rrs), 2)
+        rr = msg.response.rrs[0]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
+        self.assertEquals(rr.rdata, target)
+        rr = msg.response.rrs[1]
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
+        self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+    def testLuaProtobuf(self):
+
+        """
+        Protobuf: Check that the Lua callback rewrote the initiator
+        """
+
+        name = 'lua.protobuf.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+
+        self.assertTrue(receivedQuery)                 
+        self.assertTrue(receivedResponse)              
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the UDP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+
+        # check the protobuf message corresponding to the UDP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')                # check UDP Response
+        self.assertEquals(len(msg.response.rrs), 1)
+        for rr in msg.response.rrs:
+            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
+            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # let the protobuf messages the time to get there
+        time.sleep(1)
+
+        # check the protobuf message corresponding to the TCP query
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0')
+
+        # check the protobuf message corresponding to the TCP response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')                # check TCP response
+        self.assertEquals(len(msg.response.rrs), 1)
+        for rr in msg.response.rrs:
+            self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
+            self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
index a79521a6af28c4b5ee6dda8820f1aae156fba15a..a547b2dbdb07f93249f8f981aadeaf97f4d6949a 100644 (file)
@@ -178,7 +178,6 @@ function luaLogForward(dr, pbMsg)
 \r
 \r
 \r
-\r
    local tableTags = {}                 -- create a table\r
    tableTags["Trans"] = "FWD"           -- add transaction type to table\r
 \r
@@ -200,7 +199,7 @@ end
 \r
 function luaLogCache(dr, pbMsg)                -- this is the lua code that executes after a cache hit\r
 \r
-\r
\r
 \r
    local tableTags = {}                 -- create a table\r
    tableTags["Trans"] = "CACHE"         -- add transaction type to table\r
diff --git a/zzz-gca-example/test-protobuf-tag.sh b/zzz-gca-example/test-protobuf-tag.sh
new file mode 100755 (executable)
index 0000000..9a0103c
--- /dev/null
@@ -0,0 +1,14 @@
+cd ../regression-tests.dnsdist
+DNSDISTBIN=../pdns/dnsdistdist/dnsdist ./runtests test_ProtobufTag.py
+
+
+echo "-----------------------------------------------------------"
+echo "-----------------------------------------------------------"
+echo "-----------------------------------------------------------"
+echo "-----------------------------------------------------------"
+echo "-----------------------------------------------------------"
+echo "-----------------------------------------------------------"
+
+cat nosetests.xml
+
+