From: Remi Gacogne Date: Tue, 24 Nov 2020 13:24:02 +0000 (+0100) Subject: dnsdist: Fix TLV reuse over TCP X-Git-Tag: rec-4.5.0-alpha1~19^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c4f2293bc93811ba510925d641e63c2f3ef59f69;p=thirdparty%2Fpdns.git dnsdist: Fix TLV reuse over TCP --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index cdb2b7d8d0..0fa07c2ebe 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -168,13 +168,11 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState() } } -std::shared_ptr IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr& ds, const struct timeval& now) +std::shared_ptr IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr& ds, const std::unique_ptr>& tlvs, const struct timeval& now) { std::shared_ptr downstream{nullptr}; - if (!ds->useProxyProtocol || !d_proxyProtocolPayloadHasTLV) { - downstream = getActiveDownstreamConnection(ds); - } + downstream = getActiveDownstreamConnection(ds, tlvs); if (!downstream) { /* we don't have a connection to this backend active yet, let's ask one (it might not be a fresh one, though) */ @@ -354,7 +352,7 @@ void IncomingTCPConnectionState::resetForNewQuery() d_state = State::readingQuerySize; } -std::shared_ptr IncomingTCPConnectionState::getActiveDownstreamConnection(const std::shared_ptr& ds) +std::shared_ptr IncomingTCPConnectionState::getActiveDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs) { auto it = d_activeConnectionsToBackend.find(ds); if (it == d_activeConnectionsToBackend.end()) { @@ -363,8 +361,9 @@ std::shared_ptr IncomingTCPConnectionState::getActiveDow } for (auto& conn : it->second) { - if (conn->canAcceptNewQueries()) { + if (conn->canAcceptNewQueries() && conn->matchesTLVs(tlvs)) { DEBUGLOG("Got one active connection accepting more for "<getName()); + conn->setReused(); return conn; } DEBUGLOG("not accepting more for "<getName()); @@ -541,7 +540,9 @@ static IOState handleQuery(std::shared_ptr& state, c dq.dnsCryptQuery = std::move(dnsCryptQuery); dq.sni = state->d_handler.getServerNameIndication(); if (state->d_proxyProtocolValues) { - dq.proxyProtocolValues = std::move(state->d_proxyProtocolValues); + /* we need to copy them, because the next queries received on that connection will + need to get the _unaltered_ values */ + dq.proxyProtocolValues = make_unique>(*state->d_proxyProtocolValues); } state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR); @@ -582,6 +583,9 @@ static IOState handleQuery(std::shared_ptr& state, c especially alignment issues */ state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2); + auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); + downstreamConnection->assignToClientConnection(state, state->d_isXFR); + bool proxyProtocolPayloadAdded = false; std::string proxyProtocolPayload; @@ -592,16 +596,16 @@ static IOState handleQuery(std::shared_ptr& state, c } proxyProtocolPayload = getProxyProtocolPayload(dq); - - if (state->d_proxyProtocolPayloadHasTLV) { + if (state->d_proxyProtocolPayloadHasTLV && downstreamConnection->isFresh()) { /* we will not be able to reuse an existing connection anyway so let's add the payload right now */ addProxyProtocol(state->d_buffer, proxyProtocolPayload); proxyProtocolPayloadAdded = true; } } - auto downstreamConnection = state->getDownstreamConnection(ds, now); - downstreamConnection->assignToClientConnection(state, state->d_isXFR); + if (dq.proxyProtocolValues) { + downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues)); + } if (proxyProtocolPayloadAdded) { downstreamConnection->setProxyProtocolPayloadAdded(true); diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index b16195fa10..19a0563c90 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -541,3 +541,25 @@ void TCPConnectionToBackend::setProxyProtocolPayloadAdded(bool added) { d_proxyProtocolPayloadAdded = added; } + +void TCPConnectionToBackend::setProxyProtocolValuesSent(std::unique_ptr>&& proxyProtocolValuesSent) +{ + /* if we already have some values, we have already verified they match */ + if (!d_proxyProtocolValuesSent) { + d_proxyProtocolValuesSent = std::move(proxyProtocolValuesSent); + } +} + +bool TCPConnectionToBackend::matchesTLVs(const std::unique_ptr>& tlvs) const +{ + if (tlvs == nullptr && d_proxyProtocolValuesSent == nullptr) { + return true; + } + if (tlvs == nullptr && d_proxyProtocolValuesSent != nullptr) { + return false; + } + if (tlvs != nullptr && d_proxyProtocolValuesSent == nullptr) { + return false; + } + return *tlvs == *d_proxyProtocolValuesSent; +} diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh index 9bb64909b8..2ce536a8db 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh @@ -151,6 +151,8 @@ public: return true; } + bool matchesTLVs(const std::unique_ptr>& tlvs) const; + bool matches(const std::shared_ptr& ds) const { if (!ds || !d_ds) { @@ -165,6 +167,7 @@ public: void setProxyProtocolPayload(std::string&& payload); void setProxyProtocolPayloadAdded(bool added); + void setProxyProtocolValuesSent(std::unique_ptr>&& proxyProtocolValuesSent); private: /* waitingForResponseFromBackend is a state where we have not yet started reading the size, @@ -217,6 +220,7 @@ private: PacketBuffer d_responseBuffer; std::deque d_pendingQueries; std::unordered_map d_pendingResponses; + std::unique_ptr> d_proxyProtocolValuesSent{nullptr}; std::unique_ptr d_socket{nullptr}; std::unique_ptr d_ioState{nullptr}; std::shared_ptr d_ds{nullptr}; diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 983d96f2d0..2b650a7950 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -139,8 +139,8 @@ public: return false; } - std::shared_ptr getActiveDownstreamConnection(const std::shared_ptr& ds); - std::shared_ptr getDownstreamConnection(std::shared_ptr& ds, const struct timeval& now); + std::shared_ptr getActiveDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs); + std::shared_ptr getDownstreamConnection(std::shared_ptr& ds, const std::unique_ptr>& tlvs, const struct timeval& now); void registerActiveDownstreamConnection(std::shared_ptr& conn); std::unique_ptr& getIOMPlexer() const diff --git a/pdns/proxy-protocol.hh b/pdns/proxy-protocol.hh index 97f7ac777e..b2beb16d7a 100644 --- a/pdns/proxy-protocol.hh +++ b/pdns/proxy-protocol.hh @@ -28,6 +28,11 @@ struct ProxyProtocolValue { std::string content; uint8_t type; + + bool operator==(const ProxyProtocolValue& rhs) const + { + return type == rhs.type && content == rhs.content; + } }; static const size_t s_proxyProtocolMinimumHeaderSize = 16;