]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix proxy protocol handling (and broken tests)
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 19 Oct 2021 10:33:33 +0000 (12:33 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 26 Oct 2021 15:07:19 +0000 (17:07 +0200)
pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
regression-tests.dnsdist/test_ProxyProtocol.py

index 33079802403f17bc7a65a4cc1fefca37fff00ba1..86d5b9c7b029db77509584eb0acb8c7cf7e1c18e 100644 (file)
@@ -723,8 +723,6 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
 
   auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
 
-  bool proxyProtocolPayloadAdded = false;
-
   if (ds->useProxyProtocol) {
     /* if we ever sent a TLV over a connection, we can never go back */
     if (!state->d_proxyProtocolPayloadHasTLV) {
@@ -732,11 +730,6 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
     }
 
     proxyProtocolPayload = getProxyProtocolPayload(dq);
-    if (state->d_proxyProtocolPayloadHasTLV && downstreamConnection->isFresh()) {
-      /* we will not be able to reuse an existing connection anyway so let's add the payload right now */
-      addProxyProtocol(state->d_buffer, proxyProtocolPayload);
-      proxyProtocolPayloadAdded = true;
-    }
   }
 
   if (dq.proxyProtocolValues) {
@@ -744,12 +737,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
   }
 
   TCPQuery query(std::move(state->d_buffer), std::move(ids));
-  if (proxyProtocolPayloadAdded) {
-    query.d_proxyProtocolPayloadAdded = true;
-  }
-  else {
-    query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
-  }
+  query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
   vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName());
   std::shared_ptr<TCPQuerySender> incoming = state;
index b6bdb75b9557732bc9e8a98c403d92646571fac3..3d48e702b2272cb0e21bca786509f0a5b2393578 100644 (file)
@@ -152,12 +152,43 @@ static void editPayloadID(PacketBuffer& payload, uint16_t newId, size_t proxyPro
   memcpy(&payload.at(startOfHeaderOffset), &dh, sizeof(dh));
 }
 
+enum class QueryState : uint8_t {
+  hasSizePrepended,
+  noSize
+};
+
+enum class ConnectionState : uint8_t {
+  needProxy,
+  proxySent
+};
+
+static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState queryState, ConnectionState connectionState)
+{
+  if (connectionState == ConnectionState::needProxy) {
+    if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) {
+      query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end());
+      query.d_proxyProtocolPayloadAdded = true;
+    }
+  }
+  else if (connectionState == ConnectionState::proxySent) {
+    if (query.d_proxyProtocolPayloadAdded) {
+      if (query.d_buffer.size() < query.d_proxyProtocolPayload.size()) {
+        throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size()));
+      }
+      query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayload.size());
+      query.d_proxyProtocolPayloadAdded = false;
+    }
+  }
+
+  editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true);
+}
+
 IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
 {
   conn->d_currentQuery = std::move(conn->d_pendingQueries.front());
 
   uint16_t id = conn->d_highestStreamID;
-  editPayloadID(conn->d_currentQuery.d_query.d_buffer, id, conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded ? conn->d_currentQuery.d_query.d_proxyProtocolPayload.size() : 0, true);
+  prepareQueryForSending(conn->d_currentQuery.d_query, id, QueryState::hasSizePrepended, conn->needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent);
 
   conn->d_pendingQueries.pop_front();
   conn->d_state = State::sendingQueryToBackend;
@@ -318,9 +349,8 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
             if (conn->d_state == State::sendingQueryToBackend) {
               /* we need to edit this query so it has the correct ID */
               auto query = std::move(conn->d_currentQuery);
-
               uint16_t id = conn->d_highestStreamID;
-              editPayloadID(query.d_query.d_buffer, id, query.d_query.d_proxyProtocolPayloadAdded ? query.d_query.d_proxyProtocolPayload.size() : 0, true);
+              prepareQueryForSending(query.d_query, id, QueryState::hasSizePrepended, ConnectionState::needProxy);
               conn->d_currentQuery = std::move(query);
             }
 
@@ -359,11 +389,6 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
               iostate = queueNextQuery(conn);
             }
 
-            if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_query.d_proxyProtocolPayload.empty()) {
-              conn->d_currentQuery.d_query.d_buffer.insert(conn->d_currentQuery.d_query.d_buffer.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.end());
-              conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true;
-            }
-
             reconnected = true;
             connectionDied = false;
           }
@@ -422,6 +447,7 @@ void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t
 
 void TCPConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query)
 {
+  cerr<<"in "<<__PRETTY_FUNCTION__<<" for a query with a buffer of size "<<query.d_buffer.size()<<" and a proxy protocol payload size of "<<query.d_proxyProtocolPayload.size()<<" which has been added: "<<query.d_proxyProtocolPayloadAdded<<endl;
   if (!d_ioState) {
     d_ioState = make_unique<IOStateHandler>(*d_mplexer, d_handler->getDescriptor());
   }
@@ -434,14 +460,9 @@ void TCPConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender,
     d_currentPos = 0;
 
     uint16_t id = d_highestStreamID;
-    editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true);
 
     d_currentQuery = PendingRequest({sender, std::move(query)});
-
-    if (needProxyProtocolPayload() && !d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !d_currentQuery.d_query.d_proxyProtocolPayload.empty()) {
-      d_currentQuery.d_query.d_buffer.insert(d_currentQuery.d_query.d_buffer.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.end());
-      d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true;
-    }
+    prepareQueryForSending(d_currentQuery.d_query, id, QueryState::hasSizePrepended, needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent);
 
     struct timeval now;
     gettimeofday(&now, 0);
index 16b27fde4713c4962ea608772d7e5282781b6411..bf073f74ed639de2465a6c678705baaa251e8148 100644 (file)
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 
+import copy
 import dns
 import socket
 import struct
@@ -110,7 +111,7 @@ def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
 
           toQueue.put([payload, data], True, 2.0)
 
-          response = fromQueue.get(True, 2.0)
+          response = copy.deepcopy(fromQueue.get(True, 2.0))
           if not response:
             conn.close()
             break
@@ -160,6 +161,7 @@ class TestProxyProtocol(ProxyProtocolTest):
     addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"}))
     """
     _config_params = ['_proxyResponderPort']
+    _verboseMode = True
 
     def testProxyUDP(self):
         """
@@ -553,7 +555,6 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
 
         receivedQuery = dns.message.from_wire(receivedDNSData)
         receivedQuery.id = query.id
-        receivedResponse.id = response.id
         self.assertEqual(receivedQuery, query)
         self.assertEqual(receivedResponse, response)
         self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
@@ -600,7 +601,6 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         destPort = 9999
         srcAddr = "2001:db8::8"
         srcPort = 8888
-        response = dns.message.make_response(query)
 
         tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
 
@@ -650,7 +650,6 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
 
           receivedQuery = dns.message.from_wire(receivedDNSData)
           receivedQuery.id = query.id
-          receivedResponse.id = response.id
           self.assertEqual(receivedQuery, query)
           self.assertEqual(receivedResponse, response)
           self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)