auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
- bool proxyProtocolPayloadAdded = false;
-
if (ds->useProxyProtocol) {
/* if we ever sent a TLV over a connection, we can never go back */
if (!state->d_proxyProtocolPayloadHasTLV) {
}
proxyProtocolPayload = getProxyProtocolPayload(dq);
- if (state->d_proxyProtocolPayloadHasTLV && downstreamConnection->isFresh()) {
- /* we will not be able to reuse an existing connection anyway so let's add the payload right now */
- addProxyProtocol(state->d_buffer, proxyProtocolPayload);
- proxyProtocolPayloadAdded = true;
- }
}
if (dq.proxyProtocolValues) {
}
TCPQuery query(std::move(state->d_buffer), std::move(ids));
- if (proxyProtocolPayloadAdded) {
- query.d_proxyProtocolPayloadAdded = true;
- }
- else {
- query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
- }
+ query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName());
std::shared_ptr<TCPQuerySender> incoming = state;
memcpy(&payload.at(startOfHeaderOffset), &dh, sizeof(dh));
}
+enum class QueryState : uint8_t {
+ hasSizePrepended,
+ noSize
+};
+
+enum class ConnectionState : uint8_t {
+ needProxy,
+ proxySent
+};
+
+static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState queryState, ConnectionState connectionState)
+{
+ if (connectionState == ConnectionState::needProxy) {
+ if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) {
+ query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end());
+ query.d_proxyProtocolPayloadAdded = true;
+ }
+ }
+ else if (connectionState == ConnectionState::proxySent) {
+ if (query.d_proxyProtocolPayloadAdded) {
+ if (query.d_buffer.size() < query.d_proxyProtocolPayload.size()) {
+ throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size()));
+ }
+ query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayload.size());
+ query.d_proxyProtocolPayloadAdded = false;
+ }
+ }
+
+ editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true);
+}
+
IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
{
conn->d_currentQuery = std::move(conn->d_pendingQueries.front());
uint16_t id = conn->d_highestStreamID;
- editPayloadID(conn->d_currentQuery.d_query.d_buffer, id, conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded ? conn->d_currentQuery.d_query.d_proxyProtocolPayload.size() : 0, true);
+ prepareQueryForSending(conn->d_currentQuery.d_query, id, QueryState::hasSizePrepended, conn->needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent);
conn->d_pendingQueries.pop_front();
conn->d_state = State::sendingQueryToBackend;
if (conn->d_state == State::sendingQueryToBackend) {
/* we need to edit this query so it has the correct ID */
auto query = std::move(conn->d_currentQuery);
-
uint16_t id = conn->d_highestStreamID;
- editPayloadID(query.d_query.d_buffer, id, query.d_query.d_proxyProtocolPayloadAdded ? query.d_query.d_proxyProtocolPayload.size() : 0, true);
+ prepareQueryForSending(query.d_query, id, QueryState::hasSizePrepended, ConnectionState::needProxy);
conn->d_currentQuery = std::move(query);
}
iostate = queueNextQuery(conn);
}
- if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_query.d_proxyProtocolPayload.empty()) {
- conn->d_currentQuery.d_query.d_buffer.insert(conn->d_currentQuery.d_query.d_buffer.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.end());
- conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true;
- }
-
reconnected = true;
connectionDied = false;
}
void TCPConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query)
{
+ cerr<<"in "<<__PRETTY_FUNCTION__<<" for a query with a buffer of size "<<query.d_buffer.size()<<" and a proxy protocol payload size of "<<query.d_proxyProtocolPayload.size()<<" which has been added: "<<query.d_proxyProtocolPayloadAdded<<endl;
if (!d_ioState) {
d_ioState = make_unique<IOStateHandler>(*d_mplexer, d_handler->getDescriptor());
}
d_currentPos = 0;
uint16_t id = d_highestStreamID;
- editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true);
d_currentQuery = PendingRequest({sender, std::move(query)});
-
- if (needProxyProtocolPayload() && !d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !d_currentQuery.d_query.d_proxyProtocolPayload.empty()) {
- d_currentQuery.d_query.d_buffer.insert(d_currentQuery.d_query.d_buffer.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.end());
- d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true;
- }
+ prepareQueryForSending(d_currentQuery.d_query, id, QueryState::hasSizePrepended, needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent);
struct timeval now;
gettimeofday(&now, 0);
#!/usr/bin/env python
+import copy
import dns
import socket
import struct
toQueue.put([payload, data], True, 2.0)
- response = fromQueue.get(True, 2.0)
+ response = copy.deepcopy(fromQueue.get(True, 2.0))
if not response:
conn.close()
break
addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"}))
"""
_config_params = ['_proxyResponderPort']
+ _verboseMode = True
def testProxyUDP(self):
"""
receivedQuery = dns.message.from_wire(receivedDNSData)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEqual(receivedQuery, query)
self.assertEqual(receivedResponse, response)
self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
destPort = 9999
srcAddr = "2001:db8::8"
srcPort = 8888
- response = dns.message.make_response(query)
tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
receivedQuery = dns.message.from_wire(receivedDNSData)
receivedQuery.id = query.id
- receivedResponse.id = response.id
self.assertEqual(receivedQuery, query)
self.assertEqual(receivedResponse, response)
self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)