From 80d74de3cd3e7b4faaa9c34fadf25d7ce95a996d Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Mon, 2 Mar 2020 16:46:46 +0100 Subject: [PATCH] dnsdist: Don't reuse Proxy Protocol-enabled TCP connections to backends --- pdns/dnsdist-tcp.cc | 21 +++++- pdns/dnsdist.cc | 2 +- pdns/dnsdist.hh | 1 - pdns/dnsdistdist/doh.cc | 2 +- .../test_ProxyProtocol.py | 66 +++++++++++++++---- 5 files changed, 75 insertions(+), 17 deletions(-) diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 76817540e0..7eb3255745 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -173,6 +173,19 @@ public: return d_enableFastOpen; } + bool canBeReused() const + { + /* we can't reuse a connection where a proxy protocol payload has been sent, + since: + - it cannot be reused for a different client + - we might have different TLV values for each query + */ + if (d_ds && d_ds->useProxyProtocol) { + return false; + } + return true; + } + private: std::unique_ptr d_socket{nullptr}; std::shared_ptr d_ds{nullptr}; @@ -208,6 +221,11 @@ static void releaseDownstreamConnection(std::unique_ptr& return; } + if (!conn->canBeReused()) { + conn.reset(); + return; + } + const auto& remote = conn->getRemote(); const auto& it = t_downstreamConnections.find(remote); if (it != t_downstreamConnections.end()) { @@ -917,7 +935,7 @@ static void handleQuery(std::shared_ptr& state, stru dq.dh = reinterpret_cast(&state->d_buffer.at(0)); dq.size = state->d_buffer.size(); - if (dq.addProxyProtocol && state->d_ds->useProxyProtocol) { + if (state->d_ds->useProxyProtocol) { addProxyProtocol(dq); } @@ -1092,6 +1110,7 @@ static void handleDownstreamIO(std::shared_ptr& stat } if (connectionDied) { + state->d_downstreamConnection.reset(); sendQueryToBackend(state, now); } } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 364fc81625..53637eb73f 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -1368,7 +1368,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct dh->id = idOffset; - if (dq.addProxyProtocol && ss->useProxyProtocol) { + if (ss->useProxyProtocol) { addProxyProtocol(dq); } diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 14e0fd287b..3848149fc8 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -109,7 +109,6 @@ struct DNSQuestion bool ecsOverride; bool useECS{true}; bool addXPF{true}; - bool addProxyProtocol{true}; bool ecsSet{false}; bool ecsAdded{false}; bool ednsAdded{false}; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index 96a77e88c5..6abf4b304e 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -503,7 +503,7 @@ static int processDOHQuery(DOHUnit* du) dh->id = idOffset; - if (dq.addProxyProtocol && ss->useProxyProtocol) { + if (ss->useProxyProtocol) { addProxyProtocol(dq); } diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index 2d4e0069ec..660346516d 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -91,26 +91,34 @@ def ProxyProtocolTCPResponder(port, fromQueue, toQueue): continue payload = header + proxyContent + while True: + try: + data = conn.recv(2) + except socket.timeout: + data = None - data = conn.recv(2) - (datalen,) = struct.unpack("!H", data) + if not data: + conn.close() + break - data = conn.recv(datalen) + (datalen,) = struct.unpack("!H", data) + data = conn.recv(datalen) - toQueue.put([payload, data], True, 2.0) + toQueue.put([payload, data], True, 2.0) - response = fromQueue.get(True, 2.0) - if not response: + response = fromQueue.get(True, 2.0) + if not response: conn.close() - continue + break - # computing the correct ID for the response - request = dns.message.from_wire(data) - response.id = request.id + # computing the correct ID for the response + request = dns.message.from_wire(data) + response.id = request.id + + wire = response.to_wire() + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) - wire = response.to_wire() - conn.send(struct.pack("!H", len(wire))) - conn.send(wire) conn.close() sock.close() @@ -366,3 +374,35 @@ class TestProxyProtocol(ProxyProtocolTest): self.assertEquals(receivedQuery, query) self.assertEquals(receivedResponse, response) self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [1, b'dnsdist'] , [ 255, b'proxy-protocol'] ]) + + def testProxyTCPSeveralQueriesOnSameConnection(self): + """ + Proxy Protocol: Several queries on the same TCP connection + """ + name = 'several-queries-same-conn.proxy.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + conn = self.openTCPConnection(2.0) + data = query.to_wire() + + for idx in range(10): + toProxyQueue.put(response, True, 2.0) + self.sendTCPQueryOverConnection(conn, data, rawQuery=True) + receivedResponse = None + try: + receivedResponse = self.recvTCPResponseOverConnection(conn) + except socket.timeout: + print('timeout') + + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEquals(receivedQuery, query) + self.assertEquals(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, []) -- 2.39.2