]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Get rid of TCPCrossProtocolQuerySender
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 2 Dec 2022 14:57:17 +0000 (15:57 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 6 Dec 2022 16:18:55 +0000 (17:18 +0100)
We need this construct to deal with cross-protocol queries, like
queries received over TCP or DoT, but forwarded over DoH, because
the thread dealing with the client and the one dealing with the
backend will not be the same in that case, and we do not want to
have different threads touching the same TCP connections.
So we pass the query and response to the correct thread via pipes.
Until now we were allocating an additional object, TCPCrossProtocolQuerySender,
to deal with that case, but I noticed that the existing IncomingTCPConnectionState
object already does everything we need, except that it needs to
know that the response is a cross-protocol one in order to pass it
via the pipe instead of treating it in a different way. This can be
done by looking if the current thread ID differs from the one that
created this object: if it does, we are dealing with a cross-protocol
response and should pass it via the pipe, and if it does not we
can deal with it directly.
This change saves the need to allocate a new object wrapped in a
shared pointer for each cross-protocol query, which is quite nice.

pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-tcp.hh

index c9f4120b7720c7f0ab04896e7b11e0d6f6d1c49d..d2c4ba79379705e1ca7efd78ad87abb82bd09dcb 100644 (file)
@@ -491,6 +491,11 @@ void IncomingTCPConnectionState::updateIO(std::shared_ptr<IncomingTCPConnectionS
 /* called from the backend code when a new response has been received */
 void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response)
 {
+  if (std::this_thread::get_id() != d_mainThreadID) {
+    handleCrossProtocolResponse(now, std::move(response));
+    return;
+  }
+
   std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
 
   if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) {
@@ -566,66 +571,11 @@ struct TCPCrossProtocolResponse
   struct timeval d_now;
 };
 
-class TCPCrossProtocolQuerySender : public TCPQuerySender
-{
-public:
-  TCPCrossProtocolQuerySender(std::shared_ptr<IncomingTCPConnectionState>& state): d_state(state)
-  {
-  }
-
-  bool active() const override
-  {
-    return d_state->active();
-  }
-
-  const ClientState* getClientState() const override
-  {
-    return d_state->getClientState();
-  }
-
-  void handleResponse(const struct timeval& now, TCPResponse&& response) override
-  {
-    if (d_state->d_threadData.crossProtocolResponsesPipe == -1) {
-      throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
-    }
-
-    auto ptr = new TCPCrossProtocolResponse(std::move(response), d_state, now);
-    static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
-    ssize_t sent = write(d_state->d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr));
-    if (sent != sizeof(ptr)) {
-      if (errno == EAGAIN || errno == EWOULDBLOCK) {
-        ++g_stats.tcpCrossProtocolResponsePipeFull;
-        vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
-      }
-      else {
-        vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
-      }
-      delete ptr;
-    }
-  }
-
-  void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
-  {
-    handleResponse(now, std::move(response));
-  }
-
-  void notifyIOError(IDState&& query, const struct timeval& now) override
-  {
-    TCPResponse response(PacketBuffer(), std::move(query), nullptr);
-    handleResponse(now, std::move(response));
-  }
-
-private:
-  std::shared_ptr<IncomingTCPConnectionState> d_state;
-};
-
 class TCPCrossProtocolQuery : public CrossProtocolQuery
 {
 public:
-  TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState>& ds, std::shared_ptr<TCPCrossProtocolQuerySender>& sender): d_sender(sender)
+  TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender))
   {
-    query = InternalQuery(std::move(buffer), std::move(ids));
-    downstream = ds;
     proxyProtocolPayloadSize = 0;
   }
 
@@ -639,9 +589,31 @@ public:
   }
 
 private:
-  std::shared_ptr<TCPCrossProtocolQuerySender> d_sender;
+  std::shared_ptr<IncomingTCPConnectionState> d_sender;
 };
 
+void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response)
+{
+  if (d_threadData.crossProtocolResponsesPipe == -1) {
+    throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender");
+  }
+
+  std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
+  auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now);
+  static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail");
+  ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr));
+  if (sent != sizeof(ptr)) {
+    if (errno == EAGAIN || errno == EWOULDBLOCK) {
+      ++g_stats.tcpCrossProtocolResponsePipeFull;
+      vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full");
+    }
+    else {
+      vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror());
+    }
+    delete ptr;
+  }
+}
+
 static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
 {
   if (state->d_querySize < sizeof(dnsheader)) {
@@ -784,8 +756,7 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, cons
       proxyProtocolPayload = getProxyProtocolPayload(dq);
     }
 
-    auto incoming = std::make_shared<TCPCrossProtocolQuerySender>(state);
-    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, incoming);
+    auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state);
     cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
 
     ds->passCrossProtocolQuery(std::move(cpq));
index 6fa76ab6a1a2415f49a9c6260ffaa5c280146932..c128a56e813e1d66ba1b46af86907004b89a3ea2 100644 (file)
@@ -19,7 +19,7 @@ public:
 class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
 {
 public:
-  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData)
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_mainThreadID(std::this_thread::get_id())
   {
     d_origDest.reset();
     d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
@@ -125,6 +125,8 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
   void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
   void notifyIOError(IDState&& query, const struct timeval& now) override;
 
+  void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);
+
   void terminateClientConnection();
   void queueQuery(TCPQuery&& query);
 
@@ -170,6 +172,7 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
   size_t d_proxyProtocolNeed{0};
   size_t d_queriesCount{0};
   size_t d_currentQueriesCount{0};
+  std::thread::id d_mainThreadID;
   uint16_t d_querySize{0};
   State d_state{State::doingHandshake};
   bool d_isXFR{false};
index b7b78582d17b1e3aa538c66090d3a5258c94a68c..6fb1b827a209aab1fd988da657b5f196d7b303aa 100644 (file)
@@ -159,6 +159,11 @@ struct CrossProtocolQuery
   {
   }
 
+  CrossProtocolQuery(InternalQuery&& query_, std::shared_ptr<DownstreamState>& downstream_) :
+    query(std::move(query_)), downstream(downstream_)
+  {
+  }
+
   CrossProtocolQuery(CrossProtocolQuery&& rhs) = delete;
   virtual ~CrossProtocolQuery()
   {