]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add a test for DoH incoming proxy protocol outside of TLS
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 13 Jul 2023 12:45:38 +0000 (14:45 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 7 Sep 2023 08:22:06 +0000 (10:22 +0200)
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_ProxyProtocol.py

index 75068bea027d1a637482a8dc8846b21bf0677d77..e2787d577fd884128fc1426548f7d197703f52ce 100644 (file)
@@ -987,7 +987,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         return conn
 
     @classmethod
-    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, useProxyProtocol=False, conn=None):
+    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None):
         url = cls.getDOHGetURL(baseurl, query, rawQuery)
 
         if not conn:
@@ -1003,11 +1003,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
             if caFile:
                 conn.setopt(pycurl.CAINFO, caFile)
 
-        if useProxyProtocol:
-            print('enabling PP')
-            # 274 is CURLOPT_HAPROXYPROTOCOL
-            conn.setopt(274, 1)
-
         response_headers = BytesIO()
         #conn.setopt(pycurl.VERBOSE, True)
         conn.setopt(pycurl.URL, url)
@@ -1088,8 +1083,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         cls._response_headers = response_headers.getvalue()
         return (receivedQuery, message)
 
-    def sendDOHQueryWrapper(self, query, response, useQueue=True, useProxyProtocol=False):
-        return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, useProxyProtocol=useProxyProtocol)
+    def sendDOHQueryWrapper(self, query, response, useQueue=True):
+        return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
 
     def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True):
         return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue)
index 2ff4a2d741003bfb2a0a6031a35c62d1db188971..744d6a7c79e2db7b17645dfb230ddeb44c4a3a93 100644 (file)
@@ -2,6 +2,7 @@
 
 import copy
 import dns
+import selectors
 import socket
 import struct
 import sys
@@ -141,6 +142,80 @@ tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=Prox
 tcpResponder.daemon = True
 tcpResponder.start()
 
+backgroundThreads = {}
+
+def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort):
+    # this responder accepts TCP connections on the listening port,
+    # and relay the raw content to a second TCP connection to the
+    # forwarding port, after adding a Proxy Protocol v2 payload
+    # containing the initial source IP and port, destination IP
+    # and port.
+    backgroundThreads[threading.get_native_id()] = True
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+    sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+    try:
+        sock.bind(("127.0.0.1", listeningPort))
+    except socket.error as e:
+        print("Error binding in the Mock TCP reverse proxy: %s" % str(e))
+        sys.exit(1)
+    sock.settimeout(0.5)
+    sock.listen(100)
+    while True:
+        try:
+            (incoming, _) = sock.accept()
+        except socket.timeout:
+            if backgroundThreads.get(threading.get_native_id(), False) == False:
+                del backgroundThreads[threading.get_native_id()]
+                break
+            else:
+              continue
+
+        incoming.settimeout(5.0)
+        payload = ProxyProtocol.getPayload(False, True, False, '127.0.0.1', '127.0.0.1', incoming.getpeername()[1], listeningPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+
+        outgoing = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        outgoing.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+        outgoing.settimeout(2.0)
+        outgoing.connect(('127.0.0.1', forwardingPort))
+
+        outgoing.send(payload)
+
+        sel = selectors.DefaultSelector()
+        def readFromClient(conn):
+            data = conn.recv(512)
+            if not data or len(data) == 0:
+              return False
+            outgoing.send(data)
+            return True
+
+        def readFromBackend(conn):
+            data = conn.recv(512)
+            if not data or len(data) == 0:
+              return False
+            incoming.send(data)
+            return True
+
+        sel.register(incoming, selectors.EVENT_READ, readFromClient)
+        sel.register(outgoing, selectors.EVENT_READ, readFromBackend)
+        done = False
+        while not done:
+          try:
+            events = sel.select()
+            for key, mask in events:
+              if not (key.data)(key.fileobj):
+                done = True
+                break
+          except socket.timeout:
+            break
+          except:
+            break
+
+        incoming.close()
+        outgoing.close()
+
+    sock.close()
+
 class ProxyProtocolTest(DNSDistTest):
     _proxyResponderPort = proxyResponderPort
     _config_params = ['_proxyResponderPort']
@@ -398,6 +473,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     """
 
     _config_template = """
+    addDOHLocal("127.0.0.1:%s", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
     setProxyProtocolACL( { "127.0.0.1/32" } )
     newServer{address="127.0.0.1:%d", useProxyProtocol=true}
 
@@ -434,8 +510,13 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     -- override all existing values
     addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"}))
     """
-    _config_params = ['_proxyResponderPort']
-    _verboseMode = True
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _dohServerPort = 8443
+    _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort))
+    _config_params = ['_dohServerPort', '_serverCert', '_serverKey', '_proxyResponderPort']
 
     def testNoHeader(self):
         """
@@ -445,9 +526,12 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
 
-        for method in ("sendUDPQuery", "sendTCPQuery"):
+        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOHQueryWrapper"):
           sender = getattr(self, method)
-          (_, receivedResponse) = sender(query, response=None)
+          try:
+            (_, receivedResponse) = sender(query, response=None)
+          except Exception:
+            receivedResponse = None
           self.assertEqual(receivedResponse, None)
 
     def testIncomingProxyDest(self):
@@ -656,6 +740,66 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
           self.assertEqual(receivedResponse, response)
           self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
 
+    def testProxyDoHSeveralQueriesOverConnection(self):
+        """
+        Incoming Proxy Protocol: Several queries over the same connection (DoH)
+        """
+        name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        reverseProxyPort = 13053
+        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPort])
+        reverseProxy.start()
+
+        receivedResponse = None
+        conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
+
+        reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+        for idx in range(5):
+          receivedResponse = None
+          toProxyQueue.put(response, True, 2.0)
+          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          self.assertTrue(receivedResponse)
+
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          receivedQuery.id = query.id
+          receivedResponse.id = response.id
+          self.assertEqual(receivedQuery, query)
+          print(receivedResponse)
+          print(response)
+          self.assertEqual(receivedResponse, response)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+    @classmethod
+    def tearDownClass(cls):
+        cls._sock.close()
+        for backgroundThread in cls._backgroundThreads:
+            cls._backgroundThreads[backgroundThread] = False
+        for backgroundThread in backgroundThreads:
+            backgroundThreads[backgroundThread] = False
+        cls.killProcess(cls._dnsdist)
+
 class TestProxyProtocolNotExpected(DNSDistTest):
     """
     dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1