]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_Protobuf.py
dnsdist: Add a new response chain for XFR responses
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Protobuf.py
index 5f65fd31a9103cd83dcd6f6a34e24a33b78cbb15..a7d11951203622900d4092f45ba0bd9f7eca2b28 100644 (file)
@@ -871,3 +871,133 @@ class TestProtobufQUIC(DNSDistProtobufTest):
                 self.assertEqual(msg.httpVersion, dnsmessage_pb2.PBDNSMessage.HTTPVersion.HTTP3)
 
             self.checkProtobufQuery(msg, pbMessageType, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+class TestProtobufAXFR(DNSDistProtobufTest):
+    # this test suite uses a different responder port
+    # because, contrary to the other ones, its
+    # TCP responder allows multiple responses and we don't want
+    # to mix things up.
+    _testServerPort = pickAvailablePort()
+
+    @classmethod
+    def startResponders(cls):
+        print("Launching responders..")
+
+        cls._UDPResponder = threading.Thread(name='UDP Protobuf AXFR Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
+        cls._UDPResponder.daemon = True
+        cls._UDPResponder.start()
+        cls._TCPResponder = threading.Thread(name='TCP Protobuf AXFR Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, True, None, None, True])
+        cls._TCPResponder.daemon = True
+        cls._TCPResponder.start()
+        cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
+        cls._protobufListener.daemon = True
+        cls._protobufListener.start()
+
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+    rl = newRemoteLogger('127.0.0.1:%d')
+
+    addXFRResponseAction(AllRule(), RemoteLogResponseAction(rl, nil, false, {serverID='dnsdist-server-1'}))
+    """
+    _config_params = ['_testServerPort', '_protobufServerPort']
+
+    def testProtobufAXFR(self):
+        """
+        Protobuf: Check the logging of multiple messages for AXFR responses
+        """
+        # first query is NOT an AXFR, we should not log anything
+        name = 'axfr.protobuf.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        ttl = 60
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    ttl,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)
+
+        self.assertTrue(self._protobufQueue.empty())
+
+        query = dns.message.make_query(name, 'AXFR', 'IN')
+        responses = []
+        soa = dns.rrset.from_text(name,
+                                  ttl,
+                                  dns.rdataclass.IN,
+                                  dns.rdatatype.SOA,
+                                  'ns.' + name + ' hostmaster.' + name + ' 1 3600 3600 3600 60')
+        response = dns.message.make_response(query)
+        response.answer.append(soa)
+        responses.append(response)
+
+        response = dns.message.make_response(query)
+        response.answer.append(dns.rrset.from_text(name,
+                                                   ttl,
+                                                   dns.rdataclass.IN,
+                                                   dns.rdatatype.A,
+                                                   '192.0.2.1'))
+        responses.append(response)
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    ttl,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:db8::1')
+        response.answer.append(rrset)
+        responses.append(response)
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    ttl,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.TXT,
+                                    'dummy')
+        response.answer.append(rrset)
+        responses.append(response)
+
+        response = dns.message.make_response(query)
+        response.answer.append(soa)
+        responses.append(response)
+
+        # UDP would not make sense since it does not support multiple messages
+        (receivedQuery, receivedResponses) = self.sendTCPQueryWithMultipleResponses(query, responses)
+
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponses)
+        receivedQuery.id = query.id
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(len(receivedResponses), len(responses))
+
+        if self._protobufQueue.empty():
+            # let the protobuf messages the time to get there
+            time.sleep(1)
+
+        # check the protobuf messages corresponding to the responses
+        count = 0
+        while not self._protobufQueue.empty():
+            msg = self.getFirstProtobufMessage()
+            count = count + 1
+            pbMessageType = dnsmessage_pb2.PBDNSMessage.TCP
+            self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, responses[count-1])
+
+            expected = responses[count-1].answer[0]
+            if expected.rdtype in [dns.rdatatype.A, dns.rdatatype.AAAA]:
+                rr = msg.response.rrs[0]
+                self.checkProtobufResponseRecord(rr, expected.rdclass, expected.rdtype, name, ttl)
+                if expected.rdtype == dns.rdatatype.A:
+                    self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.1')
+                else:
+                    self.assertEqual(socket.inet_ntop(socket.AF_INET6, rr.rdata), '2001:db8::1')
+
+        self.assertEqual(count, len(responses))