]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: TCP refactoring using an event-based logic
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 28 Feb 2019 14:39:40 +0000 (15:39 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 4 Apr 2019 09:54:02 +0000 (11:54 +0200)
12 files changed:
pdns/dnsdist-ecs.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-idstate.cc [new file with mode: 0644]
pdns/dnsdistdist/tcpiohandler.cc
pdns/iputils.cc
pdns/iputils.hh
pdns/mplexer.hh
pdns/sstuff.hh
pdns/tcpiohandler.hh

index 9390a55df858d1164d0e832a3066e803c379da2e..5e8974d6983c9d9fdcd47aac8c78ee1b1273cba5 100644 (file)
@@ -257,10 +257,10 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload
   dh.d_class = htons(udpPayloadSize);
   static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
   memcpy(&dh.d_ttl, &edns0, sizeof edns0);
-  dh.d_clen = htons((uint16_t) optRData.length());
+  dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
   res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
-  res.assign((const char *) &name, sizeof name);
-  res.append((const char *) &dh, sizeof dh);
+  res.assign(reinterpret_cast<const char *>(&name), sizeof name);
+  res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
   res.append(optRData.c_str(), optRData.length());
 }
 
index bf66072fff965425ed492e24bf444722f041820c..b8e8d591d6190ce58282e35522a64866b6d76a7f 100644 (file)
@@ -35,6 +35,8 @@
 #include <atomic>
 #include <netinet/tcp.h>
 
+#include "sstuff.hh"
+
 using std::thread;
 using std::atomic;
 
@@ -53,42 +55,89 @@ using std::atomic;
    Let's start naively.
 */
 
-static int setupTCPDownstream(shared_ptr<DownstreamState> ds, uint16_t& downstreamFailures)
+static thread_local map<ComboAddress, std::deque<std::unique_ptr<Socket>>> t_downstreamSockets;
+static std::mutex tcpClientsCountMutex;
+static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
+uint64_t g_maxTCPQueuedConnections{1000};
+size_t g_maxTCPQueriesPerConn{0};
+size_t g_maxTCPConnectionDuration{0};
+size_t g_maxTCPConnectionsPerClient{0};
+bool g_useTCPSinglePipe{false};
+std::atomic<uint16_t> g_downstreamTCPCleanupInterval{60};
+
+static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState> ds, uint16_t& downstreamFailures, int timeout)
 {
+  std::unique_ptr<Socket> result;
+
   do {
     vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
-    int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
+    result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
     try {
       if (!IsAnyAddress(ds->sourceAddr)) {
-        SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
+        SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
 #ifdef IP_BIND_ADDRESS_NO_PORT
         if (ds->ipBindAddrNoPort) {
-          SSetsockopt(sock, SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+          SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
         }
 #endif
-        SBind(sock, ds->sourceAddr);
+        result->bind(ds->sourceAddr, false);
       }
-      setNonBlocking(sock);
+      result->setNonBlocking();
 #ifdef MSG_FASTOPEN
       if (!ds->tcpFastOpen) {
-        SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+        SConnectWithTimeout(result->getHandle(), ds->remote, timeout);
       }
 #else
-      SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+      SConnectWithTimeout(result->getHandle(), ds->remote, timeout);
 #endif /* MSG_FASTOPEN */
-      return sock;
+      return result;
     }
     catch(const std::runtime_error& e) {
-      /* don't leak our file descriptor if SConnect() (for example) throws */
+      vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
       downstreamFailures++;
-      close(sock);
       if (downstreamFailures > ds->retries) {
         throw;
       }
     }
   } while(downstreamFailures <= ds->retries);
 
-  return -1;
+  return nullptr;
+}
+
+static std::unique_ptr<Socket> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, bool& isFresh)
+{
+  std::unique_ptr<Socket> result;
+
+  const auto& it = t_downstreamSockets.find(ds->remote);
+  if (it != t_downstreamSockets.end()) {
+    auto& list = it->second;
+    if (!list.empty()) {
+      result = std::move(list.front());
+      list.pop_front();
+      isFresh = false;
+      return result;
+    }
+  }
+
+  isFresh = true;
+  return setupTCPDownstream(ds, downstreamFailures, 0);
+}
+
+static void releaseDownstreamConnection(std::shared_ptr<DownstreamState>& ds, std::unique_ptr<Socket>&& socket)
+{
+  const auto& it = t_downstreamSockets.find(ds->remote);
+  if (it != t_downstreamSockets.end()) {
+    auto& list = it->second;
+    if (list.size() >= 20) {
+      /* too many connections queued already */
+      socket.reset();
+      return;
+    }
+    list.push_back(std::move(socket));
+  }
+  else {
+    t_downstreamSockets[ds->remote].push_back(std::move(socket));
+  }
 }
 
 struct ConnectionInfo
@@ -96,6 +145,14 @@ struct ConnectionInfo
   ConnectionInfo(): cs(nullptr), fd(-1)
   {
   }
+  ConnectionInfo(ConnectionInfo&& rhs)
+  {
+    remote = rhs.remote;
+    cs = rhs.cs;
+    rhs.cs = nullptr;
+    fd = rhs.fd;
+    rhs.fd = -1;
+  }
 
   ConnectionInfo(const ConnectionInfo& rhs) = delete;
   ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
@@ -123,15 +180,6 @@ struct ConnectionInfo
   int fd{-1};
 };
 
-uint64_t g_maxTCPQueuedConnections{1000};
-size_t g_maxTCPQueriesPerConn{0};
-size_t g_maxTCPConnectionDuration{0};
-size_t g_maxTCPConnectionsPerClient{0};
-static std::mutex tcpClientsCountMutex;
-static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
-bool g_useTCPSinglePipe{false};
-std::atomic<uint16_t> g_downstreamTCPCleanupInterval{60};
-
 void tcpClientThread(int pipefd);
 
 static void decrementTCPClientCount(const ComboAddress& client)
@@ -201,392 +249,814 @@ void TCPClientCollection::addTCPClientThread()
   ++d_numthreads;
 }
 
-static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
-try
+static void cleanupClosedTCPConnections()
 {
-  uint16_t raw;
-  size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
-  if(ret != sizeof raw)
-    return false;
-  *len = ntohs(raw);
-  return true;
-}
-catch(...) {
-  return false;
-}
+  for(auto dsIt = t_downstreamSockets.begin(); dsIt != t_downstreamSockets.end(); ) {
+    for (auto socketIt = dsIt->second.begin(); socketIt != dsIt->second.end(); ) {
+      if (*socketIt && isTCPSocketUsable((*socketIt)->getHandle())) {
+        ++socketIt;
+      }
+      else {
+        socketIt = dsIt->second.erase(socketIt);
+      }
+    }
 
-static bool getNonBlockingMsgLenFromClient(TCPIOHandler& handler, uint16_t* len)
-try
-{
-  uint16_t raw;
-  size_t ret = handler.read(&raw, sizeof raw, g_tcpRecvTimeout);
-  if(ret != sizeof raw)
-    return false;
-  *len = ntohs(raw);
-  return true;
-}
-catch(...) {
-  return false;
+    if (!dsIt->second.empty()) {
+      ++dsIt;
+    }
+    else {
+      dsIt = t_downstreamSockets.erase(dsIt);
+    }
+  }
 }
 
-static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
+/* Tries to read exactly toRead bytes into the buffer, starting at position pos.
+   Updates pos everytime a successful read occurs,
+   throws an std::runtime_error in case of IO error,
+   return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+   would block.
+*/
+// XXX could probably be implemented as a TCPIOHandler
+IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
 {
-  if (maxConnectionDuration) {
-    time_t curtime = time(nullptr);
-    unsigned int elapsed = 0;
-    if (curtime > start) { // To prevent issues when time goes backward
-      elapsed = curtime - start;
+  size_t got = 0;
+  do {
+    ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
+    if (res == 0) {
+      throw runtime_error("EOF while reading message");
     }
-    if (elapsed >= maxConnectionDuration) {
-      return true;
+    if (res < 0) {
+      if (errno == EAGAIN || errno == EWOULDBLOCK) {
+        return IOState::NeedRead;
+      }
+      else {
+        throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
+      }
     }
-    remainingTime = maxConnectionDuration - elapsed;
+
+    pos += static_cast<size_t>(res);
+    got += static_cast<size_t>(res);
   }
-  return false;
+  while (got < toRead);
+
+  return IOState::Done;
 }
 
-static void cleanupClosedTCPConnections(std::map<ComboAddress,int>& sockets)
+std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+
+class TCPClientThreadData
+{
+public:
+  TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+  {
+  }
+
+  LocalHolders holders;
+  LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
+  std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+};
+
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+
+class IncomingTCPConnectionState
 {
-  for(auto it = sockets.begin(); it != sockets.end(); ) {
-    if (isTCPSocketUsable(it->second)) {
-      ++it;
+public:
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, time_t now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now), d_connectionStartTime(now)
+  {
+    d_ids.origDest.reset();
+    d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
+    socklen_t socklen = d_ids.origDest.getSocklen();
+    if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
+      d_ids.origDest = d_ci.cs->local;
+    }
+  }
+
+  IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
+  IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
+
+  ~IncomingTCPConnectionState()
+  {
+    decrementTCPClientCount(d_ci.remote);
+
+    if (d_ds != nullptr) {
+      if (d_outstanding) {
+        --d_ds->outstanding;
+      }
+
+      if (d_downstreamSocket) {
+        try {
+          if (d_lastIOState == IOState::NeedRead) {
+            cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamSocket->getHandle()<<endl;
+            d_threadData.mplexer->removeReadFD(d_downstreamSocket->getHandle());
+          }
+          else if (d_lastIOState == IOState::NeedWrite) {
+            cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamSocket->getHandle()<<endl;
+            d_threadData.mplexer->removeWriteFD(d_downstreamSocket->getHandle());
+          }
+        }
+        catch(const FDMultiplexerException& e) {
+          vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+        }
+      }
+    }
+
+    try {
+      if (d_lastIOState == IOState::NeedRead) {
+        cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
+        d_threadData.mplexer->removeReadFD(d_ci.fd);
+      }
+      else if (d_lastIOState == IOState::NeedWrite) {
+        cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
+        d_threadData.mplexer->removeWriteFD(d_ci.fd);
+      }
+    }
+    catch(const FDMultiplexerException& e) {
+      vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+    }
+  }
+
+  void resetForNewQuery()
+  {
+    d_buffer.resize(sizeof(uint16_t));
+    d_currentPos = 0;
+    d_querySize = 0;
+    d_responseSize = 0;
+    d_downstreamFailures = 0;
+    d_state = State::readingQuerySize;
+    d_lastIOState = IOState::Done;
+  }
+
+  boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
+      return boost::none;
+    }
+
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = now.tv_sec - d_connectionStartTime;
+      if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+        return now;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
+        now.tv_sec += remaining;
+        return now;
+      }
+    }
+
+    now.tv_sec += g_tcpRecvTimeout;
+    return now;
+  }
+
+  boost::optional<struct timeval> getBackendReadTTD() const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() without any backend selected");
+    }
+    if (d_ds->tcpRecvTimeout == 0) {
+      return boost::none;
+    }
+
+    struct timeval res;
+    gettimeofday(&res, 0);
+
+    res.tv_sec += d_ds->tcpRecvTimeout;
+
+    return res;
+  }
+
+  boost::optional<struct timeval> getClientWriteTTD(boost::optional<struct timeval> now=boost::none) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
+      return boost::none;
+    }
+
+    struct timeval res;
+    if (now) {
+      res = *now;
     }
     else {
-      close(it->second);
-      it = sockets.erase(it);
+      gettimeofday(&res, 0);
+    }
+
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = res.tv_sec - d_connectionStartTime;
+      if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
+        return res;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
+        res.tv_sec += remaining;
+        return res;
+      }
+    }
+
+    res.tv_sec += g_tcpSendTimeout;
+    return res;
+  }
+
+  boost::optional<struct timeval> getBackendWriteTTD() const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() called without any backend selected");
+    }
+    if (d_ds->tcpSendTimeout == 0) {
+      return boost::none;
     }
+
+    struct timeval res;
+    gettimeofday(&res, 0);
+
+    res.tv_sec += d_ds->tcpSendTimeout;
+
+    return res;
   }
+
+  bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval now)
+  {
+    if (maxConnectionDuration) {
+      time_t curtime = now.tv_sec;
+      unsigned int elapsed = 0;
+      if (curtime > d_connectionStartTime) { // To prevent issues when time goes backward
+        elapsed = curtime - d_connectionStartTime;
+      }
+      if (elapsed >= maxConnectionDuration) {
+        return true;
+      }
+      d_remainingTime = maxConnectionDuration - elapsed;
+    }
+
+    return false;
+  }
+
+  enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
+
+  std::vector<uint8_t> d_buffer;
+  std::vector<uint8_t> d_responseBuffer;
+  TCPClientThreadData& d_threadData;
+  IDState d_ids;
+  ConnectionInfo d_ci;
+  TCPIOHandler d_handler;
+  std::unique_ptr<Socket> d_downstreamSocket{nullptr};
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
+  size_t d_currentPos{0};
+  size_t d_queriesCount{0};
+  time_t d_connectionStartTime;
+  unsigned int d_remainingTime{0};
+  uint16_t d_querySize{0};
+  uint16_t d_responseSize{0};
+  uint16_t d_downstreamFailures{0};
+  State d_state{State::doingHandshake};
+  IOState d_lastIOState{IOState::Done};
+  bool d_freshDownstreamConnection{false};
+  bool d_readingFirstQuery{true};
+  bool d_outstanding{false};
+  bool d_firstResponsePacket{true};
+  bool d_isXFR{false};
+  bool d_xfrStarted{false};
+};
+
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
+
+static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+  handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
+
+  if (state->d_isXFR && state->d_downstreamSocket) {
+    /* we need to resume reading from the backend! */
+    state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+    state->d_currentPos = 0;
+    //cerr<<__func__<<": add read client FD "<<state->d_ci.fd<<endl;
+    handleNewIOState(state, IOState::NeedRead, state->d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendReadTTD());
+    return;
+  }
+
+  if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
+    return;
+  }
+
+  struct timeval now;
+  gettimeofday(&now, 0);
+  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+    return;
+  }
+
+  state->resetForNewQuery();
+  //cerr<<__func__<<": add read client FD "<<state->d_ci.fd<<endl;
+  handleNewIOState(state, IOState::NeedRead, state->d_ci.fd, handleIOCallback, state->getClientReadTTD(now));
 }
 
-std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+  state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
+  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+     that could occur if we had to deal with the size during the processing,
+     especially alignment issues */
+  state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
 
-void tcpClientThread(int pipefd)
+  state->d_currentPos = 0;
+
+  auto iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
+  if (iostate == IOState::Done) {
+
+    handleResponseSent(state);
+    return;
+  }
+  else {
+    //cerr<<__func__<<": adding client write FD "<<state->d_ci.fd<<endl;
+    handleNewIOState(state, IOState::NeedWrite, state->d_ci.fd, handleIOCallback, state->getClientWriteTTD());
+  }
+}
+
+static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
 {
-  /* we get launched with a pipe on which we receive file descriptors from clients that we own
-     from that point on */
+  if (state->d_responseSize < sizeof(dnsheader)) {
+    return;
+  }
 
-  setThreadName("dnsdist/tcpClie");
+  auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
+  unsigned int consumed;
+  if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
+    return;
+  }
+  state->d_firstResponsePacket = false;
+
+  if (state->d_outstanding) {
+    --state->d_ds->outstanding;
+    state->d_outstanding = false;
+  }
+
+  auto dh = reinterpret_cast<struct dnsheader*>(response);
+  uint16_t addRoom = 0;
+  DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
+  if (dr.dnsCryptQuery) {
+    addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+  }
 
-  bool outstanding = false;
-  time_t lastTCPCleanup = time(nullptr);
-  
-  LocalHolders holders;
-  auto localRespRulactions = g_resprulactions.getLocal();
-  /* when the answer is encrypted in place, we need to get a copy
-     of the original header before encryption to fill the ring buffer */
   dnsheader cleartextDH;
+  memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
 
-  map<ComboAddress,int> sockets;
-  for(;;) {
-    ConnectionInfo* citmp, ci;
+  std::vector<uint8_t> rewrittenResponse;
+  size_t responseSize = state->d_responseBuffer.size();
+  if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+    return;
+  }
 
-    try {
-      readn2(pipefd, &citmp, sizeof(citmp));
-    }
-    catch(const std::runtime_error& e) {
-      throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
-    }
+  if (!rewrittenResponse.empty()) {
+    /* responseSize has been updated as well but we don't really care since it will match
+       the capacity of rewrittenResponse anyway */
+    state->d_responseBuffer = std::move(rewrittenResponse);
+    state->d_responseSize = state->d_responseBuffer.size();
+  } else {
+    /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
+    state->d_responseBuffer.resize(state->d_responseSize);
+  }
+
+  if (state->d_isXFR && !state->d_xfrStarted) {
+    /* don't bother parsing the content of the response for now */
+    state->d_xfrStarted = true;
+  }
 
-    g_tcpclientthreads->decrementQueuedCount();
-    ci=std::move(*citmp);
-    delete citmp;
+  sendResponse(state);
+
+  ++g_stats.responses;
+  struct timespec answertime;
+  gettime(&answertime);
+  double udiff = state->d_ids.sentTime.udiff();
+  g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote);
+}
+
+static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+  auto ds = state->d_ds;
+  state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
+  state->d_currentPos = 0;
+  state->d_firstResponsePacket = true;
+  state->d_downstreamSocket.reset();
+
+  if (state->d_xfrStarted) {
+    /* sorry, but we are not going to resume a XFR if we have already sent some packets
+       to the client */
+    return;
+  }
 
-    uint16_t qlen, rlen;
-    vector<uint8_t> rewrittenResponse;
-    shared_ptr<DownstreamState> ds;
-    size_t queriesCount = 0;
-    time_t connectionStartTime = time(nullptr);
-    std::vector<char> queryBuffer;
-    std::vector<char> answerBuffer;
+  while (state->d_downstreamFailures < state->d_ds->retries)
+  {
+    state->d_downstreamSocket = getConnectionToDownstream(ds, state->d_downstreamFailures, state->d_freshDownstreamConnection);
 
-    ComboAddress dest;
-    dest.reset();
-    dest.sin4.sin_family = ci.remote.sin4.sin_family;
-    socklen_t socklen = dest.getSocklen();
-    if (getsockname(ci.fd, (sockaddr*)&dest, &socklen)) {
-      dest = ci.cs->local;
+    if (!state->d_downstreamSocket) {
+      vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+      return;
     }
 
-    try {
-      TCPIOHandler handler(ci.fd, g_tcpRecvTimeout, ci.cs->tlsFrontend ? ci.cs->tlsFrontend->getContext() : nullptr, connectionStartTime);
+    //cerr<<__func__<<": add write backend FD "<<state->d_downstreamSocket->getHandle()<<endl;
+    handleNewIOState(state, IOState::NeedWrite, state->d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendWriteTTD());
+    return;
+  }
 
-      for(;;) {
-        unsigned int remainingTime = 0;
-        ds = nullptr;
-        outstanding = false;
+  vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+}
 
-        if(!getNonBlockingMsgLenFromClient(handler, &qlen)) {
-          break;
-        }
+static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+  if (state->d_querySize < sizeof(dnsheader)) {
+    ++g_stats.nonCompliantQueries;
+    return;
+  }
 
-        queriesCount++;
+  state->d_readingFirstQuery = false;
+  ++state->d_queriesCount;
+  ++state->d_ci.cs->queries;
+  ++g_stats.queries;
+
+  /* we need an accurate ("real") value for the response and
+     to store into the IDS, but not for insertion into the
+     rings for example */
+  struct timespec now;
+  struct timespec queryRealTime;
+  gettime(&now);
+  gettime(&queryRealTime, true);
+
+  auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
+  auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
+  if (dnsCryptResponse) {
+    state->d_responseBuffer = std::move(*dnsCryptResponse);
+    state->d_responseSize = state->d_responseBuffer.size();
+    sendResponse(state);
+    return;
+  }
 
-        if (qlen < sizeof(dnsheader)) {
-          ++g_stats.nonCompliantQueries;
-          break;
-        }
+  const auto& dh = reinterpret_cast<dnsheader*>(query);
+  if (!checkQueryHeaders(dh)) {
+    return;
+  }
 
-        ci.cs->queries++;
-        ++g_stats.queries;
+  uint16_t qtype, qclass;
+  unsigned int consumed = 0;
+  DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+  DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
+  dq.dnsCryptQuery = std::move(dnsCryptQuery);
 
-        if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
-          vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
-          break;
-        }
+  state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
+  if (state->d_isXFR) {
+    dq.skipCache = true;
+  }
 
-        if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) {
-          vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort());
-          break;
-        }
+  state->d_ds.reset();
+  auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
 
-        /* allocate a bit more memory to be able to spoof the content,
-           or to add ECS without allocating a new buffer */
-        queryBuffer.resize((static_cast<size_t>(qlen) + 512) < 4096 ? (static_cast<size_t>(qlen) + 512) : 4096);
-
-        char* query = &queryBuffer[0];
-        handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
-
-        /* we need an accurate ("real") value for the response and
-           to store into the IDS, but not for insertion into the
-           rings for example */
-        struct timespec now;
-        struct timespec queryRealTime;
-        gettime(&now);
-        gettime(&queryRealTime, true);
-
-        std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
-        auto dnsCryptResponse = checkDNSCryptQuery(*ci.cs, query, qlen, dnsCryptQuery, queryRealTime.tv_sec, true);
-        if (dnsCryptResponse) {
-          handler.writeSizeAndMsg(reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), g_tcpSendTimeout);
-          continue;
-        }
+  if (result == ProcessQueryResult::Drop) {
+    return;
+  }
 
-        struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
-        if (!checkQueryHeaders(dh)) {
-          break;
-        }
+  if (result == ProcessQueryResult::SendAnswer) {
+    state->d_buffer.resize(dq.len);
+    state->d_responseBuffer = std::move(state->d_buffer);
+    state->d_responseSize = state->d_responseBuffer.size();
+    sendResponse(state);
+    return;
+  }
 
-        uint16_t qtype, qclass;
-        unsigned int consumed = 0;
-        DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-        DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
-        dq.dnsCryptQuery = std::move(dnsCryptQuery);
+  if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
+    return;
+  }
 
-        std::shared_ptr<DownstreamState> ds{nullptr};
-        auto result = processQuery(dq, *ci.cs, holders, ds);
+  state->d_buffer.resize(dq.len);
+  setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
 
-        if (result == ProcessQueryResult::Drop) {
-          break;
-        }
+  const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
+  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+     that could occur if we had to deal with the size during the processing,
+     especially alignment issues */
+  state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
+  sendQueryToBackend(state);
+}
 
-        if (result == ProcessQueryResult::SendAnswer) {
-          handler.writeSizeAndMsg(reinterpret_cast<char*>(dq.dh), dq.len, g_tcpSendTimeout);
-          continue;
-        }
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
+{
+  //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
 
-        if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
-          break;
-        }
+  if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
+    state->d_threadData.mplexer->removeReadFD(fd);
+    //cerr<<__func__<<": remove read FD "<<fd<<endl;
+    state->d_lastIOState = IOState::Done;
+  }
+  else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
+    state->d_threadData.mplexer->removeWriteFD(fd);
+    //cerr<<__func__<<": remove write FD "<<fd<<endl;
+    state->d_lastIOState = IOState::Done;
+  }
 
-       int dsock = -1;
-       uint16_t downstreamFailures=0;
-#ifdef MSG_FASTOPEN
-       bool freshConn = true;
-#endif /* MSG_FASTOPEN */
-       if(sockets.count(ds->remote) == 0) {
-         dsock=setupTCPDownstream(ds, downstreamFailures);
-         sockets[ds->remote]=dsock;
-       }
-       else {
-         dsock=sockets[ds->remote];
-#ifdef MSG_FASTOPEN
-         freshConn = false;
-#endif /* MSG_FASTOPEN */
-        }
+  if (iostate == IOState::NeedRead) {
+    if (state->d_lastIOState == IOState::NeedRead) {
+      if (ttd) {
+        /* let's update the TTD ! */
+        state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
+      }
+      return;
+    }
 
-        ds->outstanding++;
-        outstanding = true;
+    state->d_lastIOState = IOState::NeedRead;
+    //cerr<<__func__<<": add read FD "<<fd<<endl;
+    state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
+  }
+  else if (iostate == IOState::NeedWrite) {
+    if (state->d_lastIOState == IOState::NeedWrite) {
+      return;
+    }
 
-      retry:; 
-        if (dsock < 0) {
-          sockets.erase(ds->remote);
-          break;
-        }
+    state->d_lastIOState = IOState::NeedWrite;
+    //cerr<<__func__<<": add write FD "<<fd<<endl;
+    state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
+  }
+  else if (iostate == IOState::Done) {
+    state->d_lastIOState = IOState::Done;
+  }
+}
 
-        if (ds->retries > 0 && downstreamFailures > ds->retries) {
-          vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstreamFailures);
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          break;
-        }
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (state->d_downstreamSocket == nullptr) {
+    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+  }
+  if (fd != state->d_downstreamSocket->getHandle()) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamSocket->getHandle()));
+  }
 
-        try {
-          int socketFlags = 0;
-#ifdef MSG_FASTOPEN
-          if (ds->tcpFastOpen && freshConn) {
-            socketFlags |= MSG_FASTOPEN;
-          }
-#endif /* MSG_FASTOPEN */
-          sendSizeAndMsgWithTimeout(dsock, dq.len, query, ds->tcpSendTimeout, &ds->remote, &ds->sourceAddr, ds->sourceItf, 0, socketFlags);
-        }
-        catch(const runtime_error& e) {
-          vinfolog("Downstream connection to %s died on us (%s), getting a new one!", ds->getName(), e.what());
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          downstreamFailures++;
-          dsock=setupTCPDownstream(ds, downstreamFailures);
-          sockets[ds->remote]=dsock;
-#ifdef MSG_FASTOPEN
-          freshConn=true;
-#endif /* MSG_FASTOPEN */
-          goto retry;
-        }
+  IOState iostate = IOState::Done;
+  bool connectionDied = false;
 
-        bool xfrStarted = false;
-        bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
-        if (isXFR) {
-          dq.skipCache = true;
-        }
-        bool firstPacket=true;
-      getpacket:;
-
-        if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
-         vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-          downstreamFailures++;
-          dsock=setupTCPDownstream(ds, downstreamFailures);
-          sockets[ds->remote]=dsock;
+  try {
+    if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
+      int socketFlags = 0;
 #ifdef MSG_FASTOPEN
-          freshConn=true;
+      if (state->d_ds->tcpFastOpen && state->d_freshDownstreamConnection) {
+        socketFlags |= MSG_FASTOPEN;
+      }
 #endif /* MSG_FASTOPEN */
-          if(xfrStarted) {
-            break;
-          }
-          goto retry;
-        }
 
-        size_t responseSize = rlen;
-        uint16_t addRoom = 0;
-        if (dq.dnsCryptQuery && (UINT16_MAX - rlen) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) {
-          addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+      size_t sent = sendMsgWithTimeout(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, 0, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, 0, socketFlags);
+      if (sent == state->d_buffer.size()) {
+        /* request sent ! */
+        state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+        state->d_currentPos = 0;
+        iostate = IOState::NeedRead;
+        if (!state->d_isXFR) {
+          /* don't bother with the outstanding count for XFR queries */
+          ++state->d_ds->outstanding;
+          state->d_outstanding = true;
         }
+      }
+      else {
+        state->d_currentPos += sent;
+        iostate = IOState::NeedWrite;
+        /* disable fast open on partial write */
+        state->d_freshDownstreamConnection = false;
+      }
+    }
 
-        responseSize += addRoom;
-        answerBuffer.resize(responseSize);
-        char* response = answerBuffer.data();
-        readn2WithTimeout(dsock, response, rlen, ds->tcpRecvTimeout);
-        uint16_t responseLen = rlen;
-        if (outstanding) {
-          /* might be false for {A,I}XFR */
-          --ds->outstanding;
-          outstanding = false;
-        }
+    if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
+      // then we need to allocate a new buffer (new because we might need to re-send the query if the
+      // backend dies on us
+      // We also might need to read and send to the client more than one response in case of XFR (yeah!)
+      // should very likely be a TCPIOHandler d_downstreamHandler
+      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
+        state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
+        state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
+        state->d_currentPos = 0;
+      }
+    }
 
-        if (rlen < sizeof(dnsheader)) {
-          break;
-        }
+    if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
+      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
 
-        consumed = 0;
-        if (firstPacket && !responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote, consumed)) {
-          break;
+        if (state->d_isXFR) {
+          /* Don't reuse the TCP connection after an {A,I}XFR */
+          /* but don't reset it either, we will need to read more messages */
         }
-        firstPacket=false;
-
-        dh = reinterpret_cast<struct dnsheader*>(response);
-        DNSResponse dr(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
-        dr.origFlags = dq.origFlags;
-        dr.ecsAdded = dq.ecsAdded;
-        dr.ednsAdded = dq.ednsAdded;
-        dr.useZeroScope = dq.useZeroScope;
-        dr.packetCache = std::move(dq.packetCache);
-        dr.delayMsec = dq.delayMsec;
-        dr.skipCache = dq.skipCache;
-        dr.cacheKey = dq.cacheKey;
-        dr.cacheKeyNoECS = dq.cacheKeyNoECS;
-        dr.dnssecOK = dq.dnssecOK;
-        dr.tempFailureTTL = dq.tempFailureTTL;
-        dr.qTag = std::move(dq.qTag);
-        dr.subnet = std::move(dq.subnet);
-#ifdef HAVE_PROTOBUF
-        dr.uniqueId = std::move(dq.uniqueId);
-#endif
-        if (dq.dnsCryptQuery) {
-          dr.dnsCryptQuery = std::move(dq.dnsCryptQuery);
+        else {
+          releaseDownstreamConnection(state->d_ds, std::move(state->d_downstreamSocket));
         }
+        fd = -1;
 
-        memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
-        if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
-          break;
-        }
+        handleResponse(state);
+        return;
+      }
+    }
 
-        if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
-          break;
-        }
+    if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
+        state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
+        state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
+      vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
+    }
+  }
+  catch(const std::exception& e) {
+    /* most likely an EOF because the other end closed the connection,
+       but it might also be a real IO error or something else.
+       Let's just drop the connection
+    */
+    vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
+    /* remove this FD from the IO multiplexer */
+    ++state->d_downstreamFailures;
+    if (state->d_outstanding && state->d_ds != nullptr) {
+      --state->d_ds->outstanding;
+    }
+    iostate = IOState::Done;
+    connectionDied = true;
+  }
 
-        if (isXFR) {
-          if (dh->rcode == 0 && dh->ancount != 0) {
-            if (xfrStarted == false) {
-              xfrStarted = true;
-              if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
-                goto getpacket;
-              }
-            }
-            else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
-              goto getpacket;
-            }
-          }
-          /* Don't reuse the TCP connection after an {A,I}XFR */
-          close(dsock);
-          dsock=-1;
-          sockets.erase(ds->remote);
-        }
+  if (iostate == IOState::Done) {
+    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
+  }
+  else {
+    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD() : state->getBackendWriteTTD());
+  }
+
+  if (connectionDied) {
+    sendQueryToBackend(state);
+  }
+}
+
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (fd != state->d_ci.fd) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
+  }
+
+  IOState iostate = IOState::Done;
+
+  struct timeval now;
+  gettimeofday(&now, 0);
+  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+    handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+    return;
+  }
 
-        ++g_stats.responses;
-        switch (dr.dh->rcode) {
-        case RCode::NXDomain:
-           ++g_stats.frontendNXDomain;
-           break;
-        case RCode::ServFail:
-          ++g_stats.frontendServFail;
-          break;
-        case RCode::NoError:
-          ++g_stats.frontendNoError;
-          break;
+  try {
+    if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+      iostate = state->d_handler.tryHandshake();
+      if (iostate == IOState::Done) {
+        state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+      }
+    }
+
+    if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+      if (iostate == IOState::Done) {
+        state->d_state = IncomingTCPConnectionState::State::readingQuery;
+        state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
+        if (state->d_querySize < sizeof(dnsheader)) {
+          /* go away */
+          handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+          return;
         }
-        struct timespec answertime;
-        gettime(&answertime);
-        unsigned int udiff = 1000000.0*DiffTime(now,answertime);
-        g_rings.insertResponse(answertime, ci.remote, qname, dq.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(responseLen), cleartextDH, ds->remote);
 
-        rewrittenResponse.clear();
+        /* allocate a bit more memory to be able to spoof the content,
+           or to add ECS without allocating a new buffer */
+        state->d_buffer.resize(state->d_querySize + 512);
+        state->d_currentPos = 0;
+      }
+    }
+
+    if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+      if (iostate == IOState::Done) {
+        handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+        handleQuery(state);
+        return;
       }
     }
-    catch(const std::exception& e) {
-      vinfolog("Got exception while handling TCP query: %s", e.what());
+
+    if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+      iostate = state->d_handler.tryWrite(state->d_buffer, state->d_currentPos, state->d_buffer.size());
+      if (iostate == IOState::Done) {
+        handleResponseSent(state);
+        return;
+      }
+    }
+
+    if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
+        state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+        state->d_state != IncomingTCPConnectionState::State::readingQuery &&
+        state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
+      vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
     }
-    catch(...) {
+  }
+  catch(const std::exception& e) {
+    /* most likely an EOF because the other end closed the connection,
+       but it might also be a real IO error or something else.
+       Let's just drop the connection
+    */
+    if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
+      vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
     }
+    else {
+      vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+    }
+    /* remove this FD from the IO multiplexer */
+    iostate = IOState::Done;
+  }
+
+  if (iostate == IOState::Done) {
+    handleNewIOState(state, iostate, fd, handleIOCallback);
+  }
+  else {
+    handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+  }
+}
+
+static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
+{
+  auto threadData = boost::any_cast<TCPClientThreadData*>(param);
+
+  ConnectionInfo* citmp{nullptr};
+
+  try {
+    readn2(pipefd, &citmp, sizeof(citmp));
+  }
+  catch(const std::runtime_error& e) {
+    throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
+  }
+
+  g_tcpclientthreads->decrementQueuedCount();
+  auto ci = std::move(*citmp);
+  delete citmp;
+  citmp = nullptr;
+
+  struct timeval now;
+  gettimeofday(&now, 0);
+  auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now.tv_sec);
 
-    vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
+  /* let's update the remaining time */
+  state->d_remainingTime = g_maxTCPConnectionDuration;
 
-    if (ds && outstanding) {
-      outstanding = false;
-      --ds->outstanding;
+  /* we could try reading right away, but let's not for now */
+  handleNewIOState(state, IOState::NeedRead, state->d_ci.fd, handleIOCallback, state->getClientReadTTD(now));
+}
+
+void tcpClientThread(int pipefd)
+{
+  /* we get launched with a pipe on which we receive file descriptors from clients that we own
+     from that point on */
+
+  setThreadName("dnsdist/tcpClie");
+
+  TCPClientThreadData data;
+
+  data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
+  time_t lastTCPCleanup = time(nullptr);
+  struct timeval now;
+  gettimeofday(&now, 0);
+
+  for (;;) {
+    data.mplexer->run(&now);
+
+    if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
+      cleanupClosedTCPConnections();
+      lastTCPCleanup = now.tv_sec;
+    }
+
+    auto expiredReadConns = data.mplexer->getTimeouts(now, false);
+    for(const auto& conn : expiredReadConns) {
+      auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+      if (conn.first == state->d_ci.fd) {
+        vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+      }
+      else if (state->d_ds) {
+        vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
+      }
+      data.mplexer->removeReadFD(conn.first);
+      state->d_lastIOState = IOState::Done;
     }
-    decrementTCPClientCount(ci.remote);
 
-    if (g_downstreamTCPCleanupInterval > 0 && (connectionStartTime > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
-      cleanupClosedTCPConnections(sockets);
-      lastTCPCleanup = time(nullptr);
+    auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
+    for(const auto& conn : expiredWriteConns) {
+      auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+      if (conn.first == state->d_ci.fd) {
+        vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+      }
+      else if (state->d_ds) {
+        vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
+      }
+      data.mplexer->removeWriteFD(conn.first);
+      state->d_lastIOState = IOState::Done;
     }
   }
 }
 
-/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and 
+/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
    they will hand off to worker threads & spawn more of them if required
 */
 void tcpAcceptorThread(void* p)
@@ -596,7 +1066,7 @@ void tcpAcceptorThread(void* p)
   bool tcpClientCountIncremented = false;
   ComboAddress remote;
   remote.sin4.sin_family = cs->local.sin4.sin_family;
-  
+
   g_tcpclientthreads->addTCPClientThread();
 
   auto acl = g_ACL.getLocal();
index 0252b95debdf59d543e91b24f4f65ac2482d0e0c..acb4647786b4687d91c191a316bc86f35e4f72bb 100644 (file)
@@ -574,26 +574,9 @@ try {
         dh->id = ids->origID;
 
         uint16_t addRoom = 0;
-        DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, consumed, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
-        dr.origFlags = ids->origFlags;
-        dr.ecsAdded = ids->ecsAdded;
-        dr.ednsAdded = ids->ednsAdded;
-        dr.useZeroScope = ids->useZeroScope;
-        dr.packetCache = std::move(ids->packetCache);
-        dr.delayMsec = ids->delayMsec;
-        dr.skipCache = ids->skipCache;
-        dr.cacheKey = ids->cacheKey;
-        dr.cacheKeyNoECS = ids->cacheKeyNoECS;
-        dr.dnssecOK = ids->dnssecOK;
-        dr.tempFailureTTL = ids->tempFailureTTL;
-        dr.qTag = std::move(ids->qTag);
-        dr.subnet = std::move(ids->subnet);
-#ifdef HAVE_PROTOBUF
-        dr.uniqueId = std::move(ids->uniqueId);
-#endif
-        if (ids->dnsCryptQuery) {
+        DNSResponse dr = makeDNSResponseFromIDState(*ids, dh, sizeof(packet), responseLen, false);
+        if (dr.dnsCryptQuery) {
           addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
-          dr.dnsCryptQuery = std::move(ids->dnsCryptQuery);
         }
 
         memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
@@ -1577,24 +1560,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 
     ids->cs = &cs;
     ids->origID = dh->id;
-    ids->origRemote = remote;
-    ids->sentTime.set(queryRealTime);
-    ids->qname = std::move(qname);
-    ids->qtype = dq.qtype;
-    ids->qclass = dq.qclass;
-    ids->delayMsec = dq.delayMsec;
-    ids->tempFailureTTL = dq.tempFailureTTL;
-    ids->origFlags = dq.origFlags;
-    ids->cacheKey = dq.cacheKey;
-    ids->cacheKeyNoECS = dq.cacheKeyNoECS;
-    ids->subnet = dq.subnet;
-    ids->skipCache = dq.skipCache;
-    ids->packetCache = dq.packetCache;
-    ids->ednsAdded = dq.ednsAdded;
-    ids->ecsAdded = dq.ecsAdded;
-    ids->useZeroScope = dq.useZeroScope;
-    ids->qTag = dq.qTag;
-    ids->dnssecOK = dq.dnssecOK;
+    setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
 
     /* If we couldn't harvest the real dest addr, still
        write down the listening addr since it will be useful
@@ -1611,12 +1577,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       ids->destHarvested = false;
     }
 
-    ids->dnsCryptQuery = std::move(dq.dnsCryptQuery);
-
-#ifdef HAVE_PROTOBUF
-    ids->uniqueId = std::move(dq.uniqueId);
-#endif
-
     dh->id = idOffset;
 
     int fd = pickBackendSocketForSending(ss);
index 6ed7faeab63d498086fc9ec0f32855ea800a82fc..7757e692880a5a4d8e33942643b350b14879f225 100644 (file)
@@ -1066,3 +1066,5 @@ static const size_t s_udpIncomingBufferSize{1500};
 enum class ProcessQueryResult { Drop, SendAnswer, PassToBackend };
 ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend);
 
+DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP);
+void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname);
index b0d56c303cd0c85bcd341f6232c552edc78b36ab..8c9e90b5cd55a5cfc1031664bd0f324b3f205584 100644 (file)
@@ -99,6 +99,7 @@ dnsdist_SOURCES = \
        dnsdist-dnscrypt.cc \
        dnsdist-dynblocks.hh \
        dnsdist-ecs.cc dnsdist-ecs.hh \
+       dnsdist-idstate.cc \
        dnsdist-lua.hh dnsdist-lua.cc \
        dnsdist-lua-actions.cc \
        dnsdist-lua-bindings.cc \
diff --git a/pdns/dnsdistdist/dnsdist-idstate.cc b/pdns/dnsdistdist/dnsdist-idstate.cc
new file mode 100644 (file)
index 0000000..169ba64
--- /dev/null
@@ -0,0 +1,58 @@
+
+#include "dnsdist.hh"
+
+DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP)
+{
+  
+  DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, ids.qname.wirelength(), &ids.origDest, &ids.origRemote, dh, bufferSize, responseLen, isTCP, &ids.sentTime.d_start);
+  dr.origFlags = ids.origFlags;
+  dr.ecsAdded = ids.ecsAdded;
+  dr.ednsAdded = ids.ednsAdded;
+  dr.useZeroScope = ids.useZeroScope;
+  dr.packetCache = std::move(ids.packetCache);
+  dr.delayMsec = ids.delayMsec;
+  dr.skipCache = ids.skipCache;
+  dr.cacheKey = ids.cacheKey;
+  dr.cacheKeyNoECS = ids.cacheKeyNoECS;
+  dr.dnssecOK = ids.dnssecOK;
+  dr.tempFailureTTL = ids.tempFailureTTL;
+  dr.qTag = std::move(ids.qTag);
+  dr.subnet = std::move(ids.subnet);
+#ifdef HAVE_PROTOBUF
+  dr.uniqueId = std::move(ids.uniqueId);
+#endif
+  if (ids.dnsCryptQuery) {
+    dr.dnsCryptQuery = std::move(ids.dnsCryptQuery);
+  }
+
+  return dr;  
+}
+
+void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname)
+{
+  ids.origRemote = *dq.remote;
+  ids.origDest = *dq.local;
+  ids.sentTime.set(*dq.queryTime);
+  ids.qname = std::move(qname);
+  ids.qtype = dq.qtype;
+  ids.qclass = dq.qclass;
+  ids.delayMsec = dq.delayMsec;
+  ids.tempFailureTTL = dq.tempFailureTTL;
+  ids.origFlags = dq.origFlags;
+  ids.cacheKey = dq.cacheKey;
+  ids.cacheKeyNoECS = dq.cacheKeyNoECS;
+  ids.subnet = dq.subnet;
+  ids.skipCache = dq.skipCache;
+  ids.packetCache = dq.packetCache;
+  ids.ednsAdded = dq.ednsAdded;
+  ids.ecsAdded = dq.ecsAdded;
+  ids.useZeroScope = dq.useZeroScope;
+  ids.qTag = dq.qTag;
+  ids.dnssecOK = dq.dnssecOK;
+  
+  ids.dnsCryptQuery = std::move(dq.dnsCryptQuery);
+  
+#ifdef HAVE_PROTOBUF
+  ids.uniqueId = std::move(dq.uniqueId);
+#endif
+}
index 2be4a4c62fe4675b13468dc0271b5f696c69f9bb..9d44f0dba7c9169f76840dc46fe6a8a9ba86ee1f 100644 (file)
@@ -232,7 +232,7 @@ private:
 class OpenSSLTLSConnection: public TLSConnection
 {
 public:
-  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free))
+  OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout)
   {
     d_socket = socket;
 
@@ -247,12 +247,59 @@ public:
     if (!SSL_set_fd(d_conn.get(), d_socket)) {
       throw std::runtime_error("Error assigning socket");
     }
+  }
+
+  IOState convertIORequestToIOState(int res) const
+  {
+    int error = SSL_get_error(d_conn.get(), res);
+    if (error == SSL_ERROR_WANT_READ) {
+      return IOState::NeedRead;
+    }
+    else if (error == SSL_ERROR_WANT_WRITE) {
+      return IOState::NeedWrite;
+    }
+    else {
+      throw std::runtime_error("Error while processing TLS connection:" + std::to_string(error));
+    }
+  }
+
+  void handleIORequest(int res, unsigned int timeout)
+  {
+    auto state = convertIORequestToIOState(res);
+    if (state == IOState::NeedRead) {
+      res = waitForData(d_socket, timeout);
+      if (res <= 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+    }
+    else if (state == IOState::NeedWrite) {
+      res = waitForRWData(d_socket, false, timeout, 0);
+      if (res <= 0) {
+        throw std::runtime_error("Error waiting to write to TLS connection");
+      }
+    }
+  }
+
+  IOState tryHandshake()
+  {
+    int res = SSL_accept(d_conn.get());
+    if (res == 1) {
+      return IOState::Done;
+    }
+    else if (res < 0) {
+      return convertIORequestToIOState(res);
+    }
+
+    throw std::runtime_error("Error accepting TLS connection");
+  }
 
+  void doHandshake()
+  {
     int res = 0;
     do {
       res = SSL_accept(d_conn.get());
       if (res < 0) {
-        handleIORequest(res, timeout);
+        handleIORequest(res, d_timeout);
       }
     }
     while (res < 0);
@@ -262,24 +309,40 @@ public:
     }
   }
 
-  void handleIORequest(int res, unsigned int timeout)
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
   {
-    int error = SSL_get_error(d_conn.get(), res);
-    if (error == SSL_ERROR_WANT_READ) {
-      res = waitForData(d_socket, timeout);
-      if (res <= 0) {
-        throw std::runtime_error("Error reading from TLS connection");
+    do {
+      int res = SSL_write(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), static_cast<int>(toWrite - pos));
+      if (res == 0) {
+        throw std::runtime_error("Error writing to TLS connection");
       }
-    }
-    else if (error == SSL_ERROR_WANT_WRITE) {
-      res = waitForRWData(d_socket, false, timeout, 0);
-      if (res <= 0) {
-        throw std::runtime_error("Error waiting to write to TLS connection");
+      else if (res < 0) {
+        return convertIORequestToIOState(res);
+      }
+      else {
+        pos += static_cast<size_t>(res);
       }
     }
-    else {
-      throw std::runtime_error("Error writing to TLS connection");
+    while (pos < toWrite);
+    return IOState::Done;
+  }
+
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
+  {
+    do {
+      int res = SSL_read(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), static_cast<int>(toRead - pos));
+      if (res == 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+      else if (res < 0) {
+        return convertIORequestToIOState(res);
+      }
+      else {
+        pos += static_cast<size_t>(res);
+      }
     }
+    while (pos < toRead);
+    return IOState::Done;
   }
 
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
@@ -300,7 +363,7 @@ public:
         handleIORequest(res, readTimeout);
       }
       else {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
 
       if (totalTimeout) {
@@ -330,7 +393,7 @@ public:
         handleIORequest(res, writeTimeout);
       }
       else {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
     }
     while (got < bufferSize);
@@ -346,6 +409,7 @@ public:
 
 private:
   std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
+  unsigned int d_timeout;
 };
 
 class OpenSSLTLSIOCtx: public TLSCtx
@@ -650,7 +714,7 @@ public:
 
   GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr<GnuTLSTicketsKey>& ticketsKey, bool enableTickets): d_conn(std::unique_ptr<gnutls_session_int, void(*)(gnutls_session_t)>(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey)
   {
-    unsigned int sslOptions = GNUTLS_SERVER;
+    unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK;
 #ifdef GNUTLS_NO_SIGNAL
     sslOptions |= GNUTLS_NO_SIGNAL;
 #endif
@@ -685,12 +749,86 @@ public:
     /* timeouts are in milliseconds */
     gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000);
     gnutls_record_set_timeout(d_conn.get(), timeout * 1000);
+  }
 
+  void doHandshake()
+  {
     int ret = 0;
     do {
       ret = gnutls_handshake(d_conn.get());
+      if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+        throw std::runtime_error("Error accepting a new connection");
+      }
+    }
+    while (ret < 0 && ret == GNUTLS_E_INTERRUPTED);
+  }
+
+  IOState tryHandshake()
+  {
+    int ret = 0;
+
+    do {
+      ret = gnutls_handshake(d_conn.get());
+      if (ret == GNUTLS_E_SUCCESS) {
+        return IOState::Done;
+      }
+      else if (ret == GNUTLS_E_AGAIN) {
+        return IOState::NeedRead;
+      }
+      else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) {
+        throw std::runtime_error("Error accepting a new connection");
+      }
+    } while (ret == GNUTLS_E_INTERRUPTED);
+
+    throw std::runtime_error("Error accepting a new connection");
+  }
+
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) override
+  {
+    do {
+      ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos);
+      if (res == 0) {
+        throw std::runtime_error("Error writing to TLS connection");
+      }
+      else if (res > 0) {
+        pos += static_cast<size_t>(res);
+      }
+      else if (res < 0) {
+        if (gnutls_error_is_fatal(res)) {
+          throw std::runtime_error("Error writing to TLS connection");
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          return IOState::NeedWrite;
+        }
+        warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+      }
+    }
+    while (pos < toWrite);
+    return IOState::Done;
+  }
+
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) override
+  {
+    do {
+      ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast<char *>(&buffer.at(pos)), toRead - pos);
+      if (res == 0) {
+        throw std::runtime_error("Error reading from TLS connection");
+      }
+      else if (res > 0) {
+        pos += static_cast<size_t>(res);
+      }
+      else if (res < 0) {
+        if (gnutls_error_is_fatal(res)) {
+          throw std::runtime_error("Error reading from TLS connection");
+        }
+        else if (res == GNUTLS_E_AGAIN) {
+          return IOState::NeedRead;
+        }
+        warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res));
+      }
     }
-    while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
+    while (pos < toRead);
+    return IOState::Done;
   }
 
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override
@@ -708,7 +846,7 @@ public:
         throw std::runtime_error("Error reading from TLS connection");
       }
       else if (res > 0) {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
@@ -750,7 +888,7 @@ public:
         throw std::runtime_error("Error writing to TLS connection");
       }
       else if (res > 0) {
-        got += (size_t) res;
+        got += static_cast<size_t>(res);
       }
       else if (res < 0) {
         if (gnutls_error_is_fatal(res)) {
index 0d7c342b008ba6890d84766756db5d7080ea68b2..88fd698131b2d0aaeeda4442f4a11c4763122527 100644 (file)
@@ -269,40 +269,112 @@ void ComboAddress::truncate(unsigned int bits) noexcept
   *place &= (~((1<<bitsleft)-1));
 }
 
-ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf)
+size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags)
 {
+  int remainingTime = totalTimeout;
+  time_t start = 0;
+  if (totalTimeout) {
+    start = time(nullptr);
+  }
+
   struct msghdr msgh;
   struct iovec iov;
   char cbuf[256];
+
+  /* Set up iov and msgh structures. */
+  memset(&msgh, 0, sizeof(struct msghdr));
+  msgh.msg_control = nullptr;
+  msgh.msg_controllen = 0;
+  if (dest) {
+    msgh.msg_name = reinterpret_cast<void*>(const_cast<ComboAddress*>(dest));
+    msgh.msg_namelen = dest->getSocklen();
+  }
+  else {
+    msgh.msg_name = nullptr;
+    msgh.msg_namelen = 0;
+  }
+
+  msgh.msg_flags = 0;
+
+  if (localItf != 0 && local) {
+    addCMsgSrcAddr(&msgh, cbuf, local, localItf);
+  }
+
+  if (localItf != 0 && local) {
+    addCMsgSrcAddr(&msgh, cbuf, local, localItf);
+  }
+
+  iov.iov_base = reinterpret_cast<void*>(const_cast<char*>(buffer));
+  iov.iov_len = len;
+  msgh.msg_iov = &iov;
+  msgh.msg_iovlen = 1;
+  msgh.msg_flags = 0;
+
+  size_t sent = 0;
   bool firstTry = true;
-  fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), const_cast<char*>(buffer), len, &dest);
-  addCMsgSrcAddr(&msgh, cbuf, &local, localItf);
 
   do {
-    ssize_t written = sendmsg(fd, &msgh, 0);
 
-    if (written > 0)
-      return written;
+#ifdef MSG_FASTOPEN
+    if (flags & MSG_FASTOPEN && firstTry == false) {
+      flags &= ~MSG_FASTOPEN;
+    }
+#endif /* MSG_FASTOPEN */
 
-    if (errno == EAGAIN) {
-      if (firstTry) {
-        int res = waitForRWData(fd, false, timeout, 0);
-        if (res > 0) {
-          /* there is room available */
-          firstTry = false;
+    ssize_t res = sendmsg(fd, &msgh, flags);
+
+    if (res > 0) {
+      size_t written = static_cast<size_t>(res);
+      sent += written;
+
+      if (sent == len) {
+        return sent;
+      }
+
+      /* partial write */
+      iov.iov_len -= written;
+      iov.iov_base = reinterpret_cast<void*>(reinterpret_cast<char*>(iov.iov_base) + written);
+      written = 0;
+    }
+    else if (res == -1) {
+      if (errno == EINTR) {
+        continue;
+      }
+      else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS) {
+        /* EINPROGRESS might happen with non blocking socket,
+           especially with TCP Fast Open */
+        if (totalTimeout <= 0 && idleTimeout <= 0) {
+          return sent;
+        }
+
+        if (firstTry) {
+          int res = waitForRWData(fd, false, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime, 0);
+          if (res > 0) {
+            /* there is room available */
+            firstTry = false;
+          }
+          else if (res == 0) {
+            throw runtime_error("Timeout while waiting to write data");
+          } else {
+            throw runtime_error("Error while waiting for room to write data");
+          }
         }
-        else if (res == 0) {
+        else {
           throw runtime_error("Timeout while waiting to write data");
-        } else {
-          throw runtime_error("Error while waiting for room to write data");
         }
       }
       else {
-        throw runtime_error("Timeout while waiting to write data");
+        unixDie("failed in sendMsgWithTimeout");
       }
     }
-    else {
-      unixDie("failed in write2WithTimeout");
+    if (totalTimeout) {
+      time_t now = time(nullptr);
+      int elapsed = now - start;
+      if (elapsed >= remainingTime) {
+        throw runtime_error("Timeout while sending data");
+      }
+      start = now;
+      remainingTime -= elapsed;
     }
   }
   while (firstTry);
index 490e45ac436665556dd77f7e545187cb810ca9a2..498e03c9243df8334ce40ccec24e0741320387f9 100644 (file)
@@ -1062,7 +1062,7 @@ bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destinat
 bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv);
 void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, char* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr);
 ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to);
-ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf);
+size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags);
 bool sendSizeAndMsgWithTimeout(int sock, uint16_t bufferLen, const char* buffer, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags);
 /* requires a non-blocking, connected TCP socket */
 bool isTCPSocketUsable(int sock);
index 7de3a8a459ff0440bb18a0661b2f9c1d2b4cd801..b42e90092847f8dcb35bc80032ca939b68a17652 100644 (file)
@@ -51,9 +51,9 @@ class FDMultiplexer
 {
 public:
   typedef boost::any funcparam_t;
+  typedef boost::function< void(int, funcparam_t&) > callbackfunc_t;
 protected:
 
-  typedef boost::function< void(int, funcparam_t&) > callbackfunc_t;
   struct Callback
   {
     callbackfunc_t d_callback;
index b8066b87826ad4e9ac4468ff4e1936b6d07ef9d8..922315a5e21bc3aeeb2c1c4bd1c321b0d29030b1 100644 (file)
@@ -60,10 +60,17 @@ public:
     setCloseOnExec(d_socket);
   }
 
+  Socket(Socket&& rhs): d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket)
+  {
+    rhs.d_socket = -1;
+  }
+
   ~Socket()
   {
     try {
-      closesocket(d_socket);
+      if (d_socket != -1) {
+        closesocket(d_socket);
+      }
     }
     catch(const PDNSException& e) {
     }
@@ -124,10 +131,10 @@ public:
   }
 
   //! Bind the socket to a specified endpoint
-  void bind(const ComboAddress &local)
+  void bind(const ComboAddress &local, bool reuseaddr=true)
   {
     int tmp=1;
-    if(setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp)<0)
+    if(reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp)<0)
       throw NetworkError(string("Setsockopt failed: ")+strerror(errno));
 
     if(::bind(d_socket, reinterpret_cast<const struct sockaddr *>(&local), local.getSocklen())<0)
index 0d5bfa514eea42333c21f7fea3c86341bc356068..30e19537f8f098f3a325733e2facaf5e7e550a09 100644 (file)
@@ -4,12 +4,18 @@
 
 #include "misc.hh"
 
+enum class IOState { Done, NeedRead, NeedWrite };
+
 class TLSConnection
 {
 public:
   virtual ~TLSConnection() { }
+  virtual void doHandshake() = 0;
+  virtual IOState tryHandshake() = 0;
   virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0;
   virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0;
+  virtual IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite) = 0;
+  virtual IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead) = 0;
   virtual void close() = 0;
 
 protected:
@@ -153,12 +159,14 @@ private:
 class TCPIOHandler
 {
 public:
+
   TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
   {
     if (ctx) {
       d_conn = ctx->getConnection(d_socket, timeout, now);
     }
   }
+
   ~TCPIOHandler()
   {
     if (d_conn) {
@@ -168,6 +176,15 @@ public:
       shutdown(d_socket, SHUT_RDWR);
     }
   }
+
+  IOState tryHandshake()
+  {
+    if (d_conn) {
+      return d_conn->tryHandshake();
+    }
+    return IOState::Done;
+  }
+
   size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0)
   {
     if (d_conn) {
@@ -176,6 +193,77 @@ public:
       return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout);
     }
   }
+
+  /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
+     Updates pos everytime a successful read occurs,
+     throws an std::runtime_error in case of IO error,
+     return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+     would block.
+  */
+  IOState tryRead(std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
+  {
+    if (d_conn) {
+      return d_conn->tryRead(buffer, pos, toRead);
+    }
+
+    size_t got = 0;
+    do {
+      ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
+      if (res == 0) {
+        throw runtime_error("EOF while reading message");
+      }
+      if (res < 0) {
+        if (errno == EAGAIN || errno == EWOULDBLOCK) {
+          return IOState::NeedRead;
+        }
+        else {
+          throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
+        }
+      }
+
+      pos += static_cast<size_t>(res);
+      got += static_cast<size_t>(res);
+    }
+    while (got < toRead);
+
+    return IOState::Done;
+  }
+
+  /* Tries to write exactly toWrite bytes from the buffer, starting at position pos.
+     Updates pos everytime a successful write occurs,
+     throws an std::runtime_error in case of IO error,
+     return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
+     would block.
+  */
+  IOState tryWrite(std::vector<uint8_t>& buffer, size_t& pos, size_t toWrite)
+  {
+    if (d_conn) {
+      return d_conn->tryWrite(buffer, pos, toWrite);
+    }
+
+    size_t sent = 0;
+    do {
+      ssize_t res = ::write(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toWrite - sent);
+      if (res == 0) {
+        throw runtime_error("EOF while sending message");
+      }
+      if (res < 0) {
+        if (errno == EAGAIN || errno == EWOULDBLOCK) {
+          return IOState::NeedWrite;
+        }
+        else {
+          throw std::runtime_error(std::string("Error while writing message: ") + strerror(errno));
+        }
+      }
+
+      pos += static_cast<size_t>(res);
+      sent += static_cast<size_t>(res);
+    }
+    while (sent < toWrite);
+
+    return IOState::Done;
+  }
+
   size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout)
   {
     if (d_conn) {