]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: TCP out-of-order implementation
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 5 Jun 2020 15:58:31 +0000 (17:58 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 10 Nov 2020 08:45:48 +0000 (09:45 +0100)
13 files changed:
pdns/dnsdist-tcp.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-backend.cc
pdns/dnsdistdist/dnsdist-proxy-protocol.cc
pdns/dnsdistdist/dnsdist-proxy-protocol.hh
pdns/dnsdistdist/dnsdist-tcp-downstream.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-tcp-downstream.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-tcp-upstream.hh [new file with mode: 0644]
pdns/dnsdistdist/doh.cc
pdns/dnsdistdist/tcpiohandler-mplexer.hh [new file with mode: 0644]
regression-tests.dnsdist/test_TCPShort.py
regression-tests.dnsdist/test_Tags.py

index 048339929a62afd425dd5a8348e031b56434305c..3533bccda5e78145249de55c325324acdbba065b 100644 (file)
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
+
+#include <thread>
+#include <netinet/tcp.h>
+#include <queue>
+
 #include "dnsdist.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-proxy-protocol.hh"
 #include "dnsdist-rings.hh"
+#include "dnsdist-tcp-downstream.hh"
+#include "dnsdist-tcp-upstream.hh"
 #include "dnsdist-xpf.hh"
-
 #include "dnsparser.hh"
-#include "ednsoptions.hh"
 #include "dolog.hh"
-#include "lock.hh"
+#include "ednsoptions.hh"
 #include "gettime.hh"
+#include "lock.hh"
+#include "sstuff.hh"
 #include "tcpiohandler.hh"
+#include "tcpiohandler-mplexer.hh"
 #include "threadname.hh"
-#include <thread>
-#include <atomic>
-#include <netinet/tcp.h>
-
-#include "sstuff.hh"
-
-using std::thread;
-using std::atomic;
 
 /* TCP: the grand design.
    We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
@@ -58,7 +58,7 @@ using std::atomic;
 
 static std::mutex tcpClientsCountMutex;
 static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
-static const size_t g_maxCachedConnectionsPerDownstream = 20;
+
 uint64_t g_maxTCPQueuedConnections{1000};
 size_t g_maxTCPQueriesPerConn{0};
 size_t g_maxTCPConnectionDuration{0};
@@ -66,243 +66,174 @@ size_t g_maxTCPConnectionsPerClient{0};
 uint16_t g_downstreamTCPCleanupInterval{60};
 bool g_useTCPSinglePipe{false};
 
-static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
+class DownstreamConnectionsManager
 {
-  std::unique_ptr<Socket> result;
+public:
 
-  do {
-    vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
-    try {
-      result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
-      if (!IsAnyAddress(ds->sourceAddr)) {
-        SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
-#ifdef IP_BIND_ADDRESS_NO_PORT
-        if (ds->ipBindAddrNoPort) {
-          SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
-        }
-#endif
-#ifdef SO_BINDTODEVICE
-        if (!ds->sourceItfName.empty()) {
-          int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, ds->sourceItfName.c_str(), ds->sourceItfName.length());
-          if (res != 0) {
-            vinfolog("Error setting up the interface on backend TCP socket '%s': %s", ds->getNameWithAddr(), stringerror());
-          }
-        }
-#endif
-        result->bind(ds->sourceAddr, false);
-      }
-      result->setNonBlocking();
-#ifdef MSG_FASTOPEN
-      if (!ds->tcpFastOpen) {
-        SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
-      }
-#else
-      SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
-#endif /* MSG_FASTOPEN */
-      return result;
-    }
-    catch(const std::runtime_error& e) {
-      vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
-      downstreamFailures++;
-      if (downstreamFailures > ds->retries) {
-        throw;
+  static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
+  {
+    std::unique_ptr<TCPConnectionToBackend> result;
+
+    const auto& it = t_downstreamConnections.find(ds->remote);
+    if (it != t_downstreamConnections.end()) {
+      auto& list = it->second;
+      if (!list.empty()) {
+        result = std::move(list.front());
+        list.pop_front();
+        result->setReused();
+        return result;
       }
     }
-  } while(downstreamFailures <= ds->retries);
 
-  return nullptr;
-}
-
-class TCPConnectionToBackend
-{
-public:
-  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen)
-  {
-    d_socket = setupTCPDownstream(d_ds, downstreamFailures);
-    ++d_ds->tcpCurrentConnections;
+    return make_unique<TCPConnectionToBackend>(ds, now);
   }
 
-  ~TCPConnectionToBackend()
+  static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
   {
-    if (d_ds && d_socket) {
-      --d_ds->tcpCurrentConnections;
-      struct timeval now;
-      gettimeofday(&now, nullptr);
-
-      auto diff = now - d_connectionStartTime;
-      d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
+    if (conn == nullptr) {
+      return;
     }
-  }
 
-  int getHandle() const
-  {
-    if (!d_socket) {
-      throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
+    if (!conn->canBeReused()) {
+      conn.reset();
+      return;
     }
 
-    return d_socket->getHandle();
-  }
-
-  const ComboAddress& getRemote() const
-  {
-    return d_ds->remote;
-  }
-
-  bool isFresh() const
-  {
-    return d_fresh;
-  }
-
-  void incQueries()
-  {
-    ++d_queries;
-  }
-
-  void setReused()
-  {
-    d_fresh = false;
-  }
-
-  void disableFastOpen()
-  {
-    d_enableFastOpen = false;
-  }
-
-  bool isFastOpenEnabled()
-  {
-    return d_enableFastOpen;
-  }
-
-  bool canBeReused() const
-  {
-    /* we can't reuse a connection where a proxy protocol payload has been sent,
-       since:
-       - it cannot be reused for a different client
-       - we might have different TLV values for each query
-    */
-    if (d_ds && d_ds->useProxyProtocol) {
-      return false;
+    const auto& remote = conn->getRemote();
+    const auto& it = t_downstreamConnections.find(remote);
+    if (it != t_downstreamConnections.end()) {
+      auto& list = it->second;
+      if (list.size() >= s_maxCachedConnectionsPerDownstream) {
+        /* too many connections queued already */
+        conn.reset();
+        return;
+      }
+      list.push_back(std::move(conn));
+    }
+    else {
+      t_downstreamConnections[remote].push_back(std::move(conn));
     }
-    return true;
   }
 
-  bool matches(const std::shared_ptr<DownstreamState>& ds) const
+  static void cleanupClosedTCPConnections()
   {
-    if (!ds || !d_ds) {
-      return false;
+    for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
+      for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
+        if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
+          ++connIt;
+        }
+        else {
+          connIt = dsIt->second.erase(connIt);
+        }
+      }
+
+      if (!dsIt->second.empty()) {
+        ++dsIt;
+      }
+      else {
+        dsIt = t_downstreamConnections.erase(dsIt);
+      }
     }
-    return ds == d_ds;
   }
 
 private:
-  std::unique_ptr<Socket> d_socket{nullptr};
-  std::shared_ptr<DownstreamState> d_ds{nullptr};
-  struct timeval d_connectionStartTime;
-  uint64_t d_queries{0};
-  bool d_fresh{true};
-  bool d_enableFastOpen{false};
+  static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
+  static const size_t s_maxCachedConnectionsPerDownstream;
 };
 
-static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
+thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> DownstreamConnectionsManager::t_downstreamConnections;
+const size_t DownstreamConnectionsManager::s_maxCachedConnectionsPerDownstream{20};
 
-static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
+static void decrementTCPClientCount(const ComboAddress& client)
 {
-  std::unique_ptr<TCPConnectionToBackend> result;
-
-  const auto& it = t_downstreamConnections.find(ds->remote);
-  if (it != t_downstreamConnections.end()) {
-    auto& list = it->second;
-    if (!list.empty()) {
-      result = std::move(list.front());
-      list.pop_front();
-      result->setReused();
-      return result;
+  if (g_maxTCPConnectionsPerClient) {
+    std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
+    tcpClientsCount.at(client)--;
+    if (tcpClientsCount[client] == 0) {
+      tcpClientsCount.erase(client);
     }
   }
-
-  return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
 }
 
-static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
+IncomingTCPConnectionState::~IncomingTCPConnectionState()
 {
-  if (conn == nullptr) {
-    return;
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  decrementTCPClientCount(d_ci.remote);
+  // DEBUG: cerr<<"decremented"<<endl;
+
+  if (d_ci.cs != nullptr) {
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+
+    auto diff = now - d_connectionStartTime;
+    // DEBUG: cerr<<"updating tcp metrics"<<endl;
+    d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
+    // DEBUG: cerr<<"updated tcp metrics"<<endl;
   }
 
-  if (!conn->canBeReused()) {
-    conn.reset();
-    return;
+#if 0
+  if (d_ds != nullptr) {
+
+    if (d_downstreamConnection) {
+      try {
+        if (d_lastIOState == IOState::NeedRead) {
+          // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
+          d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
+        }
+        else if (d_lastIOState == IOState::NeedWrite) {
+          // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
+          d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
+        }
+      }
+      catch(const FDMultiplexerException& e) {
+        vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+      }
+      catch(const std::runtime_error& e) {
+        /* might be thrown by getHandle() */
+        vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+      }
+    }
   }
+#endif
 
-  const auto& remote = conn->getRemote();
-  const auto& it = t_downstreamConnections.find(remote);
-  if (it != t_downstreamConnections.end()) {
-    auto& list = it->second;
-    if (list.size() >= g_maxCachedConnectionsPerDownstream) {
-      /* too many connections queued already */
-      conn.reset();
-      return;
+  // DEBUG: cerr<<"about to remove left over FDs"<<endl;
+  try {
+    if (d_lastIOState == IOState::NeedRead) {
+      // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover client read FD "<<d_ci.fd<<endl;
+      d_threadData.mplexer->removeReadFD(d_ci.fd);
+    }
+    else if (d_lastIOState == IOState::NeedWrite) {
+      // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover client write FD "<<d_ci.fd<<endl;
+      d_threadData.mplexer->removeWriteFD(d_ci.fd);
     }
-    list.push_back(std::move(conn));
   }
-  else {
-    t_downstreamConnections[remote].push_back(std::move(conn));
+  catch (const FDMultiplexerException& e) {
+    vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+  }
+  catch (...) {
+    vinfolog("Got an unknown exception when trying to remove a pending IO operation on an incoming TCP connection from %s", d_ci.remote.toStringWithPort());
   }
+  // DEBUG: cerr<<"done"<<endl;
 }
 
-struct ConnectionInfo
+std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
 {
-  ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
-  {
-  }
-  ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
-  {
-    rhs.cs = nullptr;
-    rhs.fd = -1;
-  }
+  std::shared_ptr<TCPConnectionToBackend> downstream{nullptr};
 
-  ConnectionInfo(const ConnectionInfo& rhs) = delete;
-  ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
-
-  ConnectionInfo& operator=(ConnectionInfo&& rhs)
-  {
-    remote = rhs.remote;
-    cs = rhs.cs;
-    rhs.cs = nullptr;
-    fd = rhs.fd;
-    rhs.fd = -1;
-    return *this;
+  if (!ds->useProxyProtocol || !d_proxyProtocolPayloadHasTLV) {
+    downstream = getActiveDownstreamConnection(ds);
   }
 
-  ~ConnectionInfo()
-  {
-    if (fd != -1) {
-      close(fd);
-      fd = -1;
-    }
-    if (cs) {
-      --cs->tcpCurrentConnections;
-    }
+  if (!downstream) {
+    /* we don't have a connection to this backend active yet, let's ask one (it might not be a fresh one, though) */
+    downstream = DownstreamConnectionsManager::getConnectionToDownstream(d_threadData.mplexer, ds, now);
   }
 
-  ComboAddress remote;
-  ClientState* cs{nullptr};
-  int fd{-1};
-};
-
-void tcpClientThread(int pipefd);
-
-static void decrementTCPClientCount(const ComboAddress& client)
-{
-  if (g_maxTCPConnectionsPerClient) {
-    std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
-    tcpClientsCount[client]--;
-    if (tcpClientsCount[client] == 0) {
-      tcpClientsCount.erase(client);
-    }
-  }
+  return downstream;
 }
 
+static void tcpClientThread(int pipefd);
+
 void TCPClientCollection::addTCPClientThread()
 {
   int pipefds[2] = { -1, -1};
@@ -349,7 +280,7 @@ void TCPClientCollection::addTCPClientThread()
     }
 
     try {
-      thread t1(tcpClientThread, pipefds[0]);
+      std::thread t1(tcpClientThread, pipefds[0]);
       t1.detach();
     }
     catch(const std::runtime_error& e) {
@@ -367,27 +298,6 @@ void TCPClientCollection::addTCPClientThread()
   }
 }
 
-static void cleanupClosedTCPConnections()
-{
-  for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
-    for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
-      if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
-        ++connIt;
-      }
-      else {
-        connIt = dsIt->second.erase(connIt);
-      }
-    }
-
-    if (!dsIt->second.empty()) {
-      ++dsIt;
-    }
-    else {
-      dsIt = t_downstreamConnections.erase(dsIt);
-    }
-  }
-}
-
 /* Tries to read exactly toRead bytes into the buffer, starting at position pos.
    Updates pos everytime a successful read occurs,
    throws an std::runtime_error in case of IO error,
@@ -395,7 +305,7 @@ static void cleanupClosedTCPConnections()
    would block.
 */
 // XXX could probably be implemented as a TCPIOHandler
-static IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
+IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
 {
   if (buffer.size() < (pos + toRead)) {
     throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
@@ -426,374 +336,155 @@ static IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t
 
 std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
 
-class TCPClientThreadData
+static IOState handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
 {
-public:
-  TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
-  {
-  }
-
-  LocalHolders holders;
-  LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
-  std::unique_ptr<FDMultiplexer> mplexer{nullptr};
-};
-
-static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
-
-class IncomingTCPConnectionState
-{
-public:
-  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_responseBuffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
-  {
-    d_ids.origDest.reset();
-    d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
-    socklen_t socklen = d_ids.origDest.getSocklen();
-    if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
-      d_ids.origDest = d_ci.cs->local;
-    }
-  }
-
-  IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
-  IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
-
-  ~IncomingTCPConnectionState()
-  {
-    decrementTCPClientCount(d_ci.remote);
-    if (d_ci.cs != nullptr) {
-      struct timeval now;
-      gettimeofday(&now, nullptr);
-
-      auto diff = now - d_connectionStartTime;
-      d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
-    }
-
-    if (d_ds != nullptr) {
-      if (d_outstanding) {
-        --d_ds->outstanding;
-        d_outstanding = false;
-      }
-
-      if (d_downstreamConnection) {
-        try {
-          if (d_lastIOState == IOState::NeedRead) {
-            cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
-            d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
-          }
-          else if (d_lastIOState == IOState::NeedWrite) {
-            cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
-            d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
-          }
-        }
-        catch(const FDMultiplexerException& e) {
-          vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
-        }
-        catch(const std::runtime_error& e) {
-          /* might be thrown by getHandle() */
-          vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
-        }
-      }
-    }
-
-    try {
-      if (d_lastIOState == IOState::NeedRead) {
-        cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
-        d_threadData.mplexer->removeReadFD(d_ci.fd);
-      }
-      else if (d_lastIOState == IOState::NeedWrite) {
-        cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
-        d_threadData.mplexer->removeWriteFD(d_ci.fd);
-      }
-    }
-    catch(const FDMultiplexerException& e) {
-      vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
-    }
-  }
-
-  void resetForNewQuery()
-  {
-    d_buffer.resize(sizeof(uint16_t));
-    d_currentPos = 0;
-    d_querySize = 0;
-    d_responseSize = 0;
-    d_downstreamFailures = 0;
-    d_state = State::readingQuerySize;
-    d_lastIOState = IOState::Done;
-    d_selfGeneratedResponse = false;
-  }
-
-  boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
-  {
-    if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
-      return boost::none;
-    }
-
-    if (g_maxTCPConnectionDuration > 0) {
-      auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
-      if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
-        return now;
-      }
-      auto remaining = g_maxTCPConnectionDuration - elapsed;
-      if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
-        now.tv_sec += remaining;
-        return now;
-      }
-    }
-
-    now.tv_sec += g_tcpRecvTimeout;
-    return now;
-  }
-
-  boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
-  {
-    if (d_ds == nullptr) {
-      throw std::runtime_error("getBackendReadTTD() without any backend selected");
-    }
-    if (d_ds->tcpRecvTimeout == 0) {
-      return boost::none;
-    }
-
-    struct timeval res = now;
-    res.tv_sec += d_ds->tcpRecvTimeout;
-
-    return res;
-  }
-
-  boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
-  {
-    if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
-      return boost::none;
+  if (!state->d_isXFR) {
+    const auto& currentResponse = state->d_currentResponse;
+    if (state->d_selfGeneratedResponse == false && currentResponse.d_ds) {
+      struct timespec answertime;
+      gettime(&answertime);
+      const auto& ids = currentResponse.d_idstate;
+      double udiff = ids.sentTime.udiff();
+      g_rings.insertResponse(answertime, state->d_ci.remote, ids.qname, ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, currentResponse.d_ds->remote);
+      vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", currentResponse.d_ds->remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff);
+    }
+
+    switch (currentResponse.d_cleartextDH.rcode) {
+    case RCode::NXDomain:
+      ++g_stats.frontendNXDomain;
+      break;
+    case RCode::ServFail:
+      ++g_stats.servfailResponses;
+      ++g_stats.frontendServFail;
+      break;
+    case RCode::NoError:
+      ++g_stats.frontendNoError;
+      break;
     }
 
-    struct timeval res = now;
-
-    if (g_maxTCPConnectionDuration > 0) {
-      auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
-      if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
-        return res;
-      }
-      auto remaining = g_maxTCPConnectionDuration - elapsed;
-      if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
-        res.tv_sec += remaining;
-        return res;
-      }
+    if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
+      vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
+      return IOState::Done;
     }
 
-    res.tv_sec += g_tcpSendTimeout;
-    return res;
-  }
-
-  boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
-  {
-    if (d_ds == nullptr) {
-      throw std::runtime_error("getBackendReadTTD() called without any backend selected");
-    }
-    if (d_ds->tcpSendTimeout == 0) {
-      return boost::none;
+    if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+      vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+      return IOState::Done;
     }
-
-    struct timeval res = now;
-    res.tv_sec += d_ds->tcpSendTimeout;
-
-    return res;
   }
 
-  bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
-  {
-    if (maxConnectionDuration) {
-      time_t curtime = now.tv_sec;
-      unsigned int elapsed = 0;
-      if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
-        elapsed = curtime - d_connectionStartTime.tv_sec;
-      }
-      if (elapsed >= maxConnectionDuration) {
-        return true;
-      }
-      d_remainingTime = maxConnectionDuration - elapsed;
+  if (state->d_queuedResponses.empty()) {
+    // DEBUG: cerr<<"no response remaining"<<endl;
+    if (state->d_isXFR) {
+      /* we should still be reading from the backend, and we don't want to read from the client */
+      state->d_state = IncomingTCPConnectionState::State::idle;
+      state->d_currentPos = 0;
+      // DEBUG: cerr<<"idling for XFR completion"<<endl;
+      return IOState::Done;
+    } else {
+      // DEBUG: cerr<<"reading new queries if any"<<endl;
+      state->resetForNewQuery();
+      return IOState::NeedRead;
     }
-
-    return false;
   }
-
-  void dump() const
-  {
-    static std::mutex s_mutex;
-
-    struct timeval now;
-    gettimeofday(&now, 0);
-
-    {
-      std::lock_guard<std::mutex> lock(s_mutex);
-      fprintf(stderr, "State is %p\n", this);
-      cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
-      cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
-      cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
-      if (d_state > State::doingHandshake) {
-        cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
-      }
-      if (d_state > State::readingQuerySize) {
-        cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
-      }
-      if (d_state > State::readingQuerySize) {
-        cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
-      }
-      if (d_state > State::readingQuery) {
-        cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
-      }
-      if (d_state > State::sendingQueryToBackend) {
-        cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
-      }
-      if (d_state > State::readingResponseFromBackend) {
-        cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
-      }
-    }
+  else {
+    // DEBUG: cerr<<"queue size is "<<state->d_queuedResponses.size()<<endl;
+    TCPResponse resp = std::move(state->d_queuedResponses.front());
+    state->d_queuedResponses.pop_front();
+    state->d_state = IncomingTCPConnectionState::State::idle;
+    state->sendResponse(state, now, std::move(resp));
+    return IOState::NeedWrite;
   }
+}
 
-  enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
-
-  std::vector<uint8_t> d_buffer;
-  std::vector<uint8_t> d_responseBuffer;
-  TCPClientThreadData& d_threadData;
-  IDState d_ids;
-  ConnectionInfo d_ci;
-  TCPIOHandler d_handler;
-  std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
-  std::shared_ptr<DownstreamState> d_ds{nullptr};
-  dnsheader d_cleartextDH;
-  struct timeval d_connectionStartTime;
-  struct timeval d_handshakeDoneTime;
-  struct timeval d_firstQuerySizeReadTime;
-  struct timeval d_querySizeReadTime;
-  struct timeval d_queryReadTime;
-  struct timeval d_querySentTime;
-  struct timeval d_responseReadTime;
-  size_t d_currentPos{0};
-  size_t d_queriesCount{0};
-  unsigned int d_remainingTime{0};
-  uint16_t d_querySize{0};
-  uint16_t d_responseSize{0};
-  uint16_t d_downstreamFailures{0};
-  State d_state{State::doingHandshake};
-  IOState d_lastIOState{IOState::Done};
-  bool d_readingFirstQuery{true};
-  bool d_outstanding{false};
-  bool d_firstResponsePacket{true};
-  bool d_isXFR{false};
-  bool d_xfrStarted{false};
-  bool d_selfGeneratedResponse{false};
-  bool d_proxyProtocolPayloadAdded{false};
-  bool d_proxyProtocolPayloadHasTLV{false};
-};
-
-static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
-static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
-static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
-static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
-
-static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+void IncomingTCPConnectionState::resetForNewQuery()
 {
-  handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
+  d_buffer.resize(sizeof(uint16_t));
+  d_currentPos = 0;
+  d_querySize = 0;
+  //d_responseSize = 0;
+  d_downstreamFailures = 0;
+  d_state = State::readingQuerySize;
+  d_lastIOState = IOState::Done;
+  d_selfGeneratedResponse = false;
+}
 
-  if (state->d_isXFR && state->d_downstreamConnection) {
-    /* we need to resume reading from the backend! */
-    state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+/* this version is called when the buffer has been set and the rules have been processed */
+void IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
+{
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  // if we already reading a query (not the query size, mind you), or sending a response we need to either queue the response
+  // otherwise we can start sending it right away
+  if (state->d_state == IncomingTCPConnectionState::State::idle ||
+      state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+
+    state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+
+    uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size());
+    const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) };
+    /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+       that could occur if we had to deal with the size during the processing,
+       especially alignment issues */
+    response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2);
     state->d_currentPos = 0;
-    handleDownstreamIO(state, now);
-    return;
-  }
-
-  if (state->d_selfGeneratedResponse == false && state->d_ds) {
-    /* if we have no downstream server selected, this was a self-answered response
-       but cache hits have a selected server as well, so be careful */
-    struct timespec answertime;
-    gettime(&answertime);
-    double udiff = state->d_ids.sentTime.udiff();
-    g_rings.insertResponse(answertime, state->d_ci.remote, state->d_ids.qname, state->d_ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), state->d_cleartextDH, state->d_ds->remote);
-    vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", state->d_ds->remote.toStringWithPort(), state->d_ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff);
-  }
+    state->d_currentResponse = std::move(response);
 
-  switch (state->d_cleartextDH.rcode) {
-  case RCode::NXDomain:
-    ++g_stats.frontendNXDomain;
-    break;
-  case RCode::ServFail:
-    ++g_stats.servfailResponses;
-    ++g_stats.frontendServFail;
-    break;
-  case RCode::NoError:
-    ++g_stats.frontendNoError;
-    break;
-  }
-
-  if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
-    vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
-    return;
+    //IncomingTCPConnectionState::handleIO(state, now);
+     state->d_ioState->update(IOState::NeedWrite, handleIOCallback, state, getClientWriteTTD(now));
+    // DEBUG: cerr<<"updated IO state"<<endl;
   }
-
-  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
-    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
-    return;
+  else {
+    // queue response
+    state->d_queuedResponses.push_back(std::move(response));
+    // DEBUG: cerr<<"queueing response because state is "<<(int)state->d_state<<", queue size is now "<<state->d_queuedResponses.size()<<endl;
   }
-
-  state->resetForNewQuery();
-
-  handleIO(state, now);
-}
-
-static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
-{
-  state->d_state = IncomingTCPConnectionState::State::sendingResponse;
-  const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
-  /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
-     that could occur if we had to deal with the size during the processing,
-     especially alignment issues */
-  state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
-
-  state->d_currentPos = 0;
-
-  handleIO(state, now);
 }
 
-static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+/* this version is called from the backend code when a new response has been received */
+void IncomingTCPConnectionState::handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
 {
-  if (state->d_responseSize < sizeof(dnsheader) || !state->d_ds) {
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  if (response.d_buffer.size() < sizeof(dnsheader)) {
+    // DEBUG: cerr<<"too small"<<endl;
     return;
   }
 
-  auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
+  uint16_t responseSize = response.d_buffer.size();
+  response.d_buffer.resize(responseSize + static_cast<size_t>(512));
+  size_t responseCapacity = response.d_buffer.size();
+  auto responseAsCharArray = reinterpret_cast<char*>(&response.d_buffer.at(0));
+
+  auto& ids = response.d_idstate;
+  // DEBUG: cerr<<"IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
   unsigned int consumed;
-  if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
+  // DEBUG: cerr<<"about to match response for "<<ids.qname<<endl;
+  if (!responseContentMatches(responseAsCharArray, responseSize, ids.qname, ids.qtype, ids.qclass, response.d_ds->remote, consumed)) {
+    // DEBUG: cerr<<"content does not match"<<endl;
     return;
   }
-  state->d_firstResponsePacket = false;
-
-  if (state->d_outstanding) {
-    --state->d_ds->outstanding;
-    state->d_outstanding = false;
-  }
 
-  auto dh = reinterpret_cast<struct dnsheader*>(response);
+  auto dh = reinterpret_cast<struct dnsheader*>(responseAsCharArray);
   uint16_t addRoom = 0;
-  DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
+  DNSResponse dr = makeDNSResponseFromIDState(ids, dh, responseCapacity, responseSize, true);
   if (dr.dnsCryptQuery) {
     addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
   }
 
-  memcpy(&state->d_cleartextDH, dr.dh, sizeof(state->d_cleartextDH));
+  memcpy(&response.d_cleartextDH, dr.dh, sizeof(response.d_cleartextDH));
 
   std::vector<uint8_t> rewrittenResponse;
-  size_t responseSize = state->d_responseBuffer.size();
-  if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+  if (!processResponse(&responseAsCharArray, &responseSize, &responseCapacity, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+    // DEBUG: cerr<<"process said to drop it"<<endl;
     return;
   }
 
   if (!rewrittenResponse.empty()) {
     /* responseSize has been updated as well but we don't really care since it will match
        the capacity of rewrittenResponse anyway */
-    state->d_responseBuffer = std::move(rewrittenResponse);
-    state->d_responseSize = state->d_responseBuffer.size();
+    response.d_buffer = std::move(rewrittenResponse);
   } else {
     /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
-    state->d_responseBuffer.resize(state->d_responseSize);
+    response.d_buffer.resize(responseSize);
   }
 
   if (state->d_isXFR && !state->d_xfrStarted) {
@@ -801,67 +492,23 @@ static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, s
     state->d_xfrStarted = true;
     ++g_stats.responses;
     ++state->d_ci.cs->responses;
-    ++state->d_ds->responses;
+    ++response.d_ds->responses;
   }
 
   if (!state->d_isXFR) {
     ++g_stats.responses;
     ++state->d_ci.cs->responses;
-    ++state->d_ds->responses;
+    ++response.d_ds->responses;
   }
 
-  sendResponse(state, now);
+  sendResponse(state, now, std::move(response));
 }
 
-static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
-{
-  auto ds = state->d_ds;
-  state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
-  state->d_currentPos = 0;
-  state->d_firstResponsePacket = true;
-
-  if (state->d_xfrStarted) {
-    /* sorry, but we are not going to resume a XFR if we have already sent some packets
-       to the client */
-    return;
-  }
-
-  if (!state->d_downstreamConnection) {
-    if (state->d_downstreamFailures < state->d_ds->retries) {
-      try {
-        state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
-      }
-      catch (const std::runtime_error& e) {
-        state->d_downstreamConnection.reset();
-      }
-    }
-
-    if (!state->d_downstreamConnection) {
-      ++ds->tcpGaveUp;
-      ++state->d_ci.cs->tcpGaveUp;
-      vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
-      return;
-    }
-
-    if (ds->useProxyProtocol && !state->d_proxyProtocolPayloadAdded) {
-      /* we know there is no TLV values to add, otherwise we would not have tried
-         to reuse the connection and d_proxyProtocolPayloadAdded would be true already */
-      addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, std::vector<ProxyProtocolValue>());
-      state->d_proxyProtocolPayloadAdded = true;
-    }
-  }
-
-  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", state->d_ids.qname.toLogString(), QType(state->d_ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
-
-  handleDownstreamIO(state, now);
-  return;
-}
-
-static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+static bool handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
 {
   if (state->d_querySize < sizeof(dnsheader)) {
     ++g_stats.nonCompliantQueries;
-    return;
+    return true;
   }
 
   state->d_readingFirstQuery = false;
@@ -900,21 +547,24 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, stru
   std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
   auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
   if (dnsCryptResponse) {
-    state->d_responseBuffer = std::move(*dnsCryptResponse);
-    state->d_responseSize = state->d_responseBuffer.size();
-    sendResponse(state, now);
-    return;
+    //state->d_responseBuffer = std::move(*dnsCryptResponse);
+    //state->d_responseSize = state->d_responseBuffer.size();
+    TCPResponse response;
+    response.d_buffer = std::move(*dnsCryptResponse);
+    state->d_state = IncomingTCPConnectionState::State::idle;
+    state->sendResponse(state, now, std::move(response));
+    return false;
   }
 
   const auto& dh = reinterpret_cast<dnsheader*>(query);
   if (!checkQueryHeaders(dh)) {
-    return;
+    return true;
   }
 
   uint16_t qtype, qclass;
   unsigned int consumed = 0;
   DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
-  DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
+  DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
   dq.dnsCryptQuery = std::move(dnsCryptQuery);
   dq.sni = state->d_handler.getServerNameIndication();
 
@@ -923,27 +573,39 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, stru
     dq.skipCache = true;
   }
 
-  state->d_ds.reset();
-  auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
+  std::shared_ptr<DownstreamState> ds;
+  auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, ds);
 
   if (result == ProcessQueryResult::Drop) {
-    return;
+    return true;
   }
 
   if (result == ProcessQueryResult::SendAnswer) {
     state->d_selfGeneratedResponse = true;
     state->d_buffer.resize(dq.len);
-    state->d_responseBuffer = std::move(state->d_buffer);
-    state->d_responseSize = state->d_responseBuffer.size();
-    sendResponse(state, now);
-    return;
+    TCPResponse response;
+    response.d_buffer = std::move(state->d_buffer);
+    state->d_state = IncomingTCPConnectionState::State::idle;
+    state->sendResponse(state, now, std::move(response));
+    return false;
   }
 
-  if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
-    return;
+  if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
+    return true;
+  }
+
+#warning move this, we should just never read another question again on this client connection
+  if (state->d_xfrStarted) {
+    /* sorry, but we are not going to resume a XFR if we have already sent some packets
+       to the client */
+    return true;
   }
 
-  setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
+  IDState ids;
+  // DEBUG: cerr<<"DQ has "<<(dq.qTag?" TAGS ": "NO TAGS")<<endl;
+  setIDStateFromDNSQuestion(ids, dq, std::move(qname));
+  // DEBUG: cerr<<"query IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
+  ids.origID = ntohs(dh->id);
 
   const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
   /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
@@ -955,350 +617,271 @@ static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, stru
   dq.size = state->d_buffer.size();
   state->d_buffer.resize(dq.len);
 
-  if (state->d_ds->useProxyProtocol) {
+  bool proxyProtocolPayloadAdded = false;
+  std::string proxyProtocolPayload;
+
+  if (ds->useProxyProtocol) {
     /* if we ever sent a TLV over a connection, we can never go back */
     if (!state->d_proxyProtocolPayloadHasTLV) {
       state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
     }
 
-    if (state->d_downstreamConnection && !state->d_proxyProtocolPayloadHasTLV && state->d_downstreamConnection->matches(state->d_ds)) {
-      /* we have an existing connection, on which we already sent a Proxy Protocol header with no values
-         (in the previous query had TLV values we would have reset the connection afterwards),
-         so let's reuse it as long as we still don't have any values */
-      state->d_proxyProtocolPayloadAdded = false;
-    }
-    else {
-      state->d_downstreamConnection.reset();
-      addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector<ProxyProtocolValue>());
-      state->d_proxyProtocolPayloadAdded = true;
+    proxyProtocolPayload = getProxyProtocolPayload(dq);
+
+    if (state->d_proxyProtocolPayloadHasTLV) {
+      /* we will not be able to reuse an existing connection anyway so let's add the payload right now */
+      addProxyProtocol(state->d_buffer, proxyProtocolPayload);
+      proxyProtocolPayloadAdded = true;
     }
   }
 
-  sendQueryToBackend(state, now);
-}
-
-static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
-{
-  //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
+  auto downstreamConnection = state->getDownstreamConnection(ds, now);
+  downstreamConnection->assignToClientConnection(state, state->d_isXFR);
 
-  if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
-    state->d_threadData.mplexer->removeReadFD(fd);
-    //cerr<<__func__<<": remove read FD "<<fd<<endl;
-    state->d_lastIOState = IOState::Done;
+  if (proxyProtocolPayloadAdded) {
+    downstreamConnection->setProxyProtocolPayloadAdded(true);
   }
-  else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
-    state->d_threadData.mplexer->removeWriteFD(fd);
-    //cerr<<__func__<<": remove write FD "<<fd<<endl;
-    state->d_lastIOState = IOState::Done;
+  else {
+    downstreamConnection->setProxyProtocolPayload(std::move(proxyProtocolPayload));
   }
 
-  if (iostate == IOState::NeedRead) {
-    if (state->d_lastIOState == IOState::NeedRead) {
-      if (ttd) {
-        /* let's update the TTD ! */
-        state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
-      }
-      return;
-    }
+  vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
 
-    state->d_lastIOState = IOState::NeedRead;
-    //cerr<<__func__<<": add read FD "<<fd<<endl;
-    state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
-  }
-  else if (iostate == IOState::NeedWrite) {
-    if (state->d_lastIOState == IOState::NeedWrite) {
-      return;
-    }
+// DEBUG: cerr<<"about to be queued query IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
+  downstreamConnection->queueQuery(TCPQuery(std::move(state->d_buffer), std::move(ids)), downstreamConnection);
 
-    state->d_lastIOState = IOState::NeedWrite;
-    //cerr<<__func__<<": add write FD "<<fd<<endl;
-    state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
-  }
-  else if (iostate == IOState::Done) {
-    state->d_lastIOState = IOState::Done;
-  }
+  //sendQueryToBackend(state, now);
+  // DEBUG: cerr<<"out of "<<__PRETTY_FUNCTION__<<endl;
+  return true;
 }
 
-static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
 {
-  if (state->d_downstreamConnection == nullptr) {
-    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+  auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+  if (fd != conn->d_ci.fd) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_ci.fd));
   }
 
-  int fd = state->d_downstreamConnection->getHandle();
-  IOState iostate = IOState::Done;
-  bool connectionDied = false;
+  struct timeval now;
+  gettimeofday(&now, 0);
+  handleIO(conn, now);
+}
 
-  try {
-    if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
-      int socketFlags = 0;
-#ifdef MSG_FASTOPEN
-      if (state->d_downstreamConnection->isFastOpenEnabled()) {
-        socketFlags |= MSG_FASTOPEN;
-      }
-#endif /* MSG_FASTOPEN */
-
-      size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, socketFlags);
-      if (sent == state->d_buffer.size()) {
-        /* request sent ! */
-        state->d_downstreamConnection->incQueries();
-        state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
-        state->d_currentPos = 0;
-        state->d_querySentTime = now;
-        iostate = IOState::NeedRead;
-        if (!state->d_isXFR && !state->d_outstanding) {
-          /* don't bother with the outstanding count for XFR queries */
-          ++state->d_ds->outstanding;
-          state->d_outstanding = true;
-        }
-      }
-      else {
-        state->d_currentPos += sent;
-        iostate = IOState::NeedWrite;
-        /* disable fast open on partial write */
-        state->d_downstreamConnection->disableFastOpen();
-      }
-    }
+void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
+{
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
+  // even though the underlying socket is not ready, so we need to actually ask for the data first
+  bool wouldBlock = false;
+  IOState iostate = IOState::Done;
+  do {
+    iostate = IOState::Done;
+    IOStateGuard ioGuard(state->d_ioState);
 
-    if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
-      // then we need to allocate a new buffer (new because we might need to re-send the query if the
-      // backend dies on us
-      // We also might need to read and send to the client more than one response in case of XFR (yeah!)
-      // should very likely be a TCPIOHandler d_downstreamHandler
-      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
-      if (iostate == IOState::Done) {
-        state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
-        state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
-        state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
-        state->d_currentPos = 0;
-      }
+    if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+      vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+      // will be handled by the ioGuard
+      //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+      return;
     }
 
-    if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
-      iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
-      if (iostate == IOState::Done) {
-        handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
-
-        if (state->d_isXFR) {
-          /* Don't reuse the TCP connection after an {A,I}XFR */
-          /* but don't reset it either, we will need to read more messages */
-        }
-        else {
-          /* if we did not send a Proxy Protocol header, let's pool the connection */
-          if (state->d_ds && state->d_ds->useProxyProtocol == false) {
-            releaseDownstreamConnection(std::move(state->d_downstreamConnection));
-          }
-          else {
-            if (state->d_proxyProtocolPayloadHasTLV) {
-              /* sent a Proxy Protocol header with TLV values, we can't reuse it */
-              state->d_downstreamConnection.reset();
+    try {
+      if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+        // DEBUG: cerr<<"doing handshake"<<endl;
+        iostate = state->d_handler.tryHandshake();
+        if (iostate == IOState::Done) {
+          // DEBUG: cerr<<"handshake done"<<endl;
+          if (state->d_handler.isTLS()) {
+            if (!state->d_handler.hasTLSSessionBeenResumed()) {
+              ++state->d_ci.cs->tlsNewSessions;
             }
             else {
-              /* if we did but there was no TLV values, let's try to reuse it but only
-                 for this incoming connection */
+              ++state->d_ci.cs->tlsResumptions;
+            }
+            if (state->d_handler.getResumedFromInactiveTicketKey()) {
+              ++state->d_ci.cs->tlsInactiveTicketKey;
+            }
+            if (state->d_handler.getUnknownTicketKey()) {
+              ++state->d_ci.cs->tlsUnknownTicketKey;
             }
           }
-        }
-        fd = -1;
 
-        state->d_responseReadTime = now;
-        try {
-          handleResponse(state, now);
+          state->d_handshakeDoneTime = now;
+          state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
         }
-        catch (const std::exception& e) {
-          vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", state->d_ds ? state->d_ds->getName() : "unknown", state->d_ci.remote.toStringWithPort(), e.what());
+        else {
+          wouldBlock = true;
         }
-        return;
       }
-    }
 
-    if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
-        state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
-        state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
-      vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
-    }
-  }
-  catch(const std::exception& e) {
-    /* most likely an EOF because the other end closed the connection,
-       but it might also be a real IO error or something else.
-       Let's just drop the connection
-    */
-    vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
-    if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
-      ++state->d_ds->tcpDiedSendingQuery;
-    }
-    else {
-      ++state->d_ds->tcpDiedReadingResponse;
-    }
-
-    /* don't increase this counter when reusing connections */
-    if (state->d_downstreamConnection && state->d_downstreamConnection->isFresh()) {
-      ++state->d_downstreamFailures;
-    }
-
-    if (state->d_outstanding) {
-      state->d_outstanding = false;
+      if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+        // DEBUG: cerr<<"reading query size"<<endl;
+        iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
+        if (iostate == IOState::Done) {
+          // DEBUG: cerr<<"query size received"<<endl;
+          state->d_state = IncomingTCPConnectionState::State::readingQuery;
+          state->d_querySizeReadTime = now;
+          if (state->d_queriesCount == 0) {
+            state->d_firstQuerySizeReadTime = now;
+          }
+          state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
+          if (state->d_querySize < sizeof(dnsheader)) {
+            /* go away */
+            // will be handled by the guard
+            //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+            return;
+          }
 
-      if (state->d_ds != nullptr) {
-        --state->d_ds->outstanding;
+          /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
+             or to add ECS without allocating a new buffer */
+          state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
+          state->d_currentPos = 0;
+        }
+        else {
+          wouldBlock = true;
+        }
       }
-    }
-    /* remove this FD from the IO multiplexer */
-    iostate = IOState::Done;
-    connectionDied = true;
-  }
 
-  if (iostate == IOState::Done) {
-    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
-  }
-  else {
-    handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
-  }
-
-  if (connectionDied) {
-    state->d_downstreamConnection.reset();
-    sendQueryToBackend(state, now);
-  }
-}
-
-static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
-{
-  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
-  if (state->d_downstreamConnection == nullptr) {
-    throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
-  }
-  if (fd != state->d_downstreamConnection->getHandle()) {
-    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
-  }
-
-  struct timeval now;
-  gettimeofday(&now, 0);
-  handleDownstreamIO(state, now);
-}
-
-static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
-{
-  int fd = state->d_ci.fd;
-  IOState iostate = IOState::Done;
-
-  if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
-    vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
-    handleNewIOState(state, IOState::Done, fd, handleIOCallback);
-    return;
-  }
-
-  try {
-    if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
-      iostate = state->d_handler.tryHandshake();
-      if (iostate == IOState::Done) {
-        if (state->d_handler.isTLS()) {
-          if (!state->d_handler.hasTLSSessionBeenResumed()) {
-            ++state->d_ci.cs->tlsNewSessions;
+      if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+        // DEBUG: cerr<<"reading query"<<endl;
+        iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+        if (iostate == IOState::Done) {
+          // DEBUG: cerr<<"query received"<<endl;
+          //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+          if (handleQuery(state, now)) {
+            // DEBUG: cerr<<"handle query returned true"<<endl;
+            // if the query has been passed to a backend, or dropped, we can start
+            // reading again, or sending queued responses
+            if (state->d_queuedResponses.empty()) {
+              state->resetForNewQuery();
+              // DEBUG: cerr<<__LINE__<<endl;
+              iostate = IOState::NeedRead;
+              //state->d_ioState->update(IOState::NeedRead, handleIOCallback, state, state->getClientReadTTD(now));
+              // DEBUG: cerr<<__LINE__<<endl;
+              //ioGuard.release();
+            }
+            else {
+              TCPResponse resp = std::move(state->d_queuedResponses.front());
+              state->d_queuedResponses.pop_front();
+              ioGuard.release();
+              state->sendResponse(state, now, std::move(resp));
+              return;
+            }
           }
           else {
-            ++state->d_ci.cs->tlsResumptions;
-          }
-          if (state->d_handler.getResumedFromInactiveTicketKey()) {
-            ++state->d_ci.cs->tlsInactiveTicketKey;
-          }
-          if (state->d_handler.getUnknownTicketKey()) {
-            ++state->d_ci.cs->tlsUnknownTicketKey;
+            /* otherwise the state should already be waiting for
+               the socket to be writable */
+            // DEBUG: cerr<<"should be waiting for writable socket"<<endl;
+            ioGuard.release();
+            return;
           }
         }
-
-        state->d_handshakeDoneTime = now;
-        state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+        else {
+          wouldBlock = true;
+        }
       }
-    }
 
-    if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
-      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
-      if (iostate == IOState::Done) {
-        state->d_state = IncomingTCPConnectionState::State::readingQuery;
-        state->d_querySizeReadTime = now;
-        if (state->d_queriesCount == 0) {
-          state->d_firstQuerySizeReadTime = now;
-        }
-        state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
-        if (state->d_querySize < sizeof(dnsheader)) {
-          /* go away */
-          handleNewIOState(state, IOState::Done, fd, handleIOCallback);
-          return;
+      if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+        // DEBUG: cerr<<"sending response"<<endl;
+        iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
+        if (iostate == IOState::Done) {
+          // DEBUG: cerr<<"response sent"<<endl;
+          iostate = handleResponseSent(state, now);
+        } else {
+          wouldBlock = true;
+          // DEBUG: cerr<<"partial write"<<endl;
         }
-
-        /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
-           or to add ECS without allocating a new buffer */
-        state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
-        state->d_currentPos = 0;
+        // DEBUG: cerr<<__LINE__<<endl;
+        //state->d_ioState->update(IOState::NeedRead, handleIOCallback, state, state->getClientReadTTD(now));
+        //// DEBUG: cerr<<__LINE__<<endl;
       }
-    }
 
-    if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
-      iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
-      if (iostate == IOState::Done) {
-        handleNewIOState(state, IOState::Done, fd, handleIOCallback);
-        handleQuery(state, now);
-        return;
+      if (state->d_state != IncomingTCPConnectionState::State::idle &&
+          state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
+          state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+          state->d_state != IncomingTCPConnectionState::State::readingQuery &&
+          state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
+        vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
       }
     }
-
-    if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
-      iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
-      if (iostate == IOState::Done) {
-        handleResponseSent(state, now);
-        return;
+    catch(const std::exception& e) {
+      /* most likely an EOF because the other end closed the connection,
+         but it might also be a real IO error or something else.
+         Let's just drop the connection
+      */
+      if (state->d_state == IncomingTCPConnectionState::State::idle ||
+          state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
+          state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
+          state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+        ++state->d_ci.cs->tcpDiedReadingQuery;
+      }
+      else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+        ++state->d_ci.cs->tcpDiedSendingResponse;
       }
-    }
 
-    if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
-        state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
-        state->d_state != IncomingTCPConnectionState::State::readingQuery &&
-        state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
-      vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
-    }
-  }
-  catch(const std::exception& e) {
-    /* most likely an EOF because the other end closed the connection,
-       but it might also be a real IO error or something else.
-       Let's just drop the connection
-    */
-    if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
-        state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
-        state->d_state == IncomingTCPConnectionState::State::readingQuery) {
-      ++state->d_ci.cs->tcpDiedReadingQuery;
-    }
-    else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
-      ++state->d_ci.cs->tcpDiedSendingResponse;
+      if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
+        // DEBUG: cerr<<"Got an exception while handling TCP query: "<<e.what()<<endl;
+        vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+      }
+      else {
+        vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+        // DEBUG: cerr<<"Closing TCP client connection: "<<e.what()<<endl;
+      }
+      /* remove this FD from the IO multiplexer */
+      iostate = IOState::Done;
     }
 
-    if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
-      vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+    if (iostate == IOState::Done) {
+      // DEBUG: cerr<<__LINE__<<endl;
+      state->d_ioState->update(iostate, handleIOCallback, state);
+      // DEBUG: cerr<<__LINE__<<endl;
     }
     else {
-      vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+      // DEBUG: cerr<<__LINE__<<endl;
+      state->d_ioState->update(iostate, handleIOCallback, state, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+      // DEBUG: cerr<<__LINE__<<endl;
     }
-    /* remove this FD from the IO multiplexer */
-    iostate = IOState::Done;
+    ioGuard.release();
+  }
+  while (state->d_state == IncomingTCPConnectionState::State::readingQuerySize && iostate == IOState::NeedRead && !wouldBlock);
+}
+
+void IncomingTCPConnectionState::notifyIOError(std::shared_ptr<IncomingTCPConnectionState>& state, IDState&& query, const struct timeval& now)
+{
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  if (d_isXFR) {
+    d_xfrDone = true;
   }
 
-  if (iostate == IOState::Done) {
-    handleNewIOState(state, iostate, fd, handleIOCallback);
+  if (d_state == State::sendingResponse) {
+    /* if we have responses to send, let's do that first */
+  }
+  else if (!d_queuedResponses.empty()) {
+    /* stop reading and send what we have */
+    TCPResponse resp = std::move(d_queuedResponses.front());
+    d_queuedResponses.pop_front();
+    sendResponse(state, now, std::move(resp));
   }
   else {
-    handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+    // the backend code already tried to reconnect if it was possible
+    d_lastIOState = IOState::Done;
+    d_ioState->reset();
   }
+
+  // DEBUG: cerr<<"out "<<__PRETTY_FUNCTION__<<endl;
 }
 
-static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+void IncomingTCPConnectionState::handleXFRResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
 {
-  auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
-  if (fd != state->d_ci.fd) {
-    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
-  }
-  struct timeval now;
-  gettimeofday(&now, 0);
+  sendResponse(state, now, std::move(response));
+}
 
-  handleIO(state, now);
+void IncomingTCPConnectionState::handleTimeout(bool write)
+{
+  // DEBUG: cerr<<"client timeout"<<endl;
+  ++d_ci.cs->tcpClientTimeouts;
+  d_lastIOState = IOState::Done;
+  d_ioState->reset();
 }
 
 static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
@@ -1333,7 +916,7 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param
     /* let's update the remaining time */
     state->d_remainingTime = g_maxTCPConnectionDuration;
 
-    handleIO(state, now);
+    IncomingTCPConnectionState::handleIO(state, now);
   }
   catch(...) {
     delete citmp;
@@ -1342,7 +925,7 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param
   }
 }
 
-void tcpClientThread(int pipefd)
+static void tcpClientThread(int pipefd)
 {
   /* we get launched with a pipe on which we receive file descriptors from clients that we own
      from that point on */
@@ -1361,42 +944,50 @@ void tcpClientThread(int pipefd)
     data.mplexer->run(&now);
 
     if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
-      cleanupClosedTCPConnections();
+      DownstreamConnectionsManager::cleanupClosedTCPConnections();
       lastTCPCleanup = now.tv_sec;
     }
 
     if (now.tv_sec > lastTimeoutScan) {
       lastTimeoutScan = now.tv_sec;
       auto expiredReadConns = data.mplexer->getTimeouts(now, false);
-      for(const auto& conn : expiredReadConns) {
-        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
-        if (conn.first == state->d_ci.fd) {
-          vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
-          ++state->d_ci.cs->tcpClientTimeouts;
+      for (const auto& cbData : expiredReadConns) {
+        if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
+          auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
+          if (cbData.first == state->d_ci.fd) {
+            vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+            state->handleTimeout(false);
+          }
         }
-        else if (state->d_ds) {
-          vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
-          ++state->d_ci.cs->tcpDownstreamTimeouts;
-          ++state->d_ds->tcpReadTimeouts;
+        else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
+          auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
+          vinfolog("Timeout (read) from remote backend %s", conn->getBackendName());
+          conn->handleTimeout(now, false);
         }
-        data.mplexer->removeReadFD(conn.first);
-        state->d_lastIOState = IOState::Done;
       }
 
       auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
-      for(const auto& conn : expiredWriteConns) {
-        auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
-        if (conn.first == state->d_ci.fd) {
-          vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
-          ++state->d_ci.cs->tcpClientTimeouts;
+      for (const auto& cbData : expiredWriteConns) {
+        if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
+          auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
+          if (cbData.first == state->d_ci.fd) {
+            vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+            state->handleTimeout(true);
+          }
+        }
+        else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
+          auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
+          vinfolog("Timeout (write) from remote backend %s", conn->getBackendName());
+          conn->handleTimeout(now, true);
         }
-        else if (state->d_ds) {
-          vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
-          ++state->d_ci.cs->tcpDownstreamTimeouts;
-          ++state->d_ds->tcpWriteTimeouts;
+#if 0
+        try {
+          data.mplexer->removeWriteFD(cbData.first);
+        }
+        catch (const FDMultiplexerException& fde) {
+          warnlog("Exception while removing a socket (%d) after a write timeout: %s", cbData.first, fde.what());
         }
-        data.mplexer->removeWriteFD(conn.first);
-        state->d_lastIOState = IOState::Done;
+#endif
       }
     }
   }
index 55c0868647b1cf02fe4db9c04d3fb584ca2d6f79..0dee6cc8df28633fd562ea92cf2ade619d249df1 100644 (file)
@@ -465,13 +465,60 @@ struct ClientState;
 struct IDState
 {
   IDState(): sentTime(true), delayMsec(0), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;}
-  IDState(const IDState& orig): origRemote(orig.origRemote), origDest(orig.origDest), age(orig.age)
+  IDState(const IDState& orig) = delete;
+  IDState(IDState&& rhs): origRemote(rhs.origRemote), origDest(rhs.origDest), sentTime(rhs.sentTime), qname(std::move(rhs.qname)), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), subnet(rhs.subnet), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), age(rhs.age), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), origFD(rhs.origFD), delayMsec(rhs.delayMsec), tempFailureTTL(rhs.tempFailureTTL), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope)
   {
-    usageIndicator.store(orig.usageIndicator.load());
-    origFD = orig.origFD;
-    origID = orig.origID;
-    delayMsec = orig.delayMsec;
-    tempFailureTTL = orig.tempFailureTTL;
+    if (rhs.isInUse()) {
+      throw std::runtime_error("Trying to move an in-use IDState");
+    }
+
+#ifdef HAVE_PROTOBUF
+    uniqueId = std::move(rhs.uniqueId);
+#endif
+  }
+
+  IDState& operator=(IDState&& rhs)
+  {
+    if (isInUse()) {
+      throw std::runtime_error("Trying to overwrite an in-use IDState");
+    }
+
+    if (rhs.isInUse()) {
+      throw std::runtime_error("Trying to move an in-use IDState");
+    }
+
+    origRemote = rhs.origRemote;
+    origDest = rhs.origDest;
+    sentTime = rhs.sentTime;
+    qname = std::move(rhs.qname);
+    dnsCryptQuery = std::move(rhs.dnsCryptQuery);
+    subnet = rhs.subnet;
+    packetCache = std::move(rhs.packetCache);
+    qTag = std::move(rhs.qTag);
+    cs = rhs.cs;
+    du = std::move(rhs.du);
+    cacheKey = rhs.cacheKey;
+    cacheKeyNoECS = rhs.cacheKeyNoECS;
+    age = rhs.age;
+    qtype = rhs.qtype;
+    qclass = rhs.qclass;
+    origID = rhs.origID;
+    origFlags = rhs.origFlags;
+    origFD = rhs.origFD;
+    delayMsec = rhs.delayMsec;
+    tempFailureTTL = rhs.tempFailureTTL;
+    ednsAdded = rhs.ednsAdded;
+    ecsAdded = rhs.ecsAdded;
+    skipCache = rhs.skipCache;
+    destHarvested = rhs.destHarvested;
+    dnssecOK = rhs.dnssecOK;
+    useZeroScope = rhs.useZeroScope;
+
+#ifdef HAVE_PROTOBUF
+    uniqueId = std::move(rhs.uniqueId);
+#endif
+
+    return *this;
   }
 
   static const int64_t unusedIndicator = -1;
@@ -563,14 +610,14 @@ struct IDState
   uint16_t origID;                                            // 2
   uint16_t origFlags;                                         // 2
   int origFD{-1};
-  int delayMsec;
+  int delayMsec{0};
   boost::optional<uint32_t> tempFailureTTL;
   bool ednsAdded{false};
   bool ecsAdded{false};
   bool skipCache{false};
   bool destHarvested{false}; // if true, origDest holds the original dest addr, otherwise the listening addr
   bool dnssecOK{false};
-  bool useZeroScope;
+  bool useZeroScope{false};
 };
 
 typedef std::unordered_map<string, unsigned int> QueryCountRecords;
index 1bea3b333010735a8452785d8bda01b8e1e0ce34..ca5a07939583c7ba067e8d7182d27f387256d2f4 100644 (file)
@@ -159,6 +159,8 @@ dnsdist_SOURCES = \
        dnsdist-snmp.cc dnsdist-snmp.hh \
        dnsdist-systemd.cc dnsdist-systemd.hh \
        dnsdist-tcp.cc \
+       dnsdist-tcp-downstream.cc dnsdist-tcp-downstream.hh \
+       dnsdist-tcp-upstream.hh \
        dnsdist-web.cc dnsdist-web.hh \
        dnsdist-xpf.cc dnsdist-xpf.hh \
        dnsdist.cc dnsdist.hh \
@@ -200,6 +202,7 @@ dnsdist_SOURCES = \
        statnode.cc statnode.hh \
        svc-records.cc svc-records.hh \
        tcpiohandler.cc tcpiohandler.hh \
+       tcpiohandler-mplexer.hh \
        threadname.hh threadname.cc \
        uuid-utils.hh uuid-utils.cc \
        views.hh \
index 38398e28eb11406b3128e5d2d57481d994300943..2eb19a8c59d343105f130dfc64443803e0f47454 100644 (file)
@@ -150,7 +150,7 @@ void DownstreamState::setWeight(int newWeight)
   }
 }
 
-DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_), name(remote_.toStringWithPort()), nameWithAddr(remote_.toStringWithPort())
+DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), idStates(g_maxOutstanding), sourceAddr(sourceAddr_), sourceItf(sourceItf_), name(remote_.toStringWithPort()), nameWithAddr(remote_.toStringWithPort())
 {
   id = getUniqueID();
   threadStarted.clear();
@@ -164,7 +164,6 @@ DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress
 
   if (connect && !IsAnyAddress(remote)) {
     reconnect();
-    idStates.resize(g_maxOutstanding);
     sw.start();
   }
 }
index 083b0d345af3cddc890182f152815f6fd0b6460b..e7773a2003fb69c9d9619c6fefe905d0190b2d98 100644 (file)
 
 #include "dnsdist-proxy-protocol.hh"
 
-bool addProxyProtocol(DNSQuestion& dq)
+std::string getProxyProtocolPayload(const DNSQuestion& dq)
+{
+  return makeProxyHeader(dq.tcp, *dq.remote, *dq.local, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector<ProxyProtocolValue>());
+}
+
+bool addProxyProtocol(DNSQuestion& dq, const std::string& payload)
 {
-  auto payload = makeProxyHeader(dq.tcp, *dq.remote, *dq.local, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector<ProxyProtocolValue>());
   if ((dq.size - dq.len) < payload.size()) {
     return false;
   }
@@ -36,10 +40,14 @@ bool addProxyProtocol(DNSQuestion& dq)
   return true;
 }
 
-bool addProxyProtocol(std::vector<uint8_t>& buffer, bool tcp, const ComboAddress& source, const ComboAddress& destination, const std::vector<ProxyProtocolValue>& values)
+bool addProxyProtocol(DNSQuestion& dq)
 {
-  auto payload = makeProxyHeader(tcp, source, destination, values);
+  auto payload = getProxyProtocolPayload(dq);
+  return addProxyProtocol(dq, payload);
+}
 
+bool addProxyProtocol(std::vector<uint8_t>& buffer, const std::string& payload)
+{
   auto previousSize = buffer.size();
   if (payload.size() > (std::numeric_limits<size_t>::max() - previousSize)) {
     return false;
@@ -51,3 +59,9 @@ bool addProxyProtocol(std::vector<uint8_t>& buffer, bool tcp, const ComboAddress
 
   return true;
 }
+
+bool addProxyProtocol(std::vector<uint8_t>& buffer, bool tcp, const ComboAddress& source, const ComboAddress& destination, const std::vector<ProxyProtocolValue>& values)
+{
+  auto payload = makeProxyHeader(tcp, source, destination, values);
+  return addProxyProtocol(buffer, payload);
+}
index 433a7d2394e931242163f1afc5d74d30b9a542be..a218a403d411a12261351e193a281a4ea95a5694 100644 (file)
@@ -23,5 +23,9 @@
 
 #include "dnsdist.hh"
 
+std::string getProxyProtocolPayload(const DNSQuestion& dq);
+
 bool addProxyProtocol(DNSQuestion& dq);
+bool addProxyProtocol(DNSQuestion& dq, const std::string& payload);
+bool addProxyProtocol(std::vector<uint8_t>& buffer, const std::string& payload);
 bool addProxyProtocol(std::vector<uint8_t>& buffer, bool tcp, const ComboAddress& source, const ComboAddress& destination, const std::vector<ProxyProtocolValue>& values);
diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc
new file mode 100644 (file)
index 0000000..86dee17
--- /dev/null
@@ -0,0 +1,469 @@
+
+#include "dnsdist-tcp-downstream.hh"
+#include "dnsdist-tcp-upstream.hh"
+
+const uint16_t TCPConnectionToBackend::s_xfrID = 0;
+
+void TCPConnectionToBackend::assignToClientConnection(std::shared_ptr<IncomingTCPConnectionState>& clientConn, bool isXFR)
+{
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  if (isXFR) {
+    d_usedForXFR = true;
+  }
+
+  d_clientConn = clientConn;
+  d_ioState = make_unique<IOStateHandler>(clientConn->getIOMPlexer(), d_socket->getHandle());
+}
+
+IOState TCPConnectionToBackend::sendNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
+{
+  conn->d_currentQuery = std::move(conn->d_pendingQueries.front());
+  conn->d_pendingQueries.pop_front();
+  conn->d_state = State::sendingQueryToBackend;
+  return IOState::NeedWrite;
+}
+
+void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now)
+{
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  if (conn->d_socket == nullptr) {
+    throw std::runtime_error("No downstream socket in " + std::string(__PRETTY_FUNCTION__) + "!");
+  }
+
+  bool connectionDied = false;
+  IOState iostate = IOState::Done;
+  IOStateGuard ioGuard(conn->d_ioState);
+  int fd = conn->d_socket->getHandle();
+
+  try {
+    if (conn->d_state == State::sendingQueryToBackend) {
+      // DEBUG: cerr<<"sending query to backend over FD "<<fd<<endl;
+      int socketFlags = 0;
+#ifdef MSG_FASTOPEN
+      if (conn->isFastOpenEnabled()) {
+        socketFlags |= MSG_FASTOPEN;
+      }
+#endif /* MSG_FASTOPEN */
+
+      size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&conn->d_currentQuery.d_buffer.at(conn->d_currentPos)), conn->d_currentQuery.d_buffer.size() - conn->d_currentPos, &conn->d_ds->remote, &conn->d_ds->sourceAddr, conn->d_ds->sourceItf, socketFlags);
+      if (sent == conn->d_currentQuery.d_buffer.size()) {
+        // DEBUG: cerr<<"query sent to backend"<<endl;
+        /* request sent ! */
+        conn->incQueries();
+        conn->d_currentPos = 0;
+        //conn->d_currentQuery.d_querySentTime = now;
+        // DEBUG: cerr<<"adding a pending response for ID "<<conn->d_currentQuery.d_idstate.origID<<" and QNAME "<<conn->d_currentQuery.d_idstate.qname<<endl;
+        // DEBUG: cerr<<"IDS has "<<(conn->d_currentQuery.d_idstate.qTag?"tags":"no tags")<<endl;
+        conn->d_pendingResponses[conn->d_currentQuery.d_idstate.origID] = std::move(conn->d_currentQuery);
+        conn->d_currentQuery.d_buffer.clear();
+#if 0
+        if (!conn->d_usedForXFR) {
+          /* don't bother with the outstanding count for XFR queries */
+          ++conn->d_ds->outstanding;
+          ++conn->d_outstanding;
+        }
+#endif
+
+        if (conn->d_pendingQueries.empty()) {
+          conn->d_state = State::readingResponseSizeFromBackend;
+          conn->d_currentPos = 0;
+          conn->d_responseBuffer.resize(sizeof(uint16_t));
+          iostate = IOState::NeedRead;
+        }
+        else {
+          iostate = sendNextQuery(conn);
+        }
+      }
+      else {
+        conn->d_currentPos += sent;
+        iostate = IOState::NeedWrite;
+        /* disable fast open on partial write */
+        conn->disableFastOpen();
+      }
+    }
+
+    if (conn->d_state == State::readingResponseSizeFromBackend) {
+      // DEBUG: cerr<<"reading response size from backend"<<endl;
+      // then we need to allocate a new buffer (new because we might need to re-send the query if the
+      // backend dies on us)
+      // We also might need to read and send to the client more than one response in case of XFR (yeah!)
+      // should very likely be a TCPIOHandler d_downstreamHandler
+      conn->d_responseBuffer.resize(sizeof(uint16_t));
+      iostate = tryRead(fd, conn->d_responseBuffer, conn->d_currentPos, sizeof(uint16_t) - conn->d_currentPos);
+      if (iostate == IOState::Done) {
+        // DEBUG: cerr<<"got response size from backend"<<endl;
+        conn->d_state = State::readingResponseFromBackend;
+        conn->d_responseSize = conn->d_responseBuffer.at(0) * 256 + conn->d_responseBuffer.at(1);
+        conn->d_responseBuffer.reserve(conn->d_responseSize + /* we will need to prepend the size later */ 2);
+        conn->d_responseBuffer.resize(conn->d_responseSize);
+        conn->d_currentPos = 0;
+      }
+    }
+
+    if (conn->d_state == State::readingResponseFromBackend) {
+      // DEBUG: cerr<<"reading response from backend"<<endl;
+      iostate = tryRead(fd, conn->d_responseBuffer, conn->d_currentPos, conn->d_responseSize - conn->d_currentPos);
+      if (iostate == IOState::Done) {
+        // DEBUG: cerr<<"got response from backend"<<endl;
+        //conn->d_responseReadTime = now;
+        try {
+          iostate = conn->handleResponse(now);
+        }
+        catch (const std::exception& e) {
+          vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what());
+        }
+        //return;
+      }
+    }
+
+    if (conn->d_state != State::idle &&
+        conn->d_state != State::sendingQueryToBackend &&
+        conn->d_state != State::readingResponseSizeFromBackend &&
+        conn->d_state != State::readingResponseFromBackend) {
+      vinfolog("Unexpected state %d in TCPConnectionToBackend::handleIO", static_cast<int>(conn->d_state));
+    }
+  }
+  catch (const std::exception& e) {
+    /* most likely an EOF because the other end closed the connection,
+       but it might also be a real IO error or something else.
+       Let's just drop the connection
+    */
+    vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_ioState->getState() == IOState::NeedRead ? "reading from" : "writing to"), conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what());
+    if (conn->d_state == State::sendingQueryToBackend) {
+      ++conn->d_ds->tcpDiedSendingQuery;
+    }
+    else {
+      ++conn->d_ds->tcpDiedReadingResponse;
+    }
+
+    /* don't increase this counter when reusing connections */
+    if (conn->d_fresh) {
+      ++conn->d_downstreamFailures;
+    }
+
+#if 0
+    if (conn->d_outstanding) {
+      conn->d_outstanding = false;
+
+      if (conn->d_ds != nullptr) {
+        --conn->d_ds->outstanding;
+      }
+    }
+#endif
+    /* remove this FD from the IO multiplexer */
+    iostate = IOState::Done;
+    connectionDied = true;
+  }
+
+  if (connectionDied) {
+    bool reconnected = false;
+    // DEBUG: cerr<<"connection died, number of failures is "<<conn->d_downstreamFailures<<", retries is "<<conn->d_ds->retries<<endl;
+
+    if ((!conn->d_usedForXFR || conn->d_queries == 0) && conn->d_downstreamFailures < conn->d_ds->retries) {
+      // DEBUG: cerr<<"reconnecting"<<endl;
+      conn->d_ioState->reset();
+      ioGuard.release();
+
+      if (conn->reconnect()) {
+        // DEBUG: cerr<<"reconnected"<<endl;
+
+        conn->d_ioState = make_unique<IOStateHandler>(conn->d_clientConn->getIOMPlexer(), conn->d_socket->getHandle());
+        // DEBUG: cerr<<"new state"<<endl;
+
+        for (auto& pending : conn->d_pendingResponses) {
+          conn->d_pendingQueries.push_back(std::move(pending.second));
+        }
+        conn->d_pendingResponses.clear();
+        conn->d_currentPos = 0;
+
+        if (conn->d_state == State::doingHandshake ||
+            conn->d_state == State::sendingQueryToBackend) {
+          iostate = IOState::NeedWrite;
+          // resume sending query
+        }
+        else {
+          // DEBUG: cerr<<"sending next query"<<endl;
+          iostate = sendNextQuery(conn);
+          // DEBUG: cerr<<"after call to sendNextQuery"<<endl;
+        }
+
+        if (!conn->d_proxyProtocolPayloadAdded && !conn->d_proxyProtocolPayload.empty()) {
+          conn->d_currentQuery.d_buffer.insert(conn->d_currentQuery.d_buffer.begin(), conn->d_proxyProtocolPayload.begin(), conn->d_proxyProtocolPayload.end());
+          conn->d_proxyProtocolPayloadAdded = true;
+        }
+
+        reconnected = true;
+      }
+    }
+
+    if (!reconnected) {
+      /* reconnect failed, we give up */
+      conn->d_connectionDied = true;
+      conn->notifyAllQueriesFailed(now);
+    }
+  }
+
+  if (iostate == IOState::Done) {
+    // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<", done"<<endl;
+    conn->d_ioState->update(iostate, handleIOCallback, conn);
+  }
+  else {
+    // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<", updating to "<<(int)iostate<<endl;
+    conn->d_ioState->update(iostate, handleIOCallback, conn, iostate == IOState::NeedRead ? conn->getBackendReadTTD(now) : conn->getBackendWriteTTD(now));
+  }
+  ioGuard.release();
+
+}
+
+void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+  auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param);
+  if (fd != conn->getHandle()) {
+    throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle()));
+  }
+
+  struct timeval now;
+  gettimeofday(&now, 0);
+  handleIO(conn, now);
+}
+
+void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr<TCPConnectionToBackend>& sharedSelf)
+{
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  // DEBUG: cerr<<"IDS has "<<(query.d_idstate.qTag?"tags":"no tags")<<endl;
+  if (d_ioState == nullptr) {
+    throw std::runtime_error("Trying to queue a query to a TCP connection that has no incoming client connection assigned");
+  }
+
+  // if we are not already sending a query or in the middle of reading a response (so idle or doingHandshake),
+  // start sending the query
+  if (d_state == State::idle || d_state == State::waitingForResponseFromBackend) {
+    d_state = State::sendingQueryToBackend;
+    d_currentQuery = std::move(query);
+    // DEBUG: cerr<<"need write"<<endl;
+
+    struct timeval now;
+    gettimeofday(&now, 0);
+
+    d_ioState->update(IOState::NeedWrite, handleIOCallback, sharedSelf, getBackendWriteTTD(now));
+  }
+  else {
+    // store query in the list of queries to send
+    d_pendingQueries.push_back(std::move(query));
+  }
+  // DEBUG: cerr<<"out of "<<__PRETTY_FUNCTION__<<endl;
+}
+
+bool TCPConnectionToBackend::reconnect()
+{
+  std::unique_ptr<Socket> result;
+
+  if (d_socket) {
+    // DEBUG: cerr<<"closing socket "<<d_socket->getHandle()<<endl;
+    shutdown(d_socket->getHandle(), SHUT_RDWR);
+    d_socket.reset();
+    d_ioState.reset();
+    --d_ds->tcpCurrentConnections;
+  }
+
+  do {
+    vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures);
+    try {
+      result = std::unique_ptr<Socket>(new Socket(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0));
+      // DEBUG: cerr<<"result of connect is "<<result->getHandle()<<endl;
+      if (!IsAnyAddress(d_ds->sourceAddr)) {
+        SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
+#ifdef IP_BIND_ADDRESS_NO_PORT
+        if (d_ds->ipBindAddrNoPort) {
+          SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+        }
+#endif
+#ifdef SO_BINDTODEVICE
+        if (!d_ds->sourceItfName.empty()) {
+          int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length());
+          if (res != 0) {
+            vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror());
+          }
+        }
+#endif
+        result->bind(d_ds->sourceAddr, false);
+      }
+      result->setNonBlocking();
+#ifdef MSG_FASTOPEN
+      if (!d_ds->tcpFastOpen || !isFastOpenEnabled()) {
+        SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0);
+      }
+#else
+      SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0);
+#endif /* MSG_FASTOPEN */
+
+      d_socket = std::move(result);
+      // DEBUG: cerr<<"connected new socket "<<d_socket->getHandle()<<endl;
+      ++d_ds->tcpCurrentConnections;
+      return true;
+    }
+    catch(const std::runtime_error& e) {
+      vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what());
+      d_downstreamFailures++;
+      if (d_downstreamFailures > d_ds->retries) {
+        throw;
+      }
+    }
+  }
+  while (d_downstreamFailures <= d_ds->retries);
+
+  return false;
+}
+
+void TCPConnectionToBackend::handleTimeout(const struct timeval& now, bool write)
+{
+  if (write) {
+    ++d_ds->tcpWriteTimeouts;
+  }
+  else {
+    ++d_ds->tcpReadTimeouts;
+  }
+
+  if (d_ioState) {
+    d_ioState->reset();
+  }
+
+  notifyAllQueriesFailed(now, true);
+}
+
+void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, bool timeout)
+{
+  d_connectionDied = true;
+  //auto clientConn = d_clientConn.lock();
+  //if (!clientConn) {
+  //  d_clientConn.reset();
+  //  return;
+  //}
+  auto& clientConn = d_clientConn;
+  if (!clientConn->active()) {
+    // a client timeout occured, or something like that */
+    d_connectionDied = true;
+    d_clientConn.reset();
+    return;
+  }
+
+  if (timeout) {
+    ++clientConn->d_ci.cs->tcpDownstreamTimeouts;
+  }
+
+  if (d_state == State::doingHandshake || d_state == State::sendingQueryToBackend) {
+    clientConn->notifyIOError(clientConn, std::move(d_currentQuery.d_idstate), now);
+  }
+
+  for (auto& query : d_pendingQueries) {
+    clientConn->notifyIOError(clientConn, std::move(query.d_idstate), now);
+  }
+
+  for (auto& response : d_pendingResponses) {
+    clientConn->notifyIOError(clientConn, std::move(response.second.d_idstate), now);
+  }
+
+  d_pendingQueries.clear();
+  d_pendingResponses.clear();
+
+  d_clientConn.reset();
+}
+
+IOState TCPConnectionToBackend::handleResponse(const struct timeval& now)
+{
+  // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+  //auto clientConn = d_clientConn.lock();
+  //if (!clientConn) {
+  //d_clientConn.reset();
+  //  d_connectionDied = true;
+  //  // DEBUG: cerr<<"connection to client died, bye bye"<<endl;
+  //  return IOState::Done;
+  //}
+
+  auto& clientConn = d_clientConn;
+  if (!clientConn->active()) {
+    // DEBUG: cerr<<"client is not active"<<endl;
+    // a client timeout occured, or something like that */
+    d_connectionDied = true;
+    d_clientConn.reset();
+    return IOState::Done;
+  }
+
+  if (d_usedForXFR) {
+    // DEBUG: cerr<<"XFR!"<<endl;
+    TCPResponse response;
+    response.d_buffer = std::move(d_responseBuffer);
+    response.d_ds = d_ds;
+    clientConn->handleXFRResponse(clientConn, now, std::move(response));
+    d_state = State::readingResponseSizeFromBackend;
+    d_currentPos = 0;
+    d_responseBuffer.resize(sizeof(uint16_t));
+    return IOState::NeedRead;
+    // get ready to read the next packet, if any
+  }
+  else {
+    // DEBUG: cerr<<"not XFR, phew"<<endl;
+    uint16_t queryId = 0;
+    try {
+      queryId = getQueryIdFromResponse();
+    }
+    catch (const std::exception& e) {
+      notifyAllQueriesFailed(now);
+      throw;
+    }
+
+    auto it = d_pendingResponses.find(queryId);
+    if (it == d_pendingResponses.end()) {
+      // DEBUG: cerr<<"could not found any corresponding query for ID "<<queryId<<endl;
+      notifyAllQueriesFailed(now);
+      return IOState::Done;
+    }
+    auto ids = std::move(it->second.d_idstate);
+    // DEBUG: cerr<<"IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
+    // DEBUG: cerr<<"passing response to client connection for "<<ids.qname<<endl;
+    clientConn->handleResponse(clientConn, now, TCPResponse(std::move(d_responseBuffer), std::move(ids), d_ds));
+    d_pendingResponses.erase(it);
+
+    if (!d_pendingQueries.empty()) {
+      // DEBUG: cerr<<"still have some queries to send"<<endl;
+      d_state = State::sendingQueryToBackend;
+      d_currentQuery = std::move(d_pendingQueries.front());
+      d_pendingQueries.pop_front();
+      return IOState::NeedWrite;
+    }
+    else if (!d_pendingResponses.empty()) {
+      // DEBUG: cerr<<"still have some responses to read"<<endl;
+      d_state = State::readingResponseSizeFromBackend;
+      d_currentPos = 0;
+      d_responseBuffer.resize(sizeof(uint16_t));
+      return IOState::NeedRead;
+    }
+    else {
+      // DEBUG: cerr<<"nothing to do, phewwwww"<<endl;
+      d_state = State::idle;
+      d_clientConn.reset();
+      return IOState::Done;
+    }
+  }
+}
+
+uint16_t TCPConnectionToBackend::getQueryIdFromResponse()
+{
+  if (d_responseBuffer.size() < sizeof(dnsheader)) {
+    throw std::runtime_error("Unable to get query ID in a too small (" + std::to_string(d_responseBuffer.size()) + ") response from " + d_ds->getNameWithAddr());
+  }
+
+  dnsheader dh;
+  memcpy(&dh, &d_responseBuffer.at(0), sizeof(dh));
+  return ntohs(dh.id);
+}
+
+void TCPConnectionToBackend::setProxyProtocolPayload(std::string&& payload)
+{
+  d_proxyProtocolPayload = std::move(payload);
+}
+
+void TCPConnectionToBackend::setProxyProtocolPayloadAdded(bool added)
+{
+  d_proxyProtocolPayloadAdded = added;
+}
diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh
new file mode 100644 (file)
index 0000000..9d24de5
--- /dev/null
@@ -0,0 +1,212 @@
+#pragma once
+
+#include <queue>
+
+#include "sstuff.hh"
+#include "tcpiohandler-mplexer.hh"
+#include "dnsdist.hh"
+
+struct TCPQuery
+{
+  TCPQuery()
+  {
+  }
+
+  TCPQuery(std::vector<uint8_t>&& buffer, IDState&& state): d_idstate(std::move(state)), d_buffer(std::move(buffer))
+  {
+  }
+
+  IDState d_idstate;
+  std::vector<uint8_t> d_buffer;
+};
+
+struct TCPResponse : public TCPQuery
+{
+  TCPResponse()
+  {
+  }
+
+  TCPResponse(std::vector<uint8_t>&& buffer, IDState&& state, std::shared_ptr<DownstreamState> ds): TCPQuery(std::move(buffer), std::move(state)), d_ds(ds)
+  {
+  }
+
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
+  dnsheader d_cleartextDH;
+  bool d_selfGenerated{false};
+};
+
+class IncomingTCPConnectionState;
+
+class TCPConnectionToBackend
+{
+public:
+  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, const struct timeval& now): d_responseBuffer(s_maxPacketCacheEntrySize), d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen)
+  {
+    reconnect();
+  }
+
+  ~TCPConnectionToBackend()
+  {
+    if (d_ds && d_socket) {
+      --d_ds->tcpCurrentConnections;
+      struct timeval now;
+      gettimeofday(&now, nullptr);
+
+      auto diff = now - d_connectionStartTime;
+      d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
+    }
+  }
+
+  void assignToClientConnection(std::shared_ptr<IncomingTCPConnectionState>& clientConn, bool isXFR);
+
+  int getHandle() const
+  {
+    if (!d_socket) {
+      throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
+    }
+
+    return d_socket->getHandle();
+  }
+
+  const ComboAddress& getRemote() const
+  {
+    return d_ds->remote;
+  }
+
+  const std::string& getBackendName() const
+  {
+    return d_ds->getName();
+  }
+
+  bool isFresh() const
+  {
+    return d_fresh;
+  }
+
+  void incQueries()
+  {
+    ++d_queries;
+  }
+
+  void setReused()
+  {
+    d_fresh = false;
+  }
+
+  void disableFastOpen()
+  {
+    d_enableFastOpen = false;
+  }
+
+  bool isFastOpenEnabled()
+  {
+    return d_enableFastOpen;
+  }
+
+  bool canAcceptNewQueries() const
+  {
+    if (d_usedForXFR || d_connectionDied) {
+      return false;
+      /* Don't reuse the TCP connection after an {A,I}XFR */
+      /* but don't reset it either, we will need to read more messages */
+    }
+#warning FIXME: maximum number of pending queries
+    return true;
+  }
+
+  bool canBeReused() const
+  {
+    if (d_usedForXFR || d_connectionDied) {
+      return false;
+    }
+    /* we can't reuse a connection where a proxy protocol payload has been sent,
+       since:
+       - it cannot be reused for a different client
+       - we might have different TLV values for each query
+    */
+    if (d_ds && d_ds->useProxyProtocol == true) {
+      return false;
+    }
+    return true;
+  }
+
+  bool matches(const std::shared_ptr<DownstreamState>& ds) const
+  {
+    if (!ds || !d_ds) {
+      return false;
+    }
+    return ds == d_ds;
+  }
+
+  static void handleIO(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
+  static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+  static IOState sendNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn);
+
+  void queueQuery(TCPQuery&& query, std::shared_ptr<TCPConnectionToBackend>& sharedSelf);
+  void handleTimeout(const struct timeval& now, bool write);
+  IOState handleResponse(const struct timeval& now);
+  void setProxyProtocolPayload(std::string&& payload);
+  void setProxyProtocolPayloadAdded(bool added);
+
+private:
+  uint16_t getQueryIdFromResponse();
+  bool reconnect();
+  void notifyAllQueriesFailed(const struct timeval& now, bool timeout = false);
+
+  boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() without any backend selected");
+    }
+    if (d_ds->tcpRecvTimeout == 0) {
+      return boost::none;
+    }
+
+    struct timeval res = now;
+    res.tv_sec += d_ds->tcpRecvTimeout;
+
+    return res;
+  }
+
+  boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
+  {
+    if (d_ds == nullptr) {
+      throw std::runtime_error("getBackendReadTTD() called without any backend selected");
+    }
+    if (d_ds->tcpSendTimeout == 0) {
+      return boost::none;
+    }
+
+    struct timeval res = now;
+    res.tv_sec += d_ds->tcpSendTimeout;
+
+    return res;
+  }
+
+  /* waitingForResponseFromBackend is a state where we have not yet started reading the size,
+     so we can still switch to sending instead */
+  enum class State { idle, doingHandshake, sendingQueryToBackend, waitingForResponseFromBackend, readingResponseSizeFromBackend, readingResponseFromBackend };
+  static const uint16_t s_xfrID;
+
+  std::vector<uint8_t> d_responseBuffer;
+  std::deque<TCPQuery> d_pendingQueries;
+  std::unordered_map<uint16_t, TCPQuery> d_pendingResponses;
+  std::unique_ptr<Socket> d_socket{nullptr};
+  std::unique_ptr<IOStateHandler> d_ioState{nullptr};
+  std::shared_ptr<DownstreamState> d_ds{nullptr};
+  //std::weak_ptr<IncomingTCPConnectionState> d_clientConn;
+  std::shared_ptr<IncomingTCPConnectionState> d_clientConn;
+  std::string d_proxyProtocolPayload;
+  TCPQuery d_currentQuery;
+  struct timeval d_connectionStartTime;
+  size_t d_currentPos{0};
+  uint64_t d_queries{0};
+  uint64_t d_downstreamFailures{0};
+  uint16_t d_responseSize{0};
+  State d_state{State::idle};
+  bool d_fresh{true};
+  bool d_enableFastOpen{false};
+  bool d_connectionDied{true};
+  bool d_usedForXFR{false};
+  bool d_proxyProtocolPayloadAdded{false};
+};
diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh
new file mode 100644 (file)
index 0000000..9479f0a
--- /dev/null
@@ -0,0 +1,228 @@
+#pragma once
+
+#include "dolog.hh"
+
+class TCPClientThreadData
+{
+public:
+  TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+  {
+  }
+
+  LocalHolders holders;
+  LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
+  std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+};
+
+struct ConnectionInfo
+{
+  ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
+  {
+  }
+  ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
+  {
+    rhs.cs = nullptr;
+    rhs.fd = -1;
+  }
+
+  ConnectionInfo(const ConnectionInfo& rhs) = delete;
+  ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
+
+  ConnectionInfo& operator=(ConnectionInfo&& rhs)
+  {
+    remote = rhs.remote;
+    cs = rhs.cs;
+    rhs.cs = nullptr;
+    fd = rhs.fd;
+    rhs.fd = -1;
+    return *this;
+  }
+
+  ~ConnectionInfo()
+  {
+    if (fd != -1) {
+      close(fd);
+      fd = -1;
+    }
+    if (cs) {
+      --cs->tcpCurrentConnections;
+    }
+  }
+
+  ComboAddress remote;
+  ClientState* cs{nullptr};
+  int fd{-1};
+};
+
+class IncomingTCPConnectionState
+{
+public:
+  //IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_responseBuffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_ioState(threadData.mplexer, d_ci.fd), _connectionStartTime(now)
+  IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_ioState(make_unique<IOStateHandler>(threadData.mplexer, d_ci.fd)), d_connectionStartTime(now)
+  {
+    d_origDest.reset();
+    d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
+    socklen_t socklen = d_origDest.getSocklen();
+    if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_origDest), &socklen)) {
+      d_origDest = d_ci.cs->local;
+    }
+  }
+
+  IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
+  IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
+
+  ~IncomingTCPConnectionState();
+
+  void resetForNewQuery();
+
+  boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
+      return boost::none;
+    }
+
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
+      if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+        return now;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
+        now.tv_sec += remaining;
+        return now;
+      }
+    }
+
+    now.tv_sec += g_tcpRecvTimeout;
+    return now;
+  }
+
+  boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
+  {
+    if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
+      return boost::none;
+    }
+
+    struct timeval res = now;
+
+    if (g_maxTCPConnectionDuration > 0) {
+      auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
+      if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
+        return res;
+      }
+      auto remaining = g_maxTCPConnectionDuration - elapsed;
+      if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
+        res.tv_sec += remaining;
+        return res;
+      }
+    }
+
+    res.tv_sec += g_tcpSendTimeout;
+    return res;
+  }
+
+  bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
+  {
+    if (maxConnectionDuration) {
+      time_t curtime = now.tv_sec;
+      unsigned int elapsed = 0;
+      if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
+        elapsed = curtime - d_connectionStartTime.tv_sec;
+      }
+      if (elapsed >= maxConnectionDuration) {
+        return true;
+      }
+      d_remainingTime = maxConnectionDuration - elapsed;
+    }
+
+    return false;
+  }
+
+  void dump() const
+  {
+    static std::mutex s_mutex;
+
+    struct timeval now;
+    gettimeofday(&now, 0);
+
+    {
+      std::lock_guard<std::mutex> lock(s_mutex);
+      fprintf(stderr, "State is %p\n", this);
+      cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
+      cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
+      cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
+      if (d_state > State::doingHandshake) {
+        cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuerySize) {
+        cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuerySize) {
+        cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
+      }
+      if (d_state > State::readingQuery) {
+        cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
+      }
+    }
+  }
+
+  std::shared_ptr<TCPConnectionToBackend> getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds)
+  {
+#warning TODO: we need to find a connection to this DS, usable (no TLV values sent) and supporting OOR
+    return nullptr;
+  }
+
+  std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const struct timeval& now);
+
+  std::unique_ptr<FDMultiplexer>& getIOMPlexer() const
+  {
+    return d_threadData.mplexer;
+  }
+
+  static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now);
+  static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+
+  void queueQuery(TCPQuery&& query);
+  void notifyIOError(std::shared_ptr<IncomingTCPConnectionState>& state, IDState&& query, const struct timeval& now);
+  void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
+  void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
+  void handleXFRResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
+  void handleTimeout(bool write);
+
+  bool active() const
+  {
+    return d_ioState != nullptr;
+  }
+
+  enum class State { doingHandshake, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
+
+  std::vector<uint8_t> d_buffer;
+  std::deque<TCPResponse> d_queuedResponses;
+  TCPClientThreadData& d_threadData;
+  TCPResponse d_currentResponse;
+  ConnectionInfo d_ci;
+  ComboAddress d_origDest;
+  TCPIOHandler d_handler;
+  std::unique_ptr<IOStateHandler> d_ioState{nullptr};
+  struct timeval d_connectionStartTime;
+  struct timeval d_handshakeDoneTime;
+  struct timeval d_firstQuerySizeReadTime;
+  struct timeval d_querySizeReadTime;
+  struct timeval d_queryReadTime;
+  size_t d_currentPos{0};
+  size_t d_queriesCount{0};
+  unsigned int d_remainingTime{0};
+  uint16_t d_querySize{0};
+  uint16_t d_downstreamFailures{0};
+  State d_state{State::doingHandshake};
+  IOState d_lastIOState{IOState::Done};
+  bool d_readingFirstQuery{true};
+  bool d_isXFR{false};
+  bool d_xfrStarted{false};
+  bool d_xfrDone{false};
+  bool d_selfGeneratedResponse{false};
+  bool d_proxyProtocolPayloadAdded{false};
+  bool d_proxyProtocolPayloadHasTLV{false};
+};
+
+IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead);
index 080c95e3b8174f07197f8225e7ac4ba1ea7d8bc0..57c21606e027bae3d48f76caf9044b9d8336156e 100644 (file)
@@ -866,7 +866,7 @@ try
       }
 
       string decoded;
-      /* rough estimate so we hopefully don't need a need allocation later */
+      /* rough estimate so we hopefully don't need a new allocation later */
       /* We reserve at least 512 additional bytes to be able to add EDNS, but we also want
          at least s_maxPacketCacheEntrySize bytes to be able to fill the answer from the packet cache */
       const size_t estimate = ((sdns.size() * 3) / 4);
diff --git a/pdns/dnsdistdist/tcpiohandler-mplexer.hh b/pdns/dnsdistdist/tcpiohandler-mplexer.hh
new file mode 100644 (file)
index 0000000..fdb1c5e
--- /dev/null
@@ -0,0 +1,121 @@
+
+#pragma once
+
+#include "mplexer.hh"
+#include "tcpiohandler.hh"
+
+class IOStateHandler
+{
+public:
+  IOStateHandler(std::unique_ptr<FDMultiplexer>& mplexer, const int fd): d_mplexer(mplexer), d_fd(fd), d_currentState(IOState::Done)
+  {
+  }
+
+  IOStateHandler(std::unique_ptr<FDMultiplexer>& mplexer): d_mplexer(mplexer), d_fd(-1), d_currentState(IOState::Done)
+  {
+  }
+
+  ~IOStateHandler()
+  {
+    /* be careful that this won't save us if the callback is still registered to the multiplexer,
+       because in that case the shared pointer count will never reach zero so this destructor won't
+       be called */
+    reset();
+  }
+
+  IOState getState() const
+  {
+    return d_currentState;
+  }
+
+  void setSocket(int fd)
+  {
+    if (d_fd != -1) {
+      throw std::runtime_error("Trying to set the socket descriptor on an already initialized IOStateHandler");
+    }
+    d_fd = fd;
+  }
+
+  void reset()
+  {
+    update(IOState::Done);
+  }
+
+  void update(IOState iostate, FDMultiplexer::callbackfunc_t callback = FDMultiplexer::callbackfunc_t(), FDMultiplexer::funcparam_t callbackData = boost::any(), boost::optional<struct timeval> ttd = boost::none)
+  {
+    cerr<<"in "<<__PRETTY_FUNCTION__<<" for fd "<<d_fd<<", last state was "<<(int)d_currentState<<", new state is "<<(int)iostate<<endl;
+    if (d_currentState == IOState::NeedRead && iostate != IOState::NeedRead) {
+      cerr<<__PRETTY_FUNCTION__<<": remove read FD "<<d_fd<<endl;
+      d_mplexer->removeReadFD(d_fd);
+      d_currentState = IOState::Done;
+    }
+    else if (d_currentState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
+      cerr<<__PRETTY_FUNCTION__<<": remove write FD "<<d_fd<<endl;
+      d_mplexer->removeWriteFD(d_fd);
+      d_currentState = IOState::Done;
+    }
+
+    if (iostate == IOState::NeedRead) {
+      if (d_currentState == IOState::NeedRead) {
+        if (ttd) {
+          /* let's update the TTD ! */
+          d_mplexer->setReadTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */0);
+        }
+        return;
+      }
+
+      d_currentState = IOState::NeedRead;
+      cerr<<__PRETTY_FUNCTION__<<": add read FD "<<d_fd<<endl;
+      d_mplexer->addReadFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
+    }
+    else if (iostate == IOState::NeedWrite) {
+      if (d_currentState == IOState::NeedWrite) {
+        return;
+      }
+
+      d_currentState = IOState::NeedWrite;
+      cerr<<__PRETTY_FUNCTION__<<": add write FD "<<d_fd<<endl;
+      d_mplexer->addWriteFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr);
+    }
+    else if (iostate == IOState::Done) {
+      d_currentState = IOState::Done;
+      cerr<<__PRETTY_FUNCTION__<<": done"<<endl;
+    }
+  }
+
+private:
+  std::unique_ptr<FDMultiplexer>& d_mplexer;
+  int d_fd;
+  IOState d_currentState;
+};
+
+class IOStateGuard
+{
+public:
+  /* this class is using RAII to make sure we don't forget to release an IOStateHandler
+     from the IO multiplexer in case of exception / error handling */
+  IOStateGuard(std::unique_ptr<IOStateHandler>& handler): d_handler(handler), d_enabled(true)
+  {
+  }
+
+  ~IOStateGuard()
+  {
+    /* if we are still owning the state when we go out of scope,
+       let's reset the state so it's not registered to the IO multiplexer anymore
+       and its reference count goes to zero */
+    if (d_enabled && d_handler) {
+      cerr<<"IOStateGuard destroyed while holding a state, let's reset it"<<endl;
+      d_handler->reset();
+      d_enabled = false;
+    }
+  }
+
+  void release()
+  {
+    d_enabled = false;
+  }
+
+private:
+  std::unique_ptr<IOStateHandler>& d_handler;
+  bool d_enabled;
+};
index de5aa4cdd3eac234c0e70723e64ed36174700374..55a0f5c93d58874c2aec8f706e381a78d92f4c81 100644 (file)
@@ -71,8 +71,8 @@ class TestTCPShort(DNSDistTest):
         # send announcedSize bytes minus 1 so we get a second read
         conn.send(wire)
         time.sleep(1)
-        # send 1024 bytes
-        conn.send(b'A' * 1024)
+        # send the remaining byte
+        conn.send(b'A')
 
         (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True)
         conn.close()
@@ -112,8 +112,8 @@ class TestTCPShort(DNSDistTest):
         # send announcedSize bytes minus 1 so we get a second read
         conn.send(wire)
         time.sleep(1)
-        # send 1024 bytes
-        conn.send(b'A' * 1024)
+        # send the remaining byte
+        conn.send(b'A')
 
         (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True)
         conn.close()
index 9a50c2b299d7989c0ab396f9fd60bf4c5d6fcdd3..246416f6201457c4e48d8c33c94d5fce9a38cd17 100644 (file)
@@ -3,7 +3,7 @@ import dns
 import clientsubnetoption
 from dnsdisttests import DNSDistTest
 
-class TestBasics(DNSDistTest):
+class TestTags(DNSDistTest):
 
     _config_template = """
     newServer{address="127.0.0.1:%s"}