]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regression tests for the new TCP/TLS DoS mitigation options
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 31 Mar 2025 14:19:31 +0000 (16:19 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 31 Mar 2025 14:19:31 +0000 (16:19 +0200)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_TCPLimits.py

index a315d8a196b59f92189931b8edfbf07489f7b331..e191dcf81aa9bb59d2dbeea444b2edb31e732cb1 100644 (file)
@@ -660,7 +660,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return sock
 
     @classmethod
-    def openTLSConnection(cls, port, serverName, caCert=None, timeout=2.0, alpn=[]):
+    def openTLSConnection(cls, port, serverName, caCert=None, timeout=2.0, alpn=[], sslctx=None, session=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         if timeout:
@@ -668,10 +668,11 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
         # 2.7.9+
         if hasattr(ssl, 'create_default_context'):
-            sslctx = ssl.create_default_context(cafile=caCert)
-            if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
-              sslctx.set_alpn_protocols(alpn)
-            sslsock = sslctx.wrap_socket(sock, server_hostname=serverName)
+            if not sslctx:
+                sslctx = ssl.create_default_context(cafile=caCert)
+                if len(alpn)> 0 and hasattr(sslctx, 'set_alpn_protocols'):
+                    sslctx.set_alpn_protocols(alpn)
+            sslsock = sslctx.wrap_socket(sock, server_hostname=serverName, session=session)
         else:
             sslsock = ssl.wrap_socket(sock, ca_certs=caCert, cert_reqs=ssl.CERT_REQUIRED)
 
index 4567b246610749c873089244df41d03b90a4c19d..a549a273ed9771842429f180c65e9c12e3449592 100644 (file)
@@ -1,4 +1,5 @@
 #!/usr/bin/env python
+import ssl
 import struct
 import time
 import dns
@@ -30,7 +31,6 @@ class TestTCPLimits(DNSDistTest):
     setTCPConnectionsOverloadThreshold(0)
     """
     _config_params = ['_testServerPort', '_tcpIdleTimeout', '_maxTCPQueriesPerConn', '_maxTCPConnsPerClient', '_maxTCPConnDuration']
-    _verboseMode = True
 
     def testTCPQueriesPerConn(self):
         """
@@ -132,6 +132,220 @@ class TestTCPLimits(DNSDistTest):
 
         conn.close()
 
+class TestTCPLimitsReadIO(DNSDistTest):
+
+    # separate test suite because we get banned for a few seconds
+    _testServerPort = pickAvailablePort()
+    _answerUnexpected = True
+
+    _tcpIdleTimeout = 2
+    _maxTCPReadIOsPerQuery = 10
+    _banDuration = 2
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+    setTCPRecvTimeout(%d)
+    setMaxTCPReadIOsPerQuery(%d)
+    setBanDurationForExceedingMaxReadIOsPerQuery(%d)
+    -- disable "near limits" otherwise our tests are broken because connections are forcibly closed
+    setTCPConnectionsOverloadThreshold(0)
+    """
+    _config_params = ['_testServerPort', '_tcpIdleTimeout', '_maxTCPReadIOsPerQuery', '_banDuration']
+
+    def testTCPMaxReadIOsPerQuery(self):
+        """
+        TCP Limits: Maximum number of IO read events per query
+        """
+        name = 'maxreadios.tcp.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        payload = query.to_wire()
+        self.assertGreater(len(payload), self._maxTCPReadIOsPerQuery)
+
+        conn = self.openTCPConnection()
+
+        count = 0
+        failed = False
+        while count < len(payload):
+            try:
+                conn.send(payload[count].to_bytes())
+                count = count + 1
+            except Exception as e:
+                failed = True
+                break
+
+        if not failed:
+            try:
+                response = self.recvTCPResponseOverConnection(conn)
+            except:
+                failed = True
+
+        conn.close()
+        self.assertTrue(failed)
+
+        # and we should be banned now
+        failed = False
+        try:
+            conn = self.openTCPConnection()
+            response = self.recvTCPResponseOverConnection(conn)
+            if response is None:
+              failed = True
+        except Exception as e:
+            failed = True
+        finally:
+            conn.close()
+
+        self.assertTrue(failed)
+
+class TestTCPLimitsConnectionRate(DNSDistTest):
+
+    # separate test suite because we get banned for a few seconds
+    _testServerPort = pickAvailablePort()
+    _answerUnexpected = True
+    _maxConnectionRate = 10
+    _tcpIdleTimeout = 2
+    _banDuration = 2
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+    setTCPRecvTimeout(%d)
+    setMaxTCPConnectionRatePerClient(%d)
+    setBanDurationForExceedingTCPTLSRate(%d)
+    -- disable "near limits" otherwise our tests are broken because connections are forcibly closed
+    setTCPConnectionsOverloadThreshold(0)
+    """
+    _config_params = ['_testServerPort', '_tcpIdleTimeout', '_maxConnectionRate', '_banDuration']
+    _verboseMode = True
+
+    def testTCPConnectionRate(self):
+        """
+        TCP Limits: Maximum connection rate
+        """
+        name = 'maxconnectionrate.tcp.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        # _maxConnectionRate connections in a row
+        for idx in range(self._maxConnectionRate):
+            (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response=response)
+            receivedQuery.id = query.id
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, response)
+        # the next one should be past the max rate
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertEqual(receivedQuery, None)
+        self.assertEqual(receivedResponse, None)
+
+class TestTCPLimitsTLSNewSessionRate(DNSDistTest):
+    # separate test suite because we get banned for a few seconds
+    _testServerPort = pickAvailablePort()
+    _tlsServerPort = pickAvailablePort()
+    _answerUnexpected = True
+    _maxNewTLSSessionRate = 10
+    _tcpIdleTimeout = 2
+    _banDuration = 2
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _tlsServerPort = pickAvailablePort()
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+    setTCPRecvTimeout(%d)
+    setMaxTLSNewSessionRatePerClient(%d)
+    setBanDurationForExceedingTCPTLSRate(%d)
+    addTLSLocal("127.0.0.1:%d", "%s", "%s")
+
+    -- disable "near limits" otherwise our tests are broken because connections are forcibly closed
+    setTCPConnectionsOverloadThreshold(0)
+    """
+    _config_params = ['_testServerPort', '_tcpIdleTimeout', '_maxNewTLSSessionRate', '_banDuration', '_tlsServerPort', '_serverCert', '_serverKey']
+    _verboseMode = True
+
+    def testTLSNewSessionRate(self):
+        """
+        TCP Limits: Maximum TLS new session rate
+        """
+        name = 'maxtlsnewsessionrate.tcp.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        # _maxNewTLSSessionRate connections in a row, plus one because
+        # the session is only accounted for once the handshake has been completed
+        for idx in range(self._maxNewTLSSessionRate + 1):
+            (receivedQuery, receivedResponse) = self.sendDOTQueryWrapper(query, response=response)
+            receivedQuery.id = query.id
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, response)
+
+        try:
+            # the next one should be past the max rate
+            self.sendDOTQueryWrapper(query, response=None, useQueue=False)
+            self.assertTrue(False)
+        except ConnectionResetError:
+          pass
+
+class TestTCPLimitsTLSResumedSessionRate(DNSDistTest):
+    # separate test suite because we get banned for a few seconds
+    _testServerPort = pickAvailablePort()
+    _tlsServerPort = pickAvailablePort()
+    _answerUnexpected = True
+    _maxNewTLSSessionRate = 1
+    _maxResumedTLSSessionRate = 10
+    _tcpIdleTimeout = 2
+    _banDuration = 2
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _tlsServerPort = pickAvailablePort()
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+    setTCPRecvTimeout(%d)
+    setMaxTLSNewSessionRatePerClient(%d)
+    setMaxTLSResumedSessionRatePerClient(%d)
+    setBanDurationForExceedingTCPTLSRate(%d)
+    addTLSLocal("127.0.0.1:%d", "%s", "%s")
+
+    -- disable "near limits" otherwise our tests are broken because connections are forcibly closed
+    setTCPConnectionsOverloadThreshold(0)
+    """
+    _config_params = ['_testServerPort', '_tcpIdleTimeout', '_maxNewTLSSessionRate', '_maxResumedTLSSessionRate', '_banDuration', '_tlsServerPort', '_serverCert', '_serverKey']
+    _verboseMode = True
+
+    def testTLSResumedSessionRate(self):
+        """
+        TCP Limits: Maximum TLS resumed session rate
+        """
+        name = 'maxtlsresumedsessionrate.tcp.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        session = None
+        sslctx = ssl.create_default_context(cafile=self._caCert)
+
+        # _maxResumedTLSSessionRate connections in a row, plus two because
+        # - the first one is a new TLS session
+        # - the session is only accounted for once the handshake has been completed
+        for idx in range(self._maxResumedTLSSessionRate + 2):
+            conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert, timeout=1, sslctx=sslctx, session=session)
+            self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1)
+            (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1)
+            receivedQuery.id = query.id
+            self.assertEqual(receivedQuery, query)
+            self.assertEqual(receivedResponse, response)
+            if idx == 0:
+                self.assertFalse(conn.session_reused)
+                session = conn.session
+            else:
+                self.assertTrue(conn.session_reused)
+
+        try:
+            # the next one should be past the max rate
+            conn = self.openTLSConnection(self._tlsServerPort, self._serverName, self._caCert, timeout=1, sslctx=sslctx, session=session)
+            self.sendTCPQueryOverConnection(conn, query, response=response, timeout=1)
+            self.recvTCPResponseOverConnection(conn, useQueue=True, timeout=1)
+            self.assertTrue(False)
+        except ConnectionResetError:
+          pass
+
 class TestTCPFrontendLimits(DNSDistTest):
 
     # this test suite uses a different responder port
@@ -149,7 +363,6 @@ class TestTCPFrontendLimits(DNSDistTest):
     setTCPConnectionsOverloadThreshold(0)
     """
     _config_params = ['_testServerPort', '_dnsDistListeningAddr', '_dnsDistPort', '_maxTCPConnsPerFrontend']
-    _verboseMode = True
 
     def testTCPConnsPerFrontend(self):
         """