]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Don't reuse Proxy Protocol-enabled TCP connections to backends
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 2 Mar 2020 15:46:46 +0000 (16:46 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 17 Mar 2020 13:12:55 +0000 (14:12 +0100)
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/doh.cc
regression-tests.dnsdist/test_ProxyProtocol.py

index 76817540e0e37cf22b3e8b43b69f554c5cc14e62..7eb3255745c14ec74776b346bccbc1c17450d59f 100644 (file)
@@ -173,6 +173,19 @@ public:
     return d_enableFastOpen;
   }
 
+  bool canBeReused() const
+  {
+    /* we can't reuse a connection where a proxy protocol payload has been sent,
+       since:
+       - it cannot be reused for a different client
+       - we might have different TLV values for each query
+    */
+    if (d_ds && d_ds->useProxyProtocol) {
+      return false;
+    }
+    return true;
+  }
+
 private:
   std::unique_ptr<Socket> d_socket{nullptr};
   std::shared_ptr<DownstreamState> d_ds{nullptr};
@@ -208,6 +221,11 @@ static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&
     return;
   }
 
+  if (!conn->canBeReused()) {
+    conn.reset();
+    return;
+  }
+
   const auto& remote = conn->getRemote();
   const auto& it = t_downstreamConnections.find(remote);
   if (it != t_downstreamConnections.end()) {
@@ -917,7 +935,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, stru
   dq.dh = reinterpret_cast<dnsheader*>(&state->d_buffer.at(0));
   dq.size = state->d_buffer.size();
 
-  if (dq.addProxyProtocol && state->d_ds->useProxyProtocol) {
+  if (state->d_ds->useProxyProtocol) {
     addProxyProtocol(dq);
   }
 
@@ -1092,6 +1110,7 @@ static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& stat
   }
 
   if (connectionDied) {
+    state->d_downstreamConnection.reset();
     sendQueryToBackend(state, now);
   }
 }
index 364fc81625a69032407feb0c6d5832fa0b3fea86..53637eb73f9b18fb324c45e50cf15fd2cf672f93 100644 (file)
@@ -1368,7 +1368,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 
     dh->id = idOffset;
 
-    if (dq.addProxyProtocol && ss->useProxyProtocol) {
+    if (ss->useProxyProtocol) {
       addProxyProtocol(dq);
     }
 
index 14e0fd287bcc0b8e9e696513808921575bfa9bf4..3848149fc8d0c774a03c280ec69c905c4bb67f69 100644 (file)
@@ -109,7 +109,6 @@ struct DNSQuestion
   bool ecsOverride;
   bool useECS{true};
   bool addXPF{true};
-  bool addProxyProtocol{true};
   bool ecsSet{false};
   bool ecsAdded{false};
   bool ednsAdded{false};
index 96a77e88c5bdc07dd2334ad2db77e00eac0f32b9..6abf4b304e25e51326bb78fdc48e2d55b790fbab 100644 (file)
@@ -503,7 +503,7 @@ static int processDOHQuery(DOHUnit* du)
 
     dh->id = idOffset;
 
-    if (dq.addProxyProtocol && ss->useProxyProtocol) {
+    if (ss->useProxyProtocol) {
       addProxyProtocol(dq);
     }
 
index 2d4e0069ec58e35e01e99ae27774334489de0c81..660346516d2e181dc8fbde03492352a870474dcd 100644 (file)
@@ -91,26 +91,34 @@ def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
             continue
 
         payload = header + proxyContent
+        while True:
+          try:
+            data = conn.recv(2)
+          except socket.timeout:
+            data = None
 
-        data = conn.recv(2)
-        (datalen,) = struct.unpack("!H", data)
+          if not data:
+            conn.close()
+            break
 
-        data = conn.recv(datalen)
+          (datalen,) = struct.unpack("!H", data)
+          data = conn.recv(datalen)
 
-        toQueue.put([payload, data], True, 2.0)
+          toQueue.put([payload, data], True, 2.0)
 
-        response = fromQueue.get(True, 2.0)
-        if not response:
+          response = fromQueue.get(True, 2.0)
+          if not response:
             conn.close()
-            continue
+            break
 
-        # computing the correct ID for the response
-        request = dns.message.from_wire(data)
-        response.id = request.id
+          # computing the correct ID for the response
+          request = dns.message.from_wire(data)
+          response.id = request.id
+
+          wire = response.to_wire()
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
 
-        wire = response.to_wire()
-        conn.send(struct.pack("!H", len(wire)))
-        conn.send(wire)
         conn.close()
 
     sock.close()
@@ -366,3 +374,35 @@ class TestProxyProtocol(ProxyProtocolTest):
       self.assertEquals(receivedQuery, query)
       self.assertEquals(receivedResponse, response)
       self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [1, b'dnsdist'] , [ 255, b'proxy-protocol'] ])
+
+    def testProxyTCPSeveralQueriesOnSameConnection(self):
+      """
+        Proxy Protocol: Several queries on the same TCP connection
+      """
+      name = 'several-queries-same-conn.proxy.tests.powerdns.com.'
+      query = dns.message.make_query(name, 'A', 'IN')
+      response = dns.message.make_response(query)
+
+      conn = self.openTCPConnection(2.0)
+      data = query.to_wire()
+
+      for idx in range(10):
+        toProxyQueue.put(response, True, 2.0)
+        self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
+        receivedResponse = None
+        try:
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [])