]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Actually try to read before checking if the socket is readable
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 3 Apr 2019 15:35:41 +0000 (17:35 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 4 Apr 2019 09:54:06 +0000 (11:54 +0200)
We need to because the TLS layer might already have data waiting
for us, while there might not be anything left on the OS-level
buffer associated to the socket.
If we don't ask the TLS layer, we might wait indefinitely for
something to arrive while the client has already sent everything,
and it's just waiting for us because the TLS record has been read.

pdns/dnsdist-tcp.cc

index ebb12c8c31850dc6deb9ce36a1139abfbeaf4de5..0e81173a7bc2f83a0a68572edf5701922bdd218f 100644 (file)
@@ -332,7 +332,7 @@ static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param
 class IncomingTCPConnectionState
 {
 public:
-  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, time_t now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now), d_connectionStartTime(now)
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
   {
     d_ids.origDest.reset();
     d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
@@ -405,7 +405,7 @@ public:
     }
 
     if (g_maxTCPConnectionDuration > 0) {
-      auto elapsed = now.tv_sec - d_connectionStartTime;
+      auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
       if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
         return now;
       }
@@ -452,7 +452,7 @@ public:
     }
 
     if (g_maxTCPConnectionDuration > 0) {
-      auto elapsed = res.tv_sec - d_connectionStartTime;
+      auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
       if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
         return res;
       }
@@ -489,8 +489,8 @@ public:
     if (maxConnectionDuration) {
       time_t curtime = now.tv_sec;
       unsigned int elapsed = 0;
-      if (curtime > d_connectionStartTime) { // To prevent issues when time goes backward
-        elapsed = curtime - d_connectionStartTime;
+      if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
+        elapsed = curtime - d_connectionStartTime.tv_sec;
       }
       if (elapsed >= maxConnectionDuration) {
         return true;
@@ -511,9 +511,9 @@ public:
   TCPIOHandler d_handler;
   std::unique_ptr<Socket> d_downstreamSocket{nullptr};
   std::shared_ptr<DownstreamState> d_ds{nullptr};
+  struct timeval d_connectionStartTime;
   size_t d_currentPos{0};
   size_t d_queriesCount{0};
-  time_t d_connectionStartTime;
   unsigned int d_remainingTime{0};
   uint16_t d_querySize{0};
   uint16_t d_responseSize{0};
@@ -530,6 +530,7 @@ public:
 
 static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
 static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
+static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
 
 static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state)
 {
@@ -540,6 +541,8 @@ static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& stat
     state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
     state->d_currentPos = 0;
     //cerr<<__func__<<": add read client FD "<<state->d_ci.fd<<endl;
+    // XXX: if we ever do TLS toward the backend, we need to try to read right away
+    // because the TLS layer might have more bits already waiting for us
     handleNewIOState(state, IOState::NeedRead, state->d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendReadTTD());
     return;
   }
@@ -557,8 +560,8 @@ static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& stat
   }
 
   state->resetForNewQuery();
-  //cerr<<__func__<<": add read client FD "<<state->d_ci.fd<<endl;
-  handleNewIOState(state, IOState::NeedRead, state->d_ci.fd, handleIOCallback, state->getClientReadTTD(now));
+
+  handleIO(state, now);
 }
 
 static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
@@ -814,6 +817,8 @@ static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param
 
   IOState iostate = IOState::Done;
   bool connectionDied = false;
+  struct timeval now;
+  gettimeofday(&now, 0);
 
   try {
     if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
@@ -921,17 +926,11 @@ static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param
   }
 }
 
-static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
 {
-  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
-  if (fd != state->d_ci.fd) {
-    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
-  }
-
+  int fd = state->d_ci.fd;
   IOState iostate = IOState::Done;
 
-  struct timeval now;
-  gettimeofday(&now, 0);
   if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
     vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
     handleNewIOState(state, IOState::Done, fd, handleIOCallback);
@@ -974,7 +973,7 @@ static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
     }
 
     if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
-      iostate = state->d_handler.tryWrite(state->d_buffer, state->d_currentPos, state->d_buffer.size());
+      iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
       if (iostate == IOState::Done) {
         handleResponseSent(state);
         return;
@@ -1020,6 +1019,18 @@ static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
   }
 }
 
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (fd != state->d_ci.fd) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
+  }
+  struct timeval now;
+  gettimeofday(&now, 0);
+
+  handleIO(state, now);
+}
+
 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
 {
   auto threadData = boost::any_cast<TCPClientThreadData*>(param);
@@ -1047,13 +1058,12 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param
 
   struct timeval now;
   gettimeofday(&now, 0);
-  auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now.tv_sec);
+  auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now);
 
   /* let's update the remaining time */
   state->d_remainingTime = g_maxTCPConnectionDuration;
 
-  /* we could try reading right away, but let's not for now */
-  handleNewIOState(state, IOState::NeedRead, state->d_ci.fd, handleIOCallback, state->getClientReadTTD(now));
+  handleIO(state, now);
 }
 
 void tcpClientThread(int pipefd)