]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.recursor-dnssec/test_Protobuf.py
Merge pull request #13481 from berthubert/lua-localwho
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_Protobuf.py
index 134b84b75c0c5468cd1d3ffc899a89543df374c6..953a9ce20ead2a57aa4516deb8597f2b72ee935a 100644 (file)
@@ -6,6 +6,7 @@ import struct
 import sys
 import threading
 import time
+import clientsubnetoption
 
 # Python2/3 compatibility hacks
 try:
@@ -93,19 +94,19 @@ class TestRecursorProtobuf(RecursorTest):
     def getFirstProtobufMessage(self, retries=1, waitTime=1):
         msg = None
 
-        print("in getFirstProtobufMessage")
+        #print("in getFirstProtobufMessage")
         for param in protobufServersParameters:
-          print(param.port)
+          #print(param.port)
           failed = 0
 
-          while param.queue.empty:
-            print(failed)
-            print(retries)
+          while param.queue.empty():
+            #print(failed)
+            #print(retries)
             if failed >= retries:
               break
 
             failed = failed + 1
-            print("waiting")
+            #print("waiting")
             time.sleep(waitTime)
 
           self.assertFalse(param.queue.empty())
@@ -117,9 +118,14 @@ class TestRecursorProtobuf(RecursorTest):
           if oldmsg is not None:
             self.assertEqual(msg, oldmsg)
 
-        print(msg)
+        #print(msg)
         return msg
 
+    def emptyProtoBufQueue(self):
+        for param in protobufServersParameters:
+            while not param.queue.empty():
+                param.queue.get(False)
+
     def checkNoRemainingMessage(self):
         for param in protobufServersParameters:
           self.assertTrue(param.queue.empty())
@@ -154,7 +160,7 @@ class TestRecursorProtobuf(RecursorTest):
             self.assertEqual(len(msg.originalRequestorSubnet), 4)
             self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
 
-    def checkOutgoingProtobufBase(self, msg, protocol, query, initiator, length=None):
+    def checkOutgoingProtobufBase(self, msg, protocol, query, initiator, length=None, expectedECS=None):
         self.assertTrue(msg)
         self.assertTrue(msg.HasField('timeSec'))
         self.assertTrue(msg.HasField('socketFamily'))
@@ -171,6 +177,11 @@ class TestRecursorProtobuf(RecursorTest):
         else:
           # compare inBytes with length of query/response
           self.assertEqual(msg.inBytes, len(query.to_wire()))
+        if expectedECS is not None:
+            self.assertTrue(msg.HasField('originalRequestorSubnet'))
+            # v4 only for now
+            self.assertEqual(len(msg.originalRequestorSubnet), 4)
+            self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), expectedECS)
 
     def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1', to='127.0.0.1'):
         self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
@@ -221,17 +232,17 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertEqual(msg.response.appliedPolicyKind, kind)
 
     def checkProtobufTags(self, msg, tags):
-        print(tags)
-        print('---')
-        print(msg.response.tags)
+        #print(tags)
+        #print('---')
+        #print(msg.response.tags)
         self.assertEqual(len(msg.response.tags), len(tags))
         for tag in msg.response.tags:
             self.assertTrue(tag in tags)
 
     def checkProtobufMetas(self, msg, metas):
-        print(metas)
-        print('---')
-        print(msg.meta)
+        #print(metas)
+        #print('---')
+        #print(msg.meta)
         self.assertEqual(len(msg.meta), len(metas))
         for m in msg.meta:
             self.assertTrue(m.HasField('key'))
@@ -242,9 +253,9 @@ class TestRecursorProtobuf(RecursorTest):
             for s in m.value.stringVal :
               self.assertTrue(s in metas[m.key]['stringVal'])
 
-    def checkProtobufOutgoingQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1', length=None):
+    def checkProtobufOutgoingQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1', length=None, expectedECS=None):
         self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSOutgoingQueryType)
-        self.checkOutgoingProtobufBase(msg, protocol, query, initiator, length=length)
+        self.checkOutgoingProtobufBase(msg, protocol, query, initiator, length=length, expectedECS=expectedECS)
         self.assertTrue(msg.HasField('to'))
         self.assertTrue(msg.HasField('question'))
         self.assertTrue(msg.question.HasField('qClass'))
@@ -266,7 +277,7 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertEqual(msg.response.rcode, 65536)
 
     def checkProtobufIdentity(self, msg, requestorId, deviceId, deviceName):
-        print(msg)
+        #print(msg)
         self.assertTrue((requestorId == '') == (not msg.HasField('requestorId')))
         self.assertTrue((deviceId == b'') == (not msg.HasField('deviceId')))
         self.assertTrue((deviceName == '') == (not msg.HasField('deviceName')))
@@ -278,9 +289,7 @@ class TestRecursorProtobuf(RecursorTest):
         super(TestRecursorProtobuf, self).setUp()
         # Make sure the queue is empty, in case
         # a previous test failed
-        for param in protobufServersParameters:
-            while not param.queue.empty():
-                param.queue.get(False)
+        self.emptyProtoBufQueue()
         # wait long enough to be sure that the housekeeping has
         # prime the root NS
         time.sleep(1)
@@ -569,17 +578,25 @@ class OutgoingProtobufDefaultTest(TestRecursorProtobuf):
     # Switch off QName Minimization, it generates much more protobuf messages
     # (or make the test much more smart!)
     qname-minimization=no
+    max-cache-ttl=600
+    loglevel=9
 """
     _lua_config_file = """
     outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
     """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
 
     def testA(self):
+        # There is a race in priming (having the . DNSKEY in cache in particular) and this code.
+        # So make sure we have the . DNSKEY in cache
+        query = dns.message.make_query('.', 'A', want_dnssec=True)
+        query.flags |= dns.flags.RD
+        res = self.sendUDPQuery(query)
+        time.sleep(1)
+        self.emptyProtoBufQueue()
+
         name = 'host1.secure.example.'
         expected = list()
 
-        # the root DNSKEY has been learned with priming the root NS already
-        # ('.', dns.rdatatype.DNSKEY, dnsmessage_pb2.PBDNSMessage.UDP, 201),
         for qname, qtype, proto, responseSize in [
                 ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 248),
                 ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 221),
@@ -615,6 +632,111 @@ class OutgoingProtobufDefaultTest(TestRecursorProtobuf):
 
         self.checkNoRemainingMessage()
 
+class OutgoingProtobufWithECSMappingTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly export outgoing queries over protobuf.
+    It must be improved and setup env so we can check for incoming responses, but makes sure for now
+    that the recursor at least connects to the protobuf server.
+    """
+
+    _confdir = 'OutgoingProtobuffWithECSMapping'
+    _config_template = """
+    # Switch off QName Minimization, it generates much more protobuf messages
+    # (or make the test much more smart!)
+    qname-minimization=no
+    edns-subnet-allow-list=example
+    allow-from=1.2.3.4/32
+    # this is to not let . queries interfere
+    max-cache-ttl=600
+    loglevel=9
+"""
+    _lua_config_file = """
+    outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"})
+    addProxyMapping("127.0.0.0/8", "1.2.3.4", { "host1.secure.example." })
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+
+    def testA(self):
+        # There is a race in priming (having the . DNSKEY in cache in particular) and this code.
+        # So make sure we have the . DNSKEY in cache
+        query = dns.message.make_query('.', 'A', want_dnssec=True)
+        query.flags |= dns.flags.RD
+        res = self.sendUDPQuery(query)
+        time.sleep(1)
+        self.emptyProtoBufQueue()
+
+        name = 'host1.secure.example.'
+        expected = list()
+
+        for qname, qtype, proto, responseSize, ecs in [
+                ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 248, "1.2.3.0"),
+                ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 221, "1.2.3.0"),
+                ('example.', dns.rdatatype.DNSKEY, dnsmessage_pb2.PBDNSMessage.UDP, 219, "1.2.3.0"),
+                ('host1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 175, "1.2.3.0"),
+                ('secure.example.', dns.rdatatype.DNSKEY, dnsmessage_pb2.PBDNSMessage.UDP, 233, "1.2.3.0"),
+        ]:
+            if not qname:
+                expected.append((None, None, None, None, None, None, None))
+                continue
+            ecso = clientsubnetoption.ClientSubnetOption('9.10.11.12', 24)
+            query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True, options=[ecso], payload=512)
+            resp = dns.message.make_response(query)
+            expected.append((
+                qname, qtype, query, resp, proto, responseSize, ecs
+            ))
+
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.RD
+        res = self.sendUDPQuery(query)
+
+        for qname, qtype, qry, ans, proto, responseSize, ecs in expected:
+            if not qname:
+                self.getFirstProtobufMessage()
+                self.getFirstProtobufMessage()
+                continue
+
+            msg = self.getFirstProtobufMessage()
+            self.checkProtobufOutgoingQuery(msg, proto, qry, dns.rdataclass.IN, qtype, qname, "127.0.0.1", None, ecs)
+            # Check the answer
+            msg = self.getFirstProtobufMessage()
+            self.checkProtobufIncomingResponse(msg, proto, ans, length=responseSize)
+
+        self.checkNoRemainingMessage()
+
+        # this query should use the unmapped ECS
+        name = 'mx1.secure.example.'
+        expected = list()
+
+        for qname, qtype, proto, responseSize, ecs in [
+                ('mx1.secure.example.', dns.rdatatype.A, dnsmessage_pb2.PBDNSMessage.UDP, 173, "127.0.0.1"),
+        ]:
+            if not qname:
+                expected.append((None, None, None, None, None, None, None))
+                continue
+            ecso = clientsubnetoption.ClientSubnetOption('127.0.0.1', 32)
+            query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True, options=[ecso], payload=512)
+            resp = dns.message.make_response(query)
+            expected.append((
+                qname, qtype, query, resp, proto, responseSize, ecs
+            ))
+
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.RD
+        res = self.sendUDPQuery(query)
+
+        for qname, qtype, qry, ans, proto, responseSize, ecs in expected:
+            if not qname:
+                self.getFirstProtobufMessage()
+                self.getFirstProtobufMessage()
+                continue
+
+            msg = self.getFirstProtobufMessage()
+            self.checkProtobufOutgoingQuery(msg, proto, qry, dns.rdataclass.IN, qtype, qname, "127.0.0.1", None, ecs)
+            # Check the answer
+            msg = self.getFirstProtobufMessage()
+            self.checkProtobufIncomingResponse(msg, proto, ans, length=responseSize)
+
+        self.checkNoRemainingMessage()
+
 class OutgoingProtobufNoQueriesTest(TestRecursorProtobuf):
     """
     This test makes sure that we correctly export incoming responses but not outgoing queries over protobuf.
@@ -626,12 +748,23 @@ class OutgoingProtobufNoQueriesTest(TestRecursorProtobuf):
     _config_template = """
     # Switch off QName Minimization, it generates much more protobuf messages
     # (or make the test much more smart!)
-    qname-minimization=no"""
+    qname-minimization=no
+    max-cache-ttl=600
+    loglevel=9
+"""
     _lua_config_file = """
     outgoingProtobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true })
     """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
 
     def testA(self):
+        # There is a race in priming (having the . DNSKEY in cache in particular) and this code.
+        # So make sure we have the . DNSKEY in cache
+        query = dns.message.make_query('.', 'A', want_dnssec=True)
+        query.flags |= dns.flags.RD
+        res = self.sendUDPQuery(query)
+        time.sleep(1)
+        self.emptyProtoBufQueue()
+
         name = 'host1.secure.example.'
         expected = list()
         # the root DNSKEY has been learned with priming the root NS already
@@ -824,9 +957,85 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
         self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
         tags = [ self._tag_from_gettag ] + self._tags
+        #print(msg)
         self.checkProtobufTags(msg, tags)
         self.checkNoRemainingMessage()
 
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        self.checkProtobufTags(msg, [ self._tag_from_gettag ])
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        tags = [ self._tag_from_gettag ] + self._tags
+        self.checkProtobufTags(msg, tags)
+        self.checkNoRemainingMessage()
+
+class ProtobufTagCacheTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly cache tags (actually not cache them)
+    """
+
+    _confdir = 'ProtobufTagCache'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+    _lua_dns_script_file = """
+    function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp)
+      if qname:equal('tagged.example.') then
+        return 0, { '' .. math.random() }
+      end
+      return 0
+    end
+    """
+
+    def testTagged(self):
+        name = 'tagged.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
 class ProtobufSelectedFromLuaTest(TestRecursorProtobuf):
     """
     This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.