]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix the addition of the proxy protocol payload when reconnecting
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 18 Feb 2021 16:39:04 +0000 (17:39 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 2 Mar 2021 10:39:54 +0000 (11:39 +0100)
pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh

index 53f23ecd352e7cc40fd7317c08bd42c67a9fc7a0..0258585445737441e2f4b2329a0db67fe99fb70a 100644 (file)
@@ -133,9 +133,16 @@ public:
     }
   }
 
-  static void clear()
+  static size_t clear()
   {
+    size_t count = 0;
+    for (const auto downstream : t_downstreamConnections) {
+      count += downstream.second.size();
+    }
+
     t_downstreamConnections.clear();
+
+    return count;
   }
 
 private:
@@ -175,9 +182,9 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState()
   d_handler.close();
 }
 
-void IncomingTCPConnectionState::clearAllDownstreamConnections()
+size_t IncomingTCPConnectionState::clearAllDownstreamConnections()
 {
-  DownstreamConnectionsManager::clear();
+  return DownstreamConnectionsManager::clear();
 }
 
 std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now)
@@ -699,16 +706,17 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
     downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues));
   }
 
+  TCPQuery query(std::move(state->d_buffer), std::move(ids));
   if (proxyProtocolPayloadAdded) {
-    downstreamConnection->setProxyProtocolPayloadAdded(true);
+    query.d_proxyProtocolPayloadAdded = true;
   }
   else {
-    downstreamConnection->setProxyProtocolPayload(std::move(proxyProtocolPayload));
+    query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
   }
 
   ++state->d_currentQueriesCount;
-  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).getName(), state->d_proxiedRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
-  downstreamConnection->queueQuery(TCPQuery(std::move(state->d_buffer), std::move(ids)), downstreamConnection);
+  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).getName(), state->d_proxiedRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName());
+  downstreamConnection->queueQuery(std::move(query), downstreamConnection);
 }
 
 void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
@@ -935,6 +943,7 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionS
         ++state->d_ci.cs->tcpDiedReadingQuery;
       }
       else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+        /* unlikely to happen here, the exception should be handled in sendResponse() */
         ++state->d_ci.cs->tcpDiedSendingResponse;
       }
 
index 74f5fe46f7c3f0fbbef5c7945ad5a0e0183658e4..204717950ac2e888316bc6dad0a6de337bde8ff8 100644 (file)
@@ -61,6 +61,9 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend
 
   DEBUGLOG("query sent to backend");
   /* request sent ! */
+  if (conn->d_currentQuery.d_proxyProtocolPayloadAdded) {
+    conn->d_proxyProtocolPayloadSent = true;
+  }
   conn->incQueries();
   conn->d_currentPos = 0;
 
@@ -211,9 +214,9 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
               iostate = queueNextQuery(conn);
             }
 
-            if (!conn->d_proxyProtocolPayloadAdded && !conn->d_proxyProtocolPayload.empty()) {
-              conn->d_currentQuery.d_buffer.insert(conn->d_currentQuery.d_buffer.begin(), conn->d_proxyProtocolPayload.begin(), conn->d_proxyProtocolPayload.end());
-              conn->d_proxyProtocolPayloadAdded = true;
+            if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_proxyProtocolPayload.empty()) {
+              conn->d_currentQuery.d_buffer.insert(conn->d_currentQuery.d_buffer.begin(), conn->d_currentQuery.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_proxyProtocolPayload.end());
+              conn->d_currentQuery.d_proxyProtocolPayloadAdded = true;
             }
 
             reconnected = true;
@@ -273,16 +276,15 @@ void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr<TCPCon
     d_state = State::sendingQueryToBackend;
     d_currentPos = 0;
     d_currentQuery = std::move(query);
-    if (!d_proxyProtocolPayloadAdded && !d_proxyProtocolPayload.empty()) {
-      d_currentQuery.d_buffer.insert(d_currentQuery.d_buffer.begin(), d_proxyProtocolPayload.begin(), d_proxyProtocolPayload.end());
-      d_proxyProtocolPayloadAdded = true;
+    if (needProxyProtocolPayload() && !d_currentQuery.d_proxyProtocolPayloadAdded && !d_currentQuery.d_proxyProtocolPayload.empty()) {
+      d_currentQuery.d_buffer.insert(d_currentQuery.d_buffer.begin(), d_currentQuery.d_proxyProtocolPayload.begin(), d_currentQuery.d_proxyProtocolPayload.end());
+      d_currentQuery.d_proxyProtocolPayloadAdded = true;
     }
 
     struct timeval now;
     gettimeofday(&now, 0);
 
     handleIO(sharedSelf, now);
-    // d_ioState->update(IOState::NeedWrite, handleIOCallback, sharedSelf, getBackendWriteTTD(now));
   }
   else {
     DEBUGLOG("Adding new query to the queue because we are in state "<<(int)d_state);
@@ -301,6 +303,7 @@ bool TCPConnectionToBackend::reconnect()
   }
 
   d_fresh = true;
+  d_proxyProtocolPayloadSent = false;
 
   do {
     vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures);
@@ -504,16 +507,6 @@ uint16_t TCPConnectionToBackend::getQueryIdFromResponse()
   return ntohs(dh.id);
 }
 
-void TCPConnectionToBackend::setProxyProtocolPayload(std::string&& payload)
-{
-  d_proxyProtocolPayload = std::move(payload);
-}
-
-void TCPConnectionToBackend::setProxyProtocolPayloadAdded(bool added)
-{
-  d_proxyProtocolPayloadAdded = added;
-}
-
 void TCPConnectionToBackend::setProxyProtocolValuesSent(std::unique_ptr<std::vector<ProxyProtocolValue>>&& proxyProtocolValuesSent)
 {
   /* if we already have some values, we have already verified they match */
index a26b29ca06f3565b65d7cc61d0504b6d89032cf9..d582ef8fb05f926cf09cf89b6f79b698110dbe55 100644 (file)
@@ -18,6 +18,8 @@ struct TCPQuery
 
   IDState d_idstate;
   PacketBuffer d_buffer;
+  std::string d_proxyProtocolPayload;
+  bool d_proxyProtocolPayloadAdded{false};
 };
 
 class TCPConnectionToBackend;
@@ -165,8 +167,6 @@ public:
   void handleTimeout(const struct timeval& now, bool write);
   void release();
 
-  void setProxyProtocolPayload(std::string&& payload);
-  void setProxyProtocolPayloadAdded(bool added);
   void setProxyProtocolValuesSent(std::unique_ptr<std::vector<ProxyProtocolValue>>&& proxyProtocolValuesSent);
 
   std::string toString() const
@@ -191,6 +191,10 @@ private:
   uint16_t getQueryIdFromResponse();
   bool reconnect();
   void notifyAllQueriesFailed(const struct timeval& now, FailureReason reason);
+  bool needProxyProtocolPayload() const
+  {
+    return !d_proxyProtocolPayloadSent && (d_ds && d_ds->useProxyProtocol);
+  }
 
   boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
   {
@@ -232,7 +236,6 @@ private:
   std::unique_ptr<IOStateHandler> d_ioState{nullptr};
   std::shared_ptr<DownstreamState> d_ds{nullptr};
   std::shared_ptr<IncomingTCPConnectionState> d_clientConn;
-  std::string d_proxyProtocolPayload;
   TCPQuery d_currentQuery;
   struct timeval d_connectionStartTime;
   size_t d_currentPos{0};
@@ -244,5 +247,5 @@ private:
   bool d_enableFastOpen{false};
   bool d_connectionDied{false};
   bool d_usedForXFR{false};
-  bool d_proxyProtocolPayloadAdded{false};
+  bool d_proxyProtocolPayloadSent{false};
 };
index 49428ed055d6590466c0ad3468ef789e805ac0b2..54c3d88f2b37bcfa5036786775b5f941d231cedd 100644 (file)
@@ -151,7 +151,7 @@ public:
     return d_threadData.mplexer;
   }
 
-  static void clearAllDownstreamConnections();
+  static size_t clearAllDownstreamConnections();
 
   static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now);
   static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);