]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.recursor-dnssec/test_Protobuf.py
Merge pull request #13481 from berthubert/lua-localwho
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_Protobuf.py
index 5c6fd0b451608554fd13b96a1d543eaf1ba1683a..953a9ce20ead2a57aa4516deb8597f2b72ee935a 100644 (file)
@@ -94,19 +94,19 @@ class TestRecursorProtobuf(RecursorTest):
     def getFirstProtobufMessage(self, retries=1, waitTime=1):
         msg = None
 
-        print("in getFirstProtobufMessage")
+        #print("in getFirstProtobufMessage")
         for param in protobufServersParameters:
-          print(param.port)
+          #print(param.port)
           failed = 0
 
-          while param.queue.empty:
-            print(failed)
-            print(retries)
+          while param.queue.empty():
+            #print(failed)
+            #print(retries)
             if failed >= retries:
               break
 
             failed = failed + 1
-            print("waiting")
+            #print("waiting")
             time.sleep(waitTime)
 
           self.assertFalse(param.queue.empty())
@@ -118,7 +118,7 @@ class TestRecursorProtobuf(RecursorTest):
           if oldmsg is not None:
             self.assertEqual(msg, oldmsg)
 
-        print(msg)
+        #print(msg)
         return msg
 
     def emptyProtoBufQueue(self):
@@ -232,17 +232,17 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertEqual(msg.response.appliedPolicyKind, kind)
 
     def checkProtobufTags(self, msg, tags):
-        print(tags)
-        print('---')
-        print(msg.response.tags)
+        #print(tags)
+        #print('---')
+        #print(msg.response.tags)
         self.assertEqual(len(msg.response.tags), len(tags))
         for tag in msg.response.tags:
             self.assertTrue(tag in tags)
 
     def checkProtobufMetas(self, msg, metas):
-        print(metas)
-        print('---')
-        print(msg.meta)
+        #print(metas)
+        #print('---')
+        #print(msg.meta)
         self.assertEqual(len(msg.meta), len(metas))
         for m in msg.meta:
             self.assertTrue(m.HasField('key'))
@@ -277,7 +277,7 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertEqual(msg.response.rcode, 65536)
 
     def checkProtobufIdentity(self, msg, requestorId, deviceId, deviceName):
-        print(msg)
+        #print(msg)
         self.assertTrue((requestorId == '') == (not msg.HasField('requestorId')))
         self.assertTrue((deviceId == b'') == (not msg.HasField('deviceId')))
         self.assertTrue((deviceName == '') == (not msg.HasField('deviceName')))
@@ -957,9 +957,85 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
         self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
         tags = [ self._tag_from_gettag ] + self._tags
+        #print(msg)
         self.checkProtobufTags(msg, tags)
         self.checkNoRemainingMessage()
 
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        self.checkProtobufTags(msg, [ self._tag_from_gettag ])
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        tags = [ self._tag_from_gettag ] + self._tags
+        self.checkProtobufTags(msg, tags)
+        self.checkNoRemainingMessage()
+
+class ProtobufTagCacheTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly cache tags (actually not cache them)
+    """
+
+    _confdir = 'ProtobufTagCache'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+    _lua_dns_script_file = """
+    function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp)
+      if qname:equal('tagged.example.') then
+        return 0, { '' .. math.random() }
+      end
+      return 0
+    end
+    """
+
+    def testTagged(self):
+        name = 'tagged.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
 class ProtobufSelectedFromLuaTest(TestRecursorProtobuf):
     """
     This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.