]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add a test for DoH incoming proxy protocol inside of TLS
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 13 Jul 2023 14:15:48 +0000 (16:15 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Sep 2023 08:22:06 +0000 (10:22 +0200)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_ProxyProtocol.py

index e2787d577fd884128fc1426548f7d197703f52ce..c82f2b9ac9ddd08563762c21a92c0cdd720ae197 100644 (file)
@@ -657,7 +657,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
 
     @classmethod
     def recvTCPResponseOverConnection(cls, sock, useQueue=False, timeout=2.0):
-        print("reading data")
         message = None
         data = sock.recv(2)
         if data:
@@ -671,7 +670,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         print(useQueue)
         if useQueue and not cls._fromResponderQueue.empty():
             receivedQuery = cls._fromResponderQueue.get(True, timeout)
-            print("Got from queue")
             print(receivedQuery)
             return (receivedQuery, message)
         else:
@@ -707,7 +705,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         receivedQuery = None
         print(useQueue)
         if useQueue and not cls._fromResponderQueue.empty():
-            print("Got from queue")
             print(receivedQuery)
             receivedQuery = cls._fromResponderQueue.get(True, timeout)
         else:
@@ -991,13 +988,11 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         url = cls.getDOHGetURL(baseurl, query, rawQuery)
 
         if not conn:
-            print('creating a new connection')
             conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
             # this means "really do HTTP/2, not HTTP/1 with Upgrade headers"
             conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE)
 
         if useHTTPS:
-            print("disabling verify")
             conn.setopt(pycurl.SSL_VERIFYPEER, 1)
             conn.setopt(pycurl.SSL_VERIFYHOST, 2)
             if caFile:
@@ -1020,7 +1015,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         receivedQuery = None
         message = None
         cls._response_headers = ''
-        print('performing')
         data = conn.perform_rb()
         cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
         if cls._rcode == 200 and not rawResponse:
index 744d6a7c79e2db7b17645dfb230ddeb44c4a3a93..ee730eacb3dc46f3164313c45d4732fce393d1fb 100644 (file)
@@ -4,6 +4,7 @@ import copy
 import dns
 import selectors
 import socket
+import ssl
 import struct
 import sys
 import threading
@@ -144,16 +145,21 @@ tcpResponder.start()
 
 backgroundThreads = {}
 
-def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort):
+def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort, serverCtx=None, ca=None, sni=None):
     # this responder accepts TCP connections on the listening port,
     # and relay the raw content to a second TCP connection to the
     # forwarding port, after adding a Proxy Protocol v2 payload
     # containing the initial source IP and port, destination IP
     # and port.
     backgroundThreads[threading.get_native_id()] = True
+
     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
     sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+    if serverCtx is not None:
+        sock = serverCtx.wrap_socket(sock, server_side=True)
+
     try:
         sock.bind(("127.0.0.1", listeningPort))
     except socket.error as e:
@@ -161,6 +167,7 @@ def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort):
         sys.exit(1)
     sock.settimeout(0.5)
     sock.listen(100)
+
     while True:
         try:
             (incoming, _) = sock.accept()
@@ -177,6 +184,15 @@ def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort):
         outgoing = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         outgoing.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
         outgoing.settimeout(2.0)
+        if sni:
+            if hasattr(ssl, 'create_default_context'):
+                sslctx = ssl.create_default_context(cafile=ca)
+                if hasattr(sslctx, 'set_alpn_protocols'):
+                    sslctx.set_alpn_protocols(['h2'])
+                outgoing = sslctx.wrap_socket(outgoing, server_hostname=sni)
+            else:
+                outgoing = ssl.wrap_socket(outgoing, ca_certs=ca, cert_reqs=ssl.CERT_REQUIRED)
+
         outgoing.connect(('127.0.0.1', forwardingPort))
 
         outgoing.send(payload)
@@ -473,7 +489,8 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     """
 
     _config_template = """
-    addDOHLocal("127.0.0.1:%s", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
+    addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
+    addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
     setProxyProtocolACL( { "127.0.0.1/32" } )
     newServer{address="127.0.0.1:%d", useProxyProtocol=true}
 
@@ -514,9 +531,9 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     _serverCert = 'server.chain'
     _serverName = 'tls.tests.dnsdist.org'
     _caCert = 'ca.pem'
-    _dohServerPort = 8443
-    _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort))
-    _config_params = ['_dohServerPort', '_serverCert', '_serverKey', '_proxyResponderPort']
+    _dohServerPPOutsidePort = 8443
+    _dohServerPPInsidePort = 9443
+    _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort']
 
     def testNoHeader(self):
         """
@@ -740,11 +757,11 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
           self.assertEqual(receivedResponse, response)
           self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
 
-    def testProxyDoHSeveralQueriesOverConnection(self):
+    def testProxyDoHSeveralQueriesOverConnectionPPOutside(self):
         """
-        Incoming Proxy Protocol: Several queries over the same connection (DoH)
+        Incoming Proxy Protocol: Several queries over the same connection (DoH, PP outside TLS)
         """
-        name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.'
+        name = 'several-queries.doh-outside.proxy-protocol-incoming.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         response = dns.message.make_response(query)
 
@@ -753,10 +770,63 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         wire = query.to_wire()
 
         reverseProxyPort = 13053
-        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPort])
+        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPOutsidePort])
+        reverseProxy.start()
+
+        receivedResponse = None
+        conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
+
+        reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+        for idx in range(5):
+          receivedResponse = None
+          toProxyQueue.put(response, True, 2.0)
+          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          self.assertTrue(receivedResponse)
+
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          receivedQuery.id = query.id
+          receivedResponse.id = response.id
+          self.assertEqual(receivedQuery, query)
+          self.assertEqual(receivedResponse, response)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+    def testProxyDoHSeveralQueriesOverConnectionPPInside(self):
+        """
+        Incoming Proxy Protocol: Several queries over the same connection (DoH, PP inside TLS)
+        """
+        name = 'several-queries.doh-inside.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        reverseProxyPort = 14053
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain(self._serverCert, self._serverKey)
+        tlsContext.set_alpn_protocols(['h2'])
+        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPInsidePort, tlsContext, self._caCert, self._serverName])
         reverseProxy.start()
 
         receivedResponse = None
+        time.sleep(1)
         conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
 
         reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
@@ -786,8 +856,6 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
           receivedQuery.id = query.id
           receivedResponse.id = response.id
           self.assertEqual(receivedQuery, query)
-          print(receivedResponse)
-          print(response)
           self.assertEqual(receivedResponse, response)
           self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)