]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for outgoing DoT support
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 9 Jun 2021 14:50:24 +0000 (16:50 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 26 Aug 2021 14:30:28 +0000 (16:30 +0200)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_OutgoingTLS.py [new file with mode: 0644]

index d71b9d05693d8ce08269f5d5298407eb01c8b10f..479adeaa87f26373b48eff4b6da457ae8fd6f554 100644 (file)
@@ -231,7 +231,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         sock.close()
 
     @classmethod
-    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
+    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None):
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
@@ -248,8 +248,14 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             sys.exit(1)
 
         sock.listen(100)
+        if tlsContext:
+          sock = tlsContext.wrap_socket(sock, server_side=True)
+
         while True:
-            (conn, _) = sock.accept()
+            try:
+              (conn, _) = sock.accept()
+            except ssl.SSLError:
+              continue
             conn.settimeout(5.0)
             data = conn.recv(2)
             if not data:
diff --git a/regression-tests.dnsdist/test_OutgoingTLS.py b/regression-tests.dnsdist/test_OutgoingTLS.py
new file mode 100644 (file)
index 0000000..5f86597
--- /dev/null
@@ -0,0 +1,190 @@
+#!/usr/bin/env python
+import dns
+import ssl
+import threading
+import time
+
+from dnsdisttests import DNSDistTest
+
+class OutgoingTLSTests(object):
+
+    def checkOnlyTLSResponderHit(self):
+        self.assertNotIn('UDP Responder', self._responsesCounter)
+        self.assertNotIn('TCP Responder', self._responsesCounter)
+        self.assertEqual(self._responsesCounter['TLS Responder'], 1)
+        
+    def testUDP(self):
+        """
+        Outgoing TLS: UDP query is sent via TLS
+        """
+        name = 'udp.outgoing-tls.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(receivedResponse, expectedResponse)
+        self.checkOnlyTLSResponderHit()
+
+    def testTCP(self):
+        """
+        Outgoing TLS: TCP query is sent via TLS
+        """
+        name = 'tcp.outgoing-tls.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+        self.assertEqual(query, receivedQuery)
+        self.assertEqual(receivedResponse, expectedResponse)
+        self.checkOnlyTLSResponderHit()
+
+class BrokenOutgoingTLSTests(object):
+
+    def checkNoResponderHit(self):
+        self.assertNotIn('UDP Responder', self._responsesCounter)
+        self.assertNotIn('TCP Responder', self._responsesCounter)
+        self.assertEqual(self._responsesCounter['TLS Responder'], 0)
+
+    def testUDP(self):
+        """
+        Outgoing TLS (broken): UDP query is sent via TLS
+        """
+        name = 'udp.broken-outgoing-tls.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+        self.checkNoResponderHit()
+
+    def testTCP(self):
+        """
+        Outgoing TLS (broken): TCP query is sent via TLS
+        """
+        name = 'tcp.broken-outgoing-tls.test.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse.answer.append(rrset)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertEqual(receivedResponse, None)
+        self.checkNoResponderHit()
+
+class TestOutgoingTLSOpenSSL(DNSDistTest, OutgoingTLSTests):
+    _tlsBackendPort = 10443
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com'}
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching TLS responder..")
+        cls._TLSResponder = threading.Thread(name='TLS Responder', target=cls.TCPResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._TLSResponder.setDaemon(True)
+        cls._TLSResponder.start()
+
+class TestOutgoingTLSGnuTLS(DNSDistTest, OutgoingTLSTests):
+    _tlsBackendPort = 10444
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='powerdns.com'}
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching TLS responder..")
+        cls._TLSResponder = threading.Thread(name='TLS Responder', target=cls.TCPResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._TLSResponder.setDaemon(True)
+        cls._TLSResponder.start()
+
+class TestOutgoingTLSOpenSSLWrongCertName(DNSDistTest, BrokenOutgoingTLSTests):
+    _tlsBackendPort = 10445
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com'}
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching TLS responder..")
+        cls._TLSResponder = threading.Thread(name='TLS Responder', target=cls.TCPResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._TLSResponder.setDaemon(True)
+        cls._TLSResponder.start()
+
+class TestOutgoingTLSGnuTLSWrongCertName(DNSDistTest, BrokenOutgoingTLSTests):
+    _tlsBackendPort = 10446
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=true, caStore='ca.pem', subjectName='not-powerdns.com'}
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching TLS responder..")
+        cls._TLSResponder = threading.Thread(name='TLS Responder', target=cls.TCPResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._TLSResponder.setDaemon(True)
+        cls._TLSResponder.start()
+
+class TestOutgoingTLSOpenSSLWrongCertNameButNoCheck(DNSDistTest, OutgoingTLSTests):
+    _tlsBackendPort = 10447
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", tls='openssl', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com'}
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching TLS responder..")
+        cls._TLSResponder = threading.Thread(name='TLS Responder', target=cls.TCPResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._TLSResponder.setDaemon(True)
+        cls._TLSResponder.start()
+
+class TestOutgoingTLSGnuTLSWrongCertNameButNoCheck(DNSDistTest, OutgoingTLSTests):
+    _tlsBackendPort = 10448
+    _config_params = ['_tlsBackendPort']
+    _config_template = """
+    newServer{address="127.0.0.1:%s", tls='gnutls', validateCertificates=false, caStore='ca.pem', subjectName='not-powerdns.com'}
+    """
+
+    @classmethod
+    def startResponders(cls):
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain('server.chain', 'server.key')
+
+        print("Launching TLS responder..")
+        cls._TLSResponder = threading.Thread(name='TLS Responder', target=cls.TCPResponder, args=[cls._tlsBackendPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, None, tlsContext])
+        cls._TLSResponder.setDaemon(True)
+        cls._TLSResponder.start()