]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix TLV reuse over TCP
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 24 Nov 2020 13:24:02 +0000 (14:24 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Jan 2021 09:22:00 +0000 (10:22 +0100)
pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/proxy-protocol.hh

index cdb2b7d8d0ff744a8dbf6d3cfb07a7c5770fc546..0fa07c2ebe66a6a998612e11a210a2d02f4041de 100644 (file)
@@ -168,13 +168,11 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState()
   }
 }
 
-std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
+std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now)
 {
   std::shared_ptr<TCPConnectionToBackend> downstream{nullptr};
 
-  if (!ds->useProxyProtocol || !d_proxyProtocolPayloadHasTLV) {
-    downstream = getActiveDownstreamConnection(ds);
-  }
+  downstream = getActiveDownstreamConnection(ds, tlvs);
 
   if (!downstream) {
     /* we don't have a connection to this backend active yet, let's ask one (it might not be a fresh one, though) */
@@ -354,7 +352,7 @@ void IncomingTCPConnectionState::resetForNewQuery()
   d_state = State::readingQuerySize;
 }
 
-std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds)
+std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs)
 {
   auto it = d_activeConnectionsToBackend.find(ds);
   if (it == d_activeConnectionsToBackend.end()) {
@@ -363,8 +361,9 @@ std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getActiveDow
   }
 
   for (auto& conn : it->second) {
-    if (conn->canAcceptNewQueries()) {
+    if (conn->canAcceptNewQueries() && conn->matchesTLVs(tlvs)) {
       DEBUGLOG("Got one active connection accepting more for "<<ds->getName());
+      conn->setReused();
       return conn;
     }
     DEBUGLOG("not accepting more for "<<ds->getName());
@@ -541,7 +540,9 @@ static IOState handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, c
   dq.dnsCryptQuery = std::move(dnsCryptQuery);
   dq.sni = state->d_handler.getServerNameIndication();
   if (state->d_proxyProtocolValues) {
-    dq.proxyProtocolValues = std::move(state->d_proxyProtocolValues);
+    /* we need to copy them, because the next queries received on that connection will
+       need to get the _unaltered_ values */
+    dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*state->d_proxyProtocolValues);
   }
 
   state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
@@ -582,6 +583,9 @@ static IOState handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, c
      especially alignment issues */
   state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
 
+  auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now);
+  downstreamConnection->assignToClientConnection(state, state->d_isXFR);
+
   bool proxyProtocolPayloadAdded = false;
   std::string proxyProtocolPayload;
 
@@ -592,16 +596,16 @@ static IOState handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, c
     }
 
     proxyProtocolPayload = getProxyProtocolPayload(dq);
-
-    if (state->d_proxyProtocolPayloadHasTLV) {
+    if (state->d_proxyProtocolPayloadHasTLV && downstreamConnection->isFresh()) {
       /* we will not be able to reuse an existing connection anyway so let's add the payload right now */
       addProxyProtocol(state->d_buffer, proxyProtocolPayload);
       proxyProtocolPayloadAdded = true;
     }
   }
 
-  auto downstreamConnection = state->getDownstreamConnection(ds, now);
-  downstreamConnection->assignToClientConnection(state, state->d_isXFR);
+  if (dq.proxyProtocolValues) {
+    downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues));
+  }
 
   if (proxyProtocolPayloadAdded) {
     downstreamConnection->setProxyProtocolPayloadAdded(true);
index b16195fa1049ac9cca096ca450ea058fb53fa6d0..19a0563c90cbc13e2c403be5ceb2480fa1333ec4 100644 (file)
@@ -541,3 +541,25 @@ void TCPConnectionToBackend::setProxyProtocolPayloadAdded(bool added)
 {
   d_proxyProtocolPayloadAdded = added;
 }
+
+void TCPConnectionToBackend::setProxyProtocolValuesSent(std::unique_ptr<std::vector<ProxyProtocolValue>>&& proxyProtocolValuesSent)
+{
+  /* if we already have some values, we have already verified they match */
+  if (!d_proxyProtocolValuesSent) {
+    d_proxyProtocolValuesSent = std::move(proxyProtocolValuesSent);
+  }
+}
+
+bool TCPConnectionToBackend::matchesTLVs(const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) const
+{
+  if (tlvs == nullptr && d_proxyProtocolValuesSent == nullptr) {
+    return true;
+  }
+  if (tlvs == nullptr && d_proxyProtocolValuesSent != nullptr) {
+    return false;
+  }
+  if (tlvs != nullptr && d_proxyProtocolValuesSent == nullptr) {
+    return false;
+  }
+  return *tlvs == *d_proxyProtocolValuesSent;
+}
index 9bb64909b8d1ee74d9167a0105063c5b47fcdfbd..2ce536a8db95538674174caba11c049fd781433c 100644 (file)
@@ -151,6 +151,8 @@ public:
     return true;
   }
 
+  bool matchesTLVs(const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) const;
+
   bool matches(const std::shared_ptr<DownstreamState>& ds) const
   {
     if (!ds || !d_ds) {
@@ -165,6 +167,7 @@ public:
 
   void setProxyProtocolPayload(std::string&& payload);
   void setProxyProtocolPayloadAdded(bool added);
+  void setProxyProtocolValuesSent(std::unique_ptr<std::vector<ProxyProtocolValue>>&& proxyProtocolValuesSent);
 
 private:
   /* waitingForResponseFromBackend is a state where we have not yet started reading the size,
@@ -217,6 +220,7 @@ private:
   PacketBuffer d_responseBuffer;
   std::deque<TCPQuery> d_pendingQueries;
   std::unordered_map<uint16_t, TCPQuery> d_pendingResponses;
+  std::unique_ptr<std::vector<ProxyProtocolValue>> d_proxyProtocolValuesSent{nullptr};
   std::unique_ptr<Socket> d_socket{nullptr};
   std::unique_ptr<IOStateHandler> d_ioState{nullptr};
   std::shared_ptr<DownstreamState> d_ds{nullptr};
index 983d96f2d05f9d82bb0b66063f9e85a3b461e728..2b650a79500bef2526160be9803bac9a48af7fda 100644 (file)
@@ -139,8 +139,8 @@ public:
     return false;
   }
 
-  std::shared_ptr<TCPConnectionToBackend> getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds);
-  std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const struct timeval& now);
+  std::shared_ptr<TCPConnectionToBackend> getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs);
+  std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now);
   void registerActiveDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn);
 
   std::unique_ptr<FDMultiplexer>& getIOMPlexer() const
index 97f7ac777e6b8b9ecc3e3b6365cdad7cd716e91d..b2beb16d7a905756942bb3784e42f61225084bc1 100644 (file)
@@ -28,6 +28,11 @@ struct ProxyProtocolValue
 {
   std::string content;
   uint8_t type;
+
+  bool operator==(const ProxyProtocolValue& rhs) const
+  {
+    return type == rhs.type && content == rhs.content;
+  }
 };
 
 static const size_t s_proxyProtocolMinimumHeaderSize = 16;