From: Remi Gacogne Date: Tue, 19 Oct 2021 10:33:33 +0000 (+0200) Subject: dnsdist: Fix proxy protocol handling (and broken tests) X-Git-Tag: rec-4.6.0-beta1~28^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c48a3e337a7d40c9ec4d5ab35dabf3e97e5f721f;p=thirdparty%2Fpdns.git dnsdist: Fix proxy protocol handling (and broken tests) --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 3307980240..86d5b9c7b0 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -723,8 +723,6 @@ static void handleQuery(std::shared_ptr& state, cons auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); - bool proxyProtocolPayloadAdded = false; - if (ds->useProxyProtocol) { /* if we ever sent a TLV over a connection, we can never go back */ if (!state->d_proxyProtocolPayloadHasTLV) { @@ -732,11 +730,6 @@ static void handleQuery(std::shared_ptr& state, cons } proxyProtocolPayload = getProxyProtocolPayload(dq); - if (state->d_proxyProtocolPayloadHasTLV && downstreamConnection->isFresh()) { - /* we will not be able to reuse an existing connection anyway so let's add the payload right now */ - addProxyProtocol(state->d_buffer, proxyProtocolPayload); - proxyProtocolPayloadAdded = true; - } } if (dq.proxyProtocolValues) { @@ -744,12 +737,7 @@ static void handleQuery(std::shared_ptr& state, cons } TCPQuery query(std::move(state->d_buffer), std::move(ids)); - if (proxyProtocolPayloadAdded) { - query.d_proxyProtocolPayloadAdded = true; - } - else { - query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - } + query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName()); std::shared_ptr incoming = state; diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index b6bdb75b95..3d48e702b2 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -152,12 +152,43 @@ static void editPayloadID(PacketBuffer& payload, uint16_t newId, size_t proxyPro memcpy(&payload.at(startOfHeaderOffset), &dh, sizeof(dh)); } +enum class QueryState : uint8_t { + hasSizePrepended, + noSize +}; + +enum class ConnectionState : uint8_t { + needProxy, + proxySent +}; + +static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState queryState, ConnectionState connectionState) +{ + if (connectionState == ConnectionState::needProxy) { + if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) { + query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end()); + query.d_proxyProtocolPayloadAdded = true; + } + } + else if (connectionState == ConnectionState::proxySent) { + if (query.d_proxyProtocolPayloadAdded) { + if (query.d_buffer.size() < query.d_proxyProtocolPayload.size()) { + throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size())); + } + query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayload.size()); + query.d_proxyProtocolPayloadAdded = false; + } + } + + editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true); +} + IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr& conn) { conn->d_currentQuery = std::move(conn->d_pendingQueries.front()); uint16_t id = conn->d_highestStreamID; - editPayloadID(conn->d_currentQuery.d_query.d_buffer, id, conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded ? conn->d_currentQuery.d_query.d_proxyProtocolPayload.size() : 0, true); + prepareQueryForSending(conn->d_currentQuery.d_query, id, QueryState::hasSizePrepended, conn->needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent); conn->d_pendingQueries.pop_front(); conn->d_state = State::sendingQueryToBackend; @@ -318,9 +349,8 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c if (conn->d_state == State::sendingQueryToBackend) { /* we need to edit this query so it has the correct ID */ auto query = std::move(conn->d_currentQuery); - uint16_t id = conn->d_highestStreamID; - editPayloadID(query.d_query.d_buffer, id, query.d_query.d_proxyProtocolPayloadAdded ? query.d_query.d_proxyProtocolPayload.size() : 0, true); + prepareQueryForSending(query.d_query, id, QueryState::hasSizePrepended, ConnectionState::needProxy); conn->d_currentQuery = std::move(query); } @@ -359,11 +389,6 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c iostate = queueNextQuery(conn); } - if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_query.d_proxyProtocolPayload.empty()) { - conn->d_currentQuery.d_query.d_buffer.insert(conn->d_currentQuery.d_query.d_buffer.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.end()); - conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true; - } - reconnected = true; connectionDied = false; } @@ -422,6 +447,7 @@ void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t void TCPConnectionToBackend::queueQuery(std::shared_ptr& sender, TCPQuery&& query) { + cerr<<"in "<<__PRETTY_FUNCTION__<<" for a query with a buffer of size "<(*d_mplexer, d_handler->getDescriptor()); } @@ -434,14 +460,9 @@ void TCPConnectionToBackend::queueQuery(std::shared_ptr& sender, d_currentPos = 0; uint16_t id = d_highestStreamID; - editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true); d_currentQuery = PendingRequest({sender, std::move(query)}); - - if (needProxyProtocolPayload() && !d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !d_currentQuery.d_query.d_proxyProtocolPayload.empty()) { - d_currentQuery.d_query.d_buffer.insert(d_currentQuery.d_query.d_buffer.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.end()); - d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true; - } + prepareQueryForSending(d_currentQuery.d_query, id, QueryState::hasSizePrepended, needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent); struct timeval now; gettimeofday(&now, 0); diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index 16b27fde47..bf073f74ed 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -1,5 +1,6 @@ #!/usr/bin/env python +import copy import dns import socket import struct @@ -110,7 +111,7 @@ def ProxyProtocolTCPResponder(port, fromQueue, toQueue): toQueue.put([payload, data], True, 2.0) - response = fromQueue.get(True, 2.0) + response = copy.deepcopy(fromQueue.get(True, 2.0)) if not response: conn.close() break @@ -160,6 +161,7 @@ class TestProxyProtocol(ProxyProtocolTest): addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"})) """ _config_params = ['_proxyResponderPort'] + _verboseMode = True def testProxyUDP(self): """ @@ -553,7 +555,6 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): receivedQuery = dns.message.from_wire(receivedDNSData) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEqual(receivedQuery, query) self.assertEqual(receivedResponse, response) self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort) @@ -600,7 +601,6 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): destPort = 9999 srcAddr = "2001:db8::8" srcPort = 8888 - response = dns.message.make_response(query) tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) @@ -650,7 +650,6 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): receivedQuery = dns.message.from_wire(receivedDNSData) receivedQuery.id = query.id - receivedResponse.id = response.id self.assertEqual(receivedQuery, query) self.assertEqual(receivedResponse, response) self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)