]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix handling of proxy protocol payload outside of TLS for DoT 14639/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Sep 2024 12:20:48 +0000 (14:20 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 9 Sep 2024 13:09:35 +0000 (15:09 +0200)
After reading the proxy protocol payload from the I/O buffer
we were clearing the buffer but failed to properly reset the
position, leading to an exception when trying to read the DNS
payload after processing the TLS handshake:

```
Got an exception while handling (reading) TCP query from 127.0.0.1:59426: Calling tryRead() with a too small buffer (2) for a read of 18446744073709551566 bytes starting at 52
```

The huge value comes from the fact that the position (52 here)
is larger than the size of the buffer (2 at this point to read
the size of the incoming DNS payload), leading to an unsigned
underflow. The code is properly detecting that the value makes
no sense in this context, but the connection is then dropped
because we cannot recover.

It turns out we had a end-to-end test for the "proxy protocol
outside of TLS" case but only over incoming DoH, and the DoH
case avoids this specific issue because the buffer is always
properly resized, and the position updated.

(cherry picked from commit 4931fb28f7bc6e8905d3298003dead7c32f4d090)

pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-nghttp2-in.cc
regression-tests.dnsdist/test_ProxyProtocol.py

index 3db77ae279e2ab894d09c5fcba647ed864287f79..89cf81c991f505797270020966068e79576524c4 100644 (file)
@@ -908,6 +908,9 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand
           d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
         }
 
+        d_currentPos = 0;
+        d_proxyProtocolNeed = 0;
+        d_buffer.clear();
         return ProxyProtocolResult::Done;
       }
     }
@@ -1090,15 +1093,14 @@ void IncomingTCPConnectionState::handleIO()
       if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) {
         auto status = handleProxyProtocolPayload();
         if (status == ProxyProtocolResult::Done) {
+          d_buffer.resize(sizeof(uint16_t));
+
           if (isProxyPayloadOutsideTLS()) {
             d_state = State::doingHandshake;
             iostate = handleHandshake(now);
           }
           else {
             d_state = State::readingQuerySize;
-            d_buffer.resize(sizeof(uint16_t));
-            d_currentPos = 0;
-            d_proxyProtocolNeed = 0;
           }
         }
         else if (status == ProxyProtocolResult::Error) {
index 8458bc77bfc088b38a28abac458ddb37b18685b9..6dd09de9229feb800b3c074a9463c028ab686901 100644 (file)
@@ -400,9 +400,6 @@ void IncomingHTTP2Connection::handleIO()
           }
         }
         else {
-          d_currentPos = 0;
-          d_proxyProtocolNeed = 0;
-          d_buffer.clear();
           d_state = State::waitingForQuery;
           handleConnectionReady();
         }
index 2ed60e08bc9f4351b7a59da384030f4b3f7691a6..78677b3a7149f7e9a4775e6264b78f36947717ba 100644 (file)
@@ -142,7 +142,6 @@ class TestProxyProtocol(ProxyProtocolTest):
     addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"}))
     """
     _config_params = ['_proxyResponderPort']
-    _verboseMode = True
 
     def testProxyUDP(self):
         """
@@ -379,6 +378,8 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     _config_template = """
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
     addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
+    addTLSLocal("127.0.0.1:%d", "%s", "%s", {proxyProtocolOutsideTLS=true})
+    addTLSLocal("127.0.0.1:%d", "%s", "%s", {proxyProtocolOutsideTLS=false})
     setProxyProtocolACL( { "127.0.0.1/32" } )
     newServer{address="127.0.0.1:%d", useProxyProtocol=true, proxyProtocolAdvertiseTLS=true}
 
@@ -421,7 +422,9 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     _caCert = 'ca.pem'
     _dohServerPPOutsidePort = pickAvailablePort()
     _dohServerPPInsidePort = pickAvailablePort()
-    _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort']
+    _dotServerPPOutsidePort = pickAvailablePort()
+    _dotServerPPInsidePort = pickAvailablePort()
+    _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_dotServerPPOutsidePort', '_serverCert', '_serverKey', '_dotServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort']
 
     def testNoHeader(self):
         """
@@ -666,7 +669,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
 
         reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
-        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn)
         (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
         self.assertTrue(receivedProxyPayload)
         self.assertTrue(receivedDNSData)
@@ -682,7 +685,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         for idx in range(5):
           receivedResponse = None
           toProxyQueue.put(response, True, 2.0)
-          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn)
           (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
           self.assertTrue(receivedProxyPayload)
           self.assertTrue(receivedDNSData)
@@ -719,7 +722,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
 
         reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
-        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn)
         (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
         self.assertTrue(receivedProxyPayload)
         self.assertTrue(receivedDNSData)
@@ -735,7 +738,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         for idx in range(5):
           receivedResponse = None
           toProxyQueue.put(response, True, 2.0)
-          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn)
           (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
           self.assertTrue(receivedProxyPayload)
           self.assertTrue(receivedDNSData)
@@ -748,6 +751,108 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
           self.assertEqual(receivedResponse, response)
           self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
 
+    def testProxyDoTSeveralQueriesOverConnectionPPOutside(self):
+        """
+        Incoming Proxy Protocol: Several queries over the same connection (DoT, PP outside TLS)
+        """
+        name = 'several-queries.dot-outside.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        reverseProxyPort = pickAvailablePort()
+        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dotServerPPOutsidePort])
+        reverseProxy.start()
+        time.sleep(1)
+
+        receivedResponse = None
+        conn = self.openTLSConnection(reverseProxyPort, self._serverName, self._caCert, timeout=2.0)
+        self.sendTCPQueryOverConnection(conn, query, response=response)
+        receivedResponse = self.recvTCPResponseOverConnection(conn)
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+        for idx in range(5):
+          receivedResponse = None
+          toProxyQueue.put(response, True, 2.0)
+          self.sendTCPQueryOverConnection(conn, query, response=response)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          self.assertTrue(receivedResponse)
+
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          receivedQuery.id = query.id
+          receivedResponse.id = response.id
+          self.assertEqual(receivedQuery, query)
+          self.assertEqual(receivedResponse, response)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+    def testProxyDoTSeveralQueriesOverConnectionPPInside(self):
+        """
+        Incoming Proxy Protocol: Several queries over the same connection (DoT, PP inside TLS)
+        """
+        name = 'several-queries.dot-inside.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        reverseProxyPort = pickAvailablePort()
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain(self._serverCert, self._serverKey)
+        tlsContext.set_alpn_protocols(['dot'])
+        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dotServerPPInsidePort, tlsContext, self._caCert, self._serverName])
+        reverseProxy.start()
+
+        receivedResponse = None
+        time.sleep(1)
+        conn = self.openTLSConnection(reverseProxyPort, self._serverName, self._caCert, timeout=2.0)
+
+        self.sendTCPQueryOverConnection(conn, query, response=response)
+        receivedResponse = self.recvTCPResponseOverConnection(conn)
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+
+        for idx in range(5):
+          receivedResponse = None
+          toProxyQueue.put(response, True, 2.0)
+          self.sendTCPQueryOverConnection(conn, query, response=response)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          self.assertTrue(receivedResponse)
+
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          receivedQuery.id = query.id
+          receivedResponse.id = response.id
+          self.assertEqual(receivedQuery, query)
+          self.assertEqual(receivedResponse, response)
+
     @classmethod
     def tearDownClass(cls):
         cls._sock.close()
@@ -768,7 +873,6 @@ class TestProxyProtocolNotExpected(DNSDistTest):
     """
     # NORMAL responder, does not expect a proxy protocol payload!
     _config_params = ['_testServerPort']
-    _verboseMode = True
 
     def testNoHeader(self):
         """
@@ -910,7 +1014,6 @@ class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
     setACL( { "::1/128", "127.0.0.0/8" } )
     """
     _config_params = ['_proxyResponderPort', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey']
-    _verboseMode = True
 
     def testTruncation(self):
         """