]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Use the same outgoing TCP connection for different clients
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 15 Oct 2021 15:36:16 +0000 (17:36 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 26 Oct 2021 15:07:19 +0000 (17:07 +0200)
12 files changed:
pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-nghttp2.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.hh
pdns/dnsdistdist/dnsdist-tcp-upstream.hh
pdns/dnsdistdist/dnsdist-tcp.hh
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
regression-tests.dnsdist/dnsdisttests.py
regression-tests.dnsdist/test_AXFR.py
regression-tests.dnsdist/test_DynBlocks.py
regression-tests.dnsdist/test_OutgoingTLS.py
regression-tests.dnsdist/test_TCPOnly.py

index d70f7692e582787935a9d83a6d53b2f10f524f58..33079802403f17bc7a65a4cc1fefca37fff00ba1 100644 (file)
@@ -113,12 +113,14 @@ std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstrea
 {
   std::shared_ptr<TCPConnectionToBackend> downstream{nullptr};
 
-  downstream = getActiveDownstreamConnection(ds, tlvs);
+  downstream = getOwnedDownstreamConnection(ds, tlvs);
 
   if (!downstream) {
-    /* we don't have a connection to this backend active yet, let's get one (it might not be a fresh one, though) */
+    /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */
     downstream = DownstreamConnectionsManager::getConnectionToDownstream(d_threadData.mplexer, ds, now);
-    registerActiveDownstreamConnection(downstream);
+    if (ds->useProxyProtocol) {
+      registerOwnedDownstreamConnection(downstream);
+    }
   }
 
   return downstream;
@@ -307,17 +309,17 @@ void IncomingTCPConnectionState::resetForNewQuery()
   d_state = State::waitingForQuery;
 }
 
-std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs)
+std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs)
 {
-  auto it = d_activeConnectionsToBackend.find(ds);
-  if (it == d_activeConnectionsToBackend.end()) {
-    DEBUGLOG("no active connection found for "<<ds->getName());
+  auto it = d_ownedConnectionsToBackend.find(ds);
+  if (it == d_ownedConnectionsToBackend.end()) {
+    DEBUGLOG("no owned connection found for "<<ds->getName());
     return nullptr;
   }
 
   for (auto& conn : it->second) {
-    if (conn->canAcceptNewQueries() && conn->matchesTLVs(tlvs)) {
-      DEBUGLOG("Got one active connection accepting more for "<<ds->getName());
+    if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) {
+      DEBUGLOG("Got one owned connection accepting more for "<<ds->getName());
       conn->setReused();
       return conn;
     }
@@ -327,9 +329,9 @@ std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getActiveDow
   return nullptr;
 }
 
-void IncomingTCPConnectionState::registerActiveDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn)
+void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn)
 {
-  d_activeConnectionsToBackend[conn->getDS()].push_front(conn);
+  d_ownedConnectionsToBackend[conn->getDS()].push_front(conn);
 }
 
 /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */
@@ -375,7 +377,12 @@ void IncomingTCPConnectionState::terminateClientConnection()
   d_queuedResponses.clear();
   /* we have already released idle connections that could be reused,
      we don't care about the ones still waiting for responses */
-  d_activeConnectionsToBackend.clear();
+  for (auto& backend : d_ownedConnectionsToBackend) {
+    for (auto& conn : backend.second) {
+      conn->release();
+    }
+  }
+  d_ownedConnectionsToBackend.clear();
   /* meaning we will no longer be 'active' when the backend
      response or timeout comes in */
   d_ioState.reset();
@@ -419,18 +426,18 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe
 {
   std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this();
 
-  if (response.d_connection && response.d_connection->isIdle()) {
-    // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool yet, no one else will be able to use it anyway
-    if (response.d_connection->canBeReused()) {
-      const auto connIt = state->d_activeConnectionsToBackend.find(response.d_connection->getDS());
-      if (connIt != state->d_activeConnectionsToBackend.end()) {
+  if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->useProxyProtocol) {
+    // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway
+    if (!response.d_connection->willBeReusable(true)) {
+      // if it can't be reused even by us, well
+      const auto connIt = state->d_ownedConnectionsToBackend.find(response.d_connection->getDS());
+      if (connIt != state->d_ownedConnectionsToBackend.end()) {
         auto& list = connIt->second;
 
         for (auto it = list.begin(); it != list.end(); ++it) {
           if (*it == response.d_connection) {
             try {
               response.d_connection->release();
-              DownstreamConnectionsManager::releaseDownstreamConnection(std::move(*it));
             }
             catch (const std::exception& e) {
               vinfolog("Error releasing connection: %s", e.what());
@@ -1056,7 +1063,7 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnec
 {
   vinfolog("Timeout while %s TCP client %s", (write ? "writing to" : "reading from"), state->d_ci.remote.toStringWithPort());
   DEBUGLOG("client timeout");
-  DEBUGLOG("Processed "<<state->d_queriesCount<<" queries, current count is "<<state->d_currentQueriesCount<<", "<<state->d_activeConnectionsToBackend.size()<<" active connections, "<<state->d_queuedResponses.size()<<" response queued");
+  DEBUGLOG("Processed "<<state->d_queriesCount<<" queries, current count is "<<state->d_currentQueriesCount<<", "<<state->d_ownedConnectionsToBackend.size()<<" owned connections, "<<state->d_queuedResponses.size()<<" response queued");
 
   if (write || state->d_currentQueriesCount == 0) {
     ++state->d_ci.cs->tcpClientTimeouts;
@@ -1067,14 +1074,6 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnec
     /* we still have some queries in flight, let's just stop reading for now */
     state->d_state = IncomingTCPConnectionState::State::idle;
     state->d_ioState->update(IOState::Done, handleIOCallback, state);
-
-#ifdef DEBUGLOG_ENABLED
-    for (const auto& active : state->d_activeConnectionsToBackend) {
-      for (const auto& conn: active.second) {
-        DEBUGLOG("Connection to "<<active.first->getName()<<" is "<<(conn->isIdle() ? "idle" : "not idle"));
-      }
-    }
-#endif
   }
 }
 
index 4768b2a7e7319274ed42758fd37dce8ab4884e82..03abd4332a763ded6979d73ea410edea83e1f7cf 100644 (file)
@@ -43,7 +43,7 @@ std::unique_ptr<DoHClientCollection> g_dohClientThreads{nullptr};
 std::optional<uint16_t> g_outgoingDoHWorkerThreads{std::nullopt};
 
 #ifdef HAVE_NGHTTP2
-class DoHConnectionToBackend : public TCPConnectionToBackend
+class DoHConnectionToBackend : public ConnectionToBackend
 {
 public:
   DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload);
@@ -58,17 +58,14 @@ public:
     return o.str();
   }
 
-  bool reachedMaxStreamID() const;
-  bool canBeReused() const override;
-  /* full now but will become usable later */
-  bool willBeReusable() const;
-
   void setHealthCheck(bool h)
   {
     d_healthCheckQuery = h;
   }
 
   void stopIO();
+  bool reachedMaxConcurrentQueries() const override;
+  bool reachedMaxStreamID() const override;
 
 private:
   static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data);
@@ -79,7 +76,6 @@ private:
   static int on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data);
   static void handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
   static void handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param);
-  static void handleIO(std::shared_ptr<DoHConnectionToBackend>& conn, const struct timeval& now);
 
   static void addStaticHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& valueKey);
   static void addDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& nameKey, const std::string& value);
@@ -117,7 +113,6 @@ private:
   std::unique_ptr<nghttp2_session, void (*)(nghttp2_session*)> d_session{nullptr, nghttp2_session_del};
   size_t d_outPos{0};
   size_t d_inPos{0};
-  uint32_t d_highestStreamID{0};
   bool d_healthCheckQuery{false};
   bool d_firstWrite{true};
 };
@@ -213,34 +208,12 @@ bool DoHConnectionToBackend::reachedMaxStreamID() const
   return d_highestStreamID == maximumStreamID;
 }
 
-bool DoHConnectionToBackend::canBeReused() const
+bool DoHConnectionToBackend::reachedMaxConcurrentQueries() const
 {
-  if (d_connectionDied) {
-    return false;
-  }
-
-  if (!d_proxyProtocolPayload.empty()) {
-    return false;
-  }
-
-  if (reachedMaxStreamID()) {
-    return false;
-  }
-
   //cerr<<"Got "<<getConcurrentStreamsCount()<<" concurrent streams, max is "<<nghttp2_session_get_remote_settings(d_session.get(), NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS)<<endl;
   if (nghttp2_session_get_remote_settings(d_session.get(), NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS) <= getConcurrentStreamsCount()) {
-    return false;
-  }
-
-  return true;
-}
-
-bool DoHConnectionToBackend::willBeReusable() const
-{
-  if (!d_connectionDied && d_proxyProtocolPayload.empty() && !reachedMaxStreamID()) {
     return true;
   }
-
   return false;
 }
 
@@ -399,10 +372,6 @@ public:
   std::unique_ptr<FDMultiplexer> mplexer{nullptr};
 };
 
-void DoHConnectionToBackend::handleIO(std::shared_ptr<DoHConnectionToBackend>& conn, const struct timeval& now)
-{
-}
-
 void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param)
 {
   auto conn = boost::any_cast<std::shared_ptr<DoHConnectionToBackend>>(param);
@@ -505,7 +474,7 @@ void DoHConnectionToBackend::stopIO()
 {
   d_ioState->reset();
 
-  if (d_connectionDied) {
+  if (!willBeReusable(false)) {
     /* remove ourselves from the connection cache, this might mean that our
        reference count drops to zero after that, so we need to be careful */
     auto shared = std::dynamic_pointer_cast<DoHConnectionToBackend>(shared_from_this());
@@ -547,7 +516,7 @@ void DoHConnectionToBackend::updateIO(IOState newState, FDMultiplexer::callbackf
 
 void DoHConnectionToBackend::watchForRemoteHostClosingConnection()
 {
-  if (willBeReusable() && !d_healthCheckQuery) {
+  if (willBeReusable(false) && !d_healthCheckQuery) {
     updateIO(IOState::NeedRead, handleReadableIOCallback, false);
   }
 }
@@ -820,9 +789,9 @@ int DoHConnectionToBackend::on_error_callback(nghttp2_session* session, int lib_
 }
 
 DoHConnectionToBackend::DoHConnectionToBackend(std::shared_ptr<DownstreamState> ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload) :
-  TCPConnectionToBackend(ds, mplexer, now), d_proxyProtocolPayload(std::move(proxyProtocolPayload))
+  ConnectionToBackend(ds, mplexer, now), d_proxyProtocolPayload(std::move(proxyProtocolPayload))
 {
-  // inherit most of the stuff from the TCPConnectionToBackend()
+  // inherit most of the stuff from the ConnectionToBackend()
   d_ioState = make_unique<IOStateHandler>(*d_mplexer, d_handler->getDescriptor());
 
   nghttp2_session_callbacks* cbs = nullptr;
@@ -973,7 +942,7 @@ std::shared_ptr<DoHConnectionToBackend> DownstreamDoHConnectionsManager::getConn
       for (auto listIt = list.begin(); listIt != list.end();) {
         auto& entry = *listIt;
         if (!entry->canBeReused()) {
-          if (!entry->willBeReusable()) {
+          if (!entry->willBeReusable(false)) {
             listIt = list.erase(listIt);
           }
           else {
@@ -1003,7 +972,7 @@ std::shared_ptr<DoHConnectionToBackend> DownstreamDoHConnectionsManager::getConn
 
   auto newConnection = std::make_shared<DoHConnectionToBackend>(ds, mplexer, now, std::move(proxyProtocolPayload));
   if (!haveProxyProtocol) {
-    t_downstreamConnections[backendId].push_back(newConnection);
+    t_downstreamConnections[backendId].push_front(newConnection);
   }
   return newConnection;
 }
index 657c9a44c4c035a5c88ebfd374a56a49bab39dfe..8bb84145e096de45203fd9f25ce5e70b76a9fb25 100644 (file)
@@ -5,12 +5,8 @@
 
 #include "dnsparser.hh"
 
-TCPConnectionToBackend::~TCPConnectionToBackend()
+ConnectionToBackend::~ConnectionToBackend()
 {
-  if (d_ds && !d_pendingResponses.empty()) {
-    d_ds->outstanding -= d_pendingResponses.size();
-  }
-
   if (d_ds && d_handler) {
     --d_ds->tcpCurrentConnections;
     struct timeval now;
@@ -36,6 +32,100 @@ TCPConnectionToBackend::~TCPConnectionToBackend()
   }
 }
 
+bool ConnectionToBackend::reconnect()
+{
+  std::unique_ptr<TLSSession> tlsSession{nullptr};
+  if (d_handler) {
+    DEBUGLOG("closing socket "<<d_handler->getDescriptor());
+    if (d_handler->isTLS()) {
+      if (d_handler->hasTLSSessionBeenResumed()) {
+        ++d_ds->tlsResumptions;
+      }
+      try {
+        auto sessions = d_handler->getTLSSessions();
+        if (!sessions.empty()) {
+          tlsSession = std::move(sessions.back());
+          sessions.pop_back();
+          if (!sessions.empty()) {
+            g_sessionCache.putSessions(d_ds->getID(), time(nullptr), std::move(sessions));
+          }
+        }
+      }
+      catch (const std::exception& e) {
+        vinfolog("Unable to get a TLS session to resume: %s", e.what());
+      }
+    }
+    d_handler->close();
+    d_ioState.reset();
+    --d_ds->tcpCurrentConnections;
+  }
+
+  d_fresh = true;
+  d_highestStreamID = 0;
+  d_proxyProtocolPayloadSent = false;
+
+  do {
+    vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures);
+    DEBUGLOG("Opening TCP connection to backend "<<d_ds->getNameWithAddr());
+    ++d_ds->tcpNewConnections;
+    try {
+      auto socket = std::make_unique<Socket>(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0);
+      DEBUGLOG("result of socket() is "<<socket->getHandle());
+
+      if (!IsAnyAddress(d_ds->sourceAddr)) {
+        SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
+#ifdef IP_BIND_ADDRESS_NO_PORT
+        if (d_ds->ipBindAddrNoPort) {
+          SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+        }
+#endif
+#ifdef SO_BINDTODEVICE
+        if (!d_ds->sourceItfName.empty()) {
+          int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length());
+          if (res != 0) {
+            vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror());
+          }
+        }
+#endif
+        socket->bind(d_ds->sourceAddr, false);
+      }
+      socket->setNonBlocking();
+
+      gettimeofday(&d_connectionStartTime, nullptr);
+      auto handler = std::make_unique<TCPIOHandler>(d_ds->d_tlsSubjectName, socket->releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec);
+      if (!tlsSession && d_ds->d_tlsCtx) {
+        tlsSession = g_sessionCache.getSession(d_ds->getID(), d_connectionStartTime.tv_sec);
+      }
+      if (tlsSession) {
+        handler->setTLSSession(tlsSession);
+      }
+      handler->tryConnect(d_ds->tcpFastOpen && isFastOpenEnabled(), d_ds->remote);
+      d_queries = 0;
+
+      d_handler = std::move(handler);
+      d_ds->incCurrentConnectionsCount();
+      return true;
+    }
+    catch (const std::runtime_error& e) {
+      vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what());
+      d_downstreamFailures++;
+      if (d_downstreamFailures >= d_ds->d_retries) {
+        throw;
+      }
+    }
+  }
+  while (d_downstreamFailures < d_ds->d_retries);
+
+  return false;
+}
+
+TCPConnectionToBackend::~TCPConnectionToBackend()
+{
+  if (d_ds && !d_pendingResponses.empty()) {
+    d_ds->outstanding -= d_pendingResponses.size();
+  }
+}
+
 void TCPConnectionToBackend::release()
 {
   d_ds->outstanding -= d_pendingResponses.size();
@@ -43,7 +133,6 @@ void TCPConnectionToBackend::release()
   d_pendingResponses.clear();
   d_pendingQueries.clear();
 
-  d_sender.reset();
   if (d_ioState) {
     d_ioState.reset();
   }
@@ -52,6 +141,9 @@ void TCPConnectionToBackend::release()
 IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
 {
   conn->d_currentQuery = std::move(conn->d_pendingQueries.front());
+  dnsheader* dh = reinterpret_cast<dnsheader*>(&conn->d_currentQuery.d_query.d_buffer.at(sizeof(uint16_t) + (conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded ? conn->d_currentQuery.d_query.d_proxyProtocolPayload.size() : 0)));
+  uint16_t id = conn->d_highestStreamID;
+  dh->id = htons(id);
   conn->d_pendingQueries.pop_front();
   conn->d_state = State::sendingQueryToBackend;
   conn->d_currentPos = 0;
@@ -63,7 +155,7 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend
 {
   DEBUGLOG("sending query to backend "<<conn->getDS()->getName()<<" over FD "<<conn->d_handler->getDescriptor());
 
-  IOState state = conn->d_handler->tryWrite(conn->d_currentQuery.d_buffer, conn->d_currentPos, conn->d_currentQuery.d_buffer.size());
+  IOState state = conn->d_handler->tryWrite(conn->d_currentQuery.d_query.d_buffer, conn->d_currentPos, conn->d_currentQuery.d_query.d_buffer.size());
 
   if (state != IOState::Done) {
     return state;
@@ -71,20 +163,22 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend
 
   DEBUGLOG("query sent to backend");
   /* request sent ! */
-  if (conn->d_currentQuery.d_proxyProtocolPayloadAdded) {
+  if (conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded) {
     conn->d_proxyProtocolPayloadSent = true;
   }
   ++conn->d_queries;
   conn->d_currentPos = 0;
 
-  DEBUGLOG("adding a pending response for ID "<<ntohs(conn->d_currentQuery.d_idstate.origID)<<" and QNAME "<<conn->d_currentQuery.d_idstate.qname);
-  auto res = conn->d_pendingResponses.insert({ntohs(conn->d_currentQuery.d_idstate.origID), std::move(conn->d_currentQuery)});
-  /* if there was already a pending response with that ID, the client messed up and we don't expect more
+  DEBUGLOG("adding a pending response for ID "<<conn->d_highestStreamID<<" and QNAME "<<conn->d_currentQuery.d_query.d_idstate.qname);
+  auto res = conn->d_pendingResponses.insert({conn->d_highestStreamID, std::move(conn->d_currentQuery)});
+  /* if there was already a pending response with that ID, we messed up and we don't expect more
      than one response */
   if (res.second) {
     ++conn->d_ds->outstanding;
   }
-  conn->d_currentQuery.d_buffer.clear();
+  ++conn->d_highestStreamID;
+  conn->d_currentQuery.d_sender.reset();
+  conn->d_currentQuery.d_query.d_buffer.clear();
 
   return state;
 }
@@ -152,7 +246,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
             iostate = conn->handleResponse(conn, now);
           }
           catch (const std::exception& e) {
-            vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what());
+            vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_query.d_idstate.origRemote.toStringWithPort(), e.what());
             ioGuard.release();
             conn->release();
             return;
@@ -173,7 +267,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
          but it might also be a real IO error or something else.
          Let's just drop the connection
       */
-      vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_state == State::sendingQueryToBackend ? "writing to" : "reading from"), conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what());
+      vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_state == State::sendingQueryToBackend ? "writing to" : "reading from"), conn->d_currentQuery.d_query.d_idstate.origRemote.toStringWithPort(), e.what());
 
       if (conn->d_state == State::sendingQueryToBackend) {
         ++conn->d_ds->tcpDiedSendingQuery;
@@ -206,14 +300,23 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
             conn->d_ioState = make_unique<IOStateHandler>(*conn->d_mplexer, conn->d_handler->getDescriptor());
 
             /* we need to resend the queries that were in flight, if any */
+            if (conn->d_state == State::sendingQueryToBackend) {
+              /* we need to edit this query so it has the correct ID */
+              auto query = std::move(conn->d_currentQuery);
+              dnsheader* dh = reinterpret_cast<dnsheader*>(&query.d_query.d_buffer.at(sizeof(uint16_t) + (query.d_query.d_proxyProtocolPayloadAdded ? query.d_query.d_proxyProtocolPayload.size() : 0)));
+              uint16_t id = conn->d_highestStreamID;
+              dh->id = htons(id);
+              conn->d_currentQuery = std::move(query);
+            }
+
             for (auto& pending : conn->d_pendingResponses) {
               --conn->d_ds->outstanding;
 
-              if (pending.second.isXFR() && pending.second.d_xfrStarted) {
+              if (pending.second.d_query.isXFR() && pending.second.d_query.d_xfrStarted) {
                 /* this one can't be restarted, sorry */
                 DEBUGLOG("A XFR for which a response has already been sent cannot be restarted");
                 try {
-                  conn->d_sender->notifyIOError(std::move(pending.second.d_idstate), now);
+                  pending.second.d_sender->notifyIOError(std::move(pending.second.d_query.d_idstate), now);
                 }
                 catch (const std::exception& e) {
                   vinfolog("Got an exception while notifying: %s", e.what());
@@ -241,9 +344,9 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& c
               iostate = queueNextQuery(conn);
             }
 
-            if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_proxyProtocolPayload.empty()) {
-              conn->d_currentQuery.d_buffer.insert(conn->d_currentQuery.d_buffer.begin(), conn->d_currentQuery.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_proxyProtocolPayload.end());
-              conn->d_currentQuery.d_proxyProtocolPayloadAdded = true;
+            if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_query.d_proxyProtocolPayload.empty()) {
+              conn->d_currentQuery.d_query.d_buffer.insert(conn->d_currentQuery.d_query.d_buffer.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.end());
+              conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true;
             }
 
             reconnected = true;
@@ -304,125 +407,39 @@ void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t
 
 void TCPConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query)
 {
-  if (!d_sender) {
-    d_sender = sender;
+  if (!d_ioState) {
     d_ioState = make_unique<IOStateHandler>(*d_mplexer, d_handler->getDescriptor());
   }
-  else if (d_sender != sender) {
-    throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries");
-  }
 
   // if we are not already sending a query or in the middle of reading a response (so idle),
   // start sending the query
   if (d_state == State::idle || d_state == State::waitingForResponseFromBackend) {
-    DEBUGLOG("Sending new query to backend right away");
+    DEBUGLOG("Sending new query to backend right away, with ID "<<d_highestStreamID);
     d_state = State::sendingQueryToBackend;
     d_currentPos = 0;
-    d_currentQuery = std::move(query);
-    if (needProxyProtocolPayload() && !d_currentQuery.d_proxyProtocolPayloadAdded && !d_currentQuery.d_proxyProtocolPayload.empty()) {
-      d_currentQuery.d_buffer.insert(d_currentQuery.d_buffer.begin(), d_currentQuery.d_proxyProtocolPayload.begin(), d_currentQuery.d_proxyProtocolPayload.end());
-      d_currentQuery.d_proxyProtocolPayloadAdded = true;
+    dnsheader* dh = reinterpret_cast<dnsheader*>(&query.d_buffer.at(sizeof(uint16_t) + (query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0)));
+    uint16_t id = d_highestStreamID;
+    dh->id = htons(id);
+    d_currentQuery = PendingRequest({sender, std::move(query)});
+
+    if (needProxyProtocolPayload() && !d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !d_currentQuery.d_query.d_proxyProtocolPayload.empty()) {
+      d_currentQuery.d_query.d_buffer.insert(d_currentQuery.d_query.d_buffer.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.end());
+      d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true;
     }
 
     struct timeval now;
     gettimeofday(&now, 0);
 
-    auto shared = shared_from_this();
+    auto shared = std::dynamic_pointer_cast<TCPConnectionToBackend>(shared_from_this());
     handleIO(shared, now);
   }
   else {
     DEBUGLOG("Adding new query to the queue because we are in state "<<(int)d_state);
     // store query in the list of queries to send
-    d_pendingQueries.push_back(std::move(query));
+    d_pendingQueries.push_back(PendingRequest({sender, std::move(query)}));
   }
 }
 
-bool TCPConnectionToBackend::reconnect()
-{
-  std::unique_ptr<TLSSession> tlsSession{nullptr};
-  if (d_handler) {
-    DEBUGLOG("closing socket "<<d_handler->getDescriptor());
-    if (d_handler->isTLS()) {
-      if (d_handler->hasTLSSessionBeenResumed()) {
-        ++d_ds->tlsResumptions;
-      }
-      try {
-        auto sessions = d_handler->getTLSSessions();
-        if (!sessions.empty()) {
-          tlsSession = std::move(sessions.back());
-          sessions.pop_back();
-          if (!sessions.empty()) {
-            g_sessionCache.putSessions(d_ds->getID(), time(nullptr), std::move(sessions));
-          }
-        }
-      }
-      catch (const std::exception& e) {
-        vinfolog("Unable to get a TLS session to resume: %s", e.what());
-      }
-    }
-    d_handler->close();
-    d_ioState.reset();
-    --d_ds->tcpCurrentConnections;
-  }
-
-  d_fresh = true;
-  d_proxyProtocolPayloadSent = false;
-
-  do {
-    vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures);
-    DEBUGLOG("Opening TCP connection to backend "<<d_ds->getNameWithAddr());
-    ++d_ds->tcpNewConnections;
-    try {
-      auto socket = std::make_unique<Socket>(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0);
-      DEBUGLOG("result of socket() is "<<socket->getHandle());
-
-      if (!IsAnyAddress(d_ds->sourceAddr)) {
-        SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
-#ifdef IP_BIND_ADDRESS_NO_PORT
-        if (d_ds->ipBindAddrNoPort) {
-          SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
-        }
-#endif
-#ifdef SO_BINDTODEVICE
-        if (!d_ds->sourceItfName.empty()) {
-          int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length());
-          if (res != 0) {
-            vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror());
-          }
-        }
-#endif
-        socket->bind(d_ds->sourceAddr, false);
-      }
-      socket->setNonBlocking();
-
-      gettimeofday(&d_connectionStartTime, nullptr);
-      auto handler = std::make_unique<TCPIOHandler>(d_ds->d_tlsSubjectName, socket->releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec);
-      if (!tlsSession && d_ds->d_tlsCtx) {
-        tlsSession = g_sessionCache.getSession(d_ds->getID(), d_connectionStartTime.tv_sec);
-      }
-      if (tlsSession) {
-        handler->setTLSSession(tlsSession);
-      }
-      handler->tryConnect(d_ds->tcpFastOpen && isFastOpenEnabled(), d_ds->remote);
-      d_queries = 0;
-
-      d_handler = std::move(handler);
-      d_ds->incCurrentConnectionsCount();
-      return true;
-    }
-    catch (const std::runtime_error& e) {
-      vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what());
-      d_downstreamFailures++;
-      if (d_downstreamFailures >= d_ds->d_retries) {
-        throw;
-      }
-    }
-  }
-  while (d_downstreamFailures < d_ds->d_retries);
-
-  return false;
-}
-
 void TCPConnectionToBackend::handleTimeout(const struct timeval& now, bool write)
 {
   /* in some cases we could retry, here, reconnecting and sending our pending responses again */
@@ -458,37 +475,49 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F
 {
   d_connectionDied = true;
 
-  auto& sender = d_sender;
-  if (!sender->active()) {
-    // a client timeout occurred, or something like that */
-    d_sender.reset();
-    return;
-  }
-
-  if (reason == FailureReason::timeout) {
-    const ClientState* cs = sender->getClientState();
-    if (cs) {
-      ++cs->tcpDownstreamTimeouts;
+  /* we might be terminated while notifying a query sender */
+  d_ds->outstanding -= d_pendingResponses.size();
+  auto pendingQueries = std::move(d_pendingQueries);
+  auto pendingResponses = std::move(d_pendingResponses);
+
+  auto increaseCounters = [reason](std::shared_ptr<TCPQuerySender>& sender) {
+    if (reason == FailureReason::timeout) {
+      const ClientState* cs = sender->getClientState();
+      if (cs) {
+        ++cs->tcpDownstreamTimeouts;
+      }
     }
-  }
-  else if (reason == FailureReason::gaveUp) {
-    const ClientState* cs = sender->getClientState();
-    if (cs) {
-      ++cs->tcpGaveUp;
+    else if (reason == FailureReason::gaveUp) {
+      const ClientState* cs = sender->getClientState();
+      if (cs) {
+        ++cs->tcpGaveUp;
+      }
     }
-  }
+  };
 
   try {
     if (d_state == State::sendingQueryToBackend) {
-      sender->notifyIOError(std::move(d_currentQuery.d_idstate), now);
+      auto sender = d_currentQuery.d_sender;
+      if (sender->active()) {
+        increaseCounters(sender);
+        sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now);
+      }
     }
 
-    for (auto& query : d_pendingQueries) {
-      sender->notifyIOError(std::move(query.d_idstate), now);
+    for (auto& query : pendingQueries) {
+      auto sender = query.d_sender;
+      if (sender->active()) {
+        increaseCounters(sender);
+        sender->notifyIOError(std::move(query.d_query.d_idstate), now);
+      }
     }
 
-    for (auto& response : d_pendingResponses) {
-      sender->notifyIOError(std::move(response.second.d_idstate), now);
+    for (auto& response : pendingResponses) {
+      auto sender = response.second.d_sender;
+      if (sender->active()) {
+        increaseCounters(sender);
+        sender->notifyIOError(std::move(response.second.d_query.d_idstate), now);
+      }
     }
   }
   catch (const std::exception& e) {
@@ -527,16 +556,6 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
 {
   d_downstreamFailures = 0;
 
-  auto& sender = d_sender;
-  if (!sender || !sender->active()) {
-    // a client timeout occurred, or something like that */
-    d_connectionDied = true;
-
-    release();
-
-    return IOState::Done;
-  }
-
   uint16_t queryId = 0;
   try {
     queryId = getQueryIdFromResponse();
@@ -554,19 +573,24 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
     return IOState::Done;
   }
 
-  if (it->second.isXFR()) {
+  auto dh = reinterpret_cast<dnsheader*>(d_responseBuffer.data());
+  dh->id = it->second.d_query.d_idstate.origID;
+
+  auto sender = it->second.d_sender;
+
+  if (sender->active() && it->second.d_query.isXFR()) {
     DEBUGLOG("XFR!");
     bool done = false;
     TCPResponse response;
     response.d_buffer = std::move(d_responseBuffer);
     response.d_connection = conn;
     /* we don't move the whole IDS because we will need for the responses to come */
-    response.d_idstate.qtype = it->second.d_idstate.qtype;
-    response.d_idstate.qname = it->second.d_idstate.qname;
+    response.d_idstate.qtype = it->second.d_query.d_idstate.qtype;
+    response.d_idstate.qname = it->second.d_query.d_idstate.qname;
     DEBUGLOG("passing XFRresponse to client connection for "<<response.d_idstate.qname);
 
-    it->second.d_xfrStarted = true;
-    done = isXFRFinished(response, it->second);
+    it->second.d_query.d_xfrStarted = true;
+    done = isXFRFinished(response, it->second.d_query);
 
     if (done) {
       d_pendingResponses.erase(it);
@@ -580,7 +604,6 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
     sender->handleXFRResponse(now, std::move(response));
     if (done) {
       d_state = State::idle;
-      d_sender.reset();
       return IOState::Done;
     }
 
@@ -592,26 +615,23 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
   }
 
   --conn->d_ds->outstanding;
-  auto ids = std::move(it->second.d_idstate);
+  auto ids = std::move(it->second.d_query.d_idstate);
   d_pendingResponses.erase(it);
   /* marking as idle for now, so we can accept new queries if our queues are empty */
   if (d_pendingQueries.empty() && d_pendingResponses.empty()) {
     d_state = State::idle;
   }
 
-  DEBUGLOG("passing response to client connection for "<<ids.qname);
-  // make sure that we still exist after calling handleResponse()
-  auto shared = shared_from_this();
-  bool release = canBeReused() && sender->releaseConnection();
-  sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn));
+  auto shared = conn;
+  if (sender->active()) {
+    DEBUGLOG("passing response to client connection for "<<ids.qname);
+    // make sure that we still exist after calling handleResponse()
+    sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn));
+  }
 
   if (!d_pendingQueries.empty()) {
     DEBUGLOG("still have some queries to send");
-    d_state = State::sendingQueryToBackend;
-    d_currentQuery = std::move(d_pendingQueries.front());
-    d_currentPos = 0;
-    d_pendingQueries.pop_front();
-    return IOState::NeedWrite;
+    return queueNextQuery(shared);
   }
   else if (!d_pendingResponses.empty()) {
     DEBUGLOG("still have some responses to read");
@@ -623,10 +643,6 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBa
   else {
     DEBUGLOG("nothing to do, waiting for a new query");
     d_state = State::idle;
-    d_sender.reset();
-    if (release) {
-      DownstreamConnectionsManager::releaseDownstreamConnection(std::move(shared));
-    }
     return IOState::Done;
   }
 }
@@ -719,7 +735,6 @@ bool TCPConnectionToBackend::isXFRFinished(const TCPResponse& response, TCPQuery
 
 std::shared_ptr<TCPConnectionToBackend> DownstreamConnectionsManager::getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
 {
-  std::shared_ptr<TCPConnectionToBackend> result;
   struct timeval freshCutOff = now;
   freshCutOff.tv_sec -= 1;
 
@@ -731,53 +746,47 @@ std::shared_ptr<TCPConnectionToBackend> DownstreamConnectionsManager::getConnect
     const auto& it = t_downstreamConnections.find(backendId);
     if (it != t_downstreamConnections.end()) {
       auto& list = it->second;
-      while (!list.empty()) {
-        result = std::move(list.back());
-        list.pop_back();
+      for (auto listIt = list.begin(); listIt != list.end(); ) {
+        auto& entry = *listIt;
+        if (!entry->canBeReused()) {
+          if (!entry->willBeReusable(false)) {
+            listIt = list.erase(listIt);
+          }
+          else {
+            ++listIt;
+          }
+          continue;
+        }
 
-        result->setReused();
+        entry->setReused();
         /* for connections that have not been used very recently,
            check whether they have been closed in the meantime */
-        if (freshCutOff < result->getLastDataReceivedTime()) {
+        if (freshCutOff < entry->getLastDataReceivedTime()) {
           /* used recently enough, skip the check */
           ++ds->tcpReusedConnections;
-          return result;
+          return entry;
         }
 
-        if (isTCPSocketUsable(result->getHandle())) {
+        if (isTCPSocketUsable(entry->getHandle())) {
           ++ds->tcpReusedConnections;
-          return result;
+          return entry;
+        }
+        else {
+          listIt = list.erase(listIt);
+          continue;
         }
 
         /* otherwise let's try the next one, if any */
+        ++listIt;
       }
     }
   }
 
-  return std::make_shared<TCPConnectionToBackend>(ds, mplexer, now);
-}
-
-void DownstreamConnectionsManager::releaseDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>&& conn)
-{
-  if (conn == nullptr) {
-    return;
-  }
-
-  if (!conn->canBeReused()) {
-    conn.reset();
-    return;
-  }
-
-  const auto& ds = conn->getDS();
-  {
-    auto& list = t_downstreamConnections[ds->getID()];
-    while (list.size() >= s_maxCachedConnectionsPerDownstream) {
-      /* too many connections queued already */
-      list.pop_front();
-    }
-
-    list.push_back(std::move(conn));
+  auto newConnection = std::make_shared<TCPConnectionToBackend>(ds, mplexer, now);
+  if (!ds->useProxyProtocol) {
+    t_downstreamConnections[backendId].push_front(newConnection);
   }
+  return newConnection;
 }
 
 void DownstreamConnectionsManager::cleanupClosedTCPConnections(struct timeval now)
index c05c1264dfca31facfe2080960b1dd6d08df017a..0ea546b3dfd7e7a72adf00bf8e7967034e6b72e2 100644 (file)
@@ -7,15 +7,15 @@
 #include "dnsdist.hh"
 #include "dnsdist-tcp.hh"
 
-class TCPConnectionToBackend : public std::enable_shared_from_this<TCPConnectionToBackend>
+class ConnectionToBackend : public std::enable_shared_from_this<ConnectionToBackend>
 {
 public:
-  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now): d_connectionStartTime(now), d_lastDataReceivedTime(now), d_ds(ds), d_responseBuffer(s_maxPacketCacheEntrySize), d_mplexer(mplexer), d_enableFastOpen(ds->tcpFastOpen)
+  ConnectionToBackend(std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now): d_connectionStartTime(now), d_lastDataReceivedTime(now), d_ds(ds), d_mplexer(mplexer), d_enableFastOpen(ds->tcpFastOpen)
   {
     reconnect();
   }
 
-  virtual ~TCPConnectionToBackend();
+  virtual ~ConnectionToBackend();
 
   int getHandle() const
   {
@@ -61,43 +61,52 @@ public:
     return d_enableFastOpen;
   }
 
-  /* whether we can accept new queries FOR THE SAME CLIENT */
-  bool canAcceptNewQueries() const
+  /* whether a connection can be used now */
+  bool canBeReused(bool sameClient = false) const
   {
     if (d_connectionDied) {
       return false;
     }
 
-    if ((d_pendingQueries.size() + d_pendingResponses.size()) >= d_ds->d_maxInFlightQueriesPerConn) {
+    /* we can't reuse a connection where a proxy protocol payload has been sent,
+       since:
+       - it cannot be reused for a different client
+       - we might have different TLV values for each query
+    */
+    if (d_ds && d_ds->useProxyProtocol == true && !sameClient) {
       return false;
     }
 
-    return true;
-  }
+    if (reachedMaxStreamID()) {
+      return false;
+    }
 
-  bool isIdle() const
-  {
-    return d_state == State::idle && d_pendingQueries.size() == 0 && d_pendingResponses.size() == 0;
+    if (reachedMaxConcurrentQueries()) {
+      return false;
+    }
+
+    return true;
   }
 
-  /* whether a connection can be reused for a different client */
-  virtual bool canBeReused() const
+  /* full now but will become usable later */
+  bool willBeReusable(bool sameClient) const
   {
-    if (d_connectionDied) {
+    if (d_connectionDied || reachedMaxStreamID()) {
       return false;
     }
-    /* we can't reuse a connection where a proxy protocol payload has been sent,
-       since:
-       - it cannot be reused for a different client
-       - we might have different TLV values for each query
-    */
+
     if (d_ds && d_ds->useProxyProtocol == true) {
-      return false;
+      return sameClient;
     }
+
     return true;
   }
 
-  bool matchesTLVs(const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) const;
+  virtual bool reachedMaxStreamID() const = 0;
+  virtual bool reachedMaxConcurrentQueries() const = 0;
+  virtual void release()
+  {
+  }
 
   bool matches(const std::shared_ptr<DownstreamState>& ds) const
   {
@@ -107,44 +116,18 @@ public:
     return ds == d_ds;
   }
 
-  virtual void queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query);
-  virtual void handleTimeout(const struct timeval& now, bool write);
-  void release();
-
-  void setProxyProtocolValuesSent(std::unique_ptr<std::vector<ProxyProtocolValue>>&& proxyProtocolValuesSent);
+  virtual void queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query) = 0;
+  virtual void handleTimeout(const struct timeval& now, bool write) = 0;
 
   struct timeval getLastDataReceivedTime() const
   {
     return d_lastDataReceivedTime;
   }
 
-  virtual std::string toString() const
-  {
-    ostringstream o;
-    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses, linked to "<<(d_sender ? " a client" : "no client");
-    return o.str();
-  }
+  virtual std::string toString() const = 0;
 
 protected:
-  /* waitingForResponseFromBackend is a state where we have not yet started reading the size,
-     so we can still switch to sending instead */
-  enum class State : uint8_t { idle, sendingQueryToBackend, waitingForResponseFromBackend, readingResponseSizeFromBackend, readingResponseFromBackend };
-  enum class FailureReason : uint8_t { /* too many attempts */ gaveUp, timeout, unexpectedQueryID };
-
-  static void handleIO(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
-  static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
-  static IOState queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn);
-  static IOState sendQuery(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
-  static bool isXFRFinished(const TCPResponse& response, TCPQuery& query);
-
-  IOState handleResponse(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
-  uint16_t getQueryIdFromResponse() const;
   bool reconnect();
-  void notifyAllQueriesFailed(const struct timeval& now, FailureReason reason);
-  bool needProxyProtocolPayload() const
-  {
-    return !d_proxyProtocolPayloadSent && (d_ds && d_ds->useProxyProtocol);
-  }
 
   boost::optional<struct timeval> getBackendHealthCheckTTD(const struct timeval& now) const
   {
@@ -206,34 +189,107 @@ protected:
     return res;
   }
 
-  TCPQuery d_currentQuery;
-  std::deque<TCPQuery> d_pendingQueries;
-  std::unordered_map<uint16_t, TCPQuery> d_pendingResponses;
   struct timeval d_connectionStartTime;
   struct timeval d_lastDataReceivedTime;
   std::shared_ptr<DownstreamState> d_ds{nullptr};
   std::shared_ptr<TCPQuerySender> d_sender{nullptr};
-  PacketBuffer d_responseBuffer;
   std::unique_ptr<FDMultiplexer>& d_mplexer;
-  std::unique_ptr<std::vector<ProxyProtocolValue>> d_proxyProtocolValuesSent{nullptr};
   std::unique_ptr<TCPIOHandler> d_handler{nullptr};
   std::unique_ptr<IOStateHandler> d_ioState{nullptr};
-  size_t d_currentPos{0};
   uint64_t d_queries{0};
-  uint64_t d_downstreamFailures{0};
-  uint16_t d_responseSize{0};
-  State d_state{State::idle};
-  bool d_fresh{true};
+  uint32_t d_highestStreamID{0};
+  uint16_t d_downstreamFailures{0};
+  bool d_proxyProtocolPayloadSent{false};
   bool d_enableFastOpen{false};
   bool d_connectionDied{false};
-  bool d_proxyProtocolPayloadSent{false};
+  bool d_fresh{true};
+};
+
+class TCPConnectionToBackend : public ConnectionToBackend
+{
+public:
+  TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, std::unique_ptr<FDMultiplexer>& mplexer, const struct timeval& now): ConnectionToBackend(ds, mplexer, now), d_responseBuffer(s_maxPacketCacheEntrySize)
+  {
+  }
+
+  virtual ~TCPConnectionToBackend();
+
+  bool isIdle() const
+  {
+    return d_state == State::idle && d_pendingQueries.size() == 0 && d_pendingResponses.size() == 0;
+  }
+
+  bool reachedMaxStreamID() const override
+  {
+    /* TCP/DoT has only 2^16 usable identifiers, DoH has 2^32 */
+    const uint32_t maximumStreamID = std::numeric_limits<uint16_t>::max() - 1;
+    return d_highestStreamID == maximumStreamID;
+  }
+
+  bool reachedMaxConcurrentQueries() const override
+  {
+    const size_t concurrent = d_pendingQueries.size() + d_pendingResponses.size();
+    if (concurrent > 0 && concurrent >= d_ds->d_maxInFlightQueriesPerConn) {
+      return true;
+    }
+    return false;
+  }
+  bool matchesTLVs(const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) const;
+
+  void queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query) override;
+  void handleTimeout(const struct timeval& now, bool write) override;
+  void release() override;
+
+  std::string toString() const override
+  {
+    ostringstream o;
+    o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<<d_queries<<", pending queries count is "<<d_pendingQueries.size()<<", "<<d_pendingResponses.size()<<" pending responses";
+    return o.str();
+  }
+
+  void setProxyProtocolValuesSent(std::unique_ptr<std::vector<ProxyProtocolValue>>&& proxyProtocolValuesSent);
+
+private:
+  /* waitingForResponseFromBackend is a state where we have not yet started reading the size,
+     so we can still switch to sending instead */
+  enum class State : uint8_t { idle, sendingQueryToBackend, waitingForResponseFromBackend, readingResponseSizeFromBackend, readingResponseFromBackend };
+  enum class FailureReason : uint8_t { /* too many attempts */ gaveUp, timeout, unexpectedQueryID };
+
+  static void handleIO(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
+  static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+  static IOState queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn);
+  static IOState sendQuery(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
+  static bool isXFRFinished(const TCPResponse& response, TCPQuery& query);
+
+  IOState handleResponse(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now);
+  uint16_t getQueryIdFromResponse() const;
+  void notifyAllQueriesFailed(const struct timeval& now, FailureReason reason);
+  bool needProxyProtocolPayload() const
+  {
+    return !d_proxyProtocolPayloadSent && (d_ds && d_ds->useProxyProtocol);
+  }
+
+  class PendingRequest
+  {
+  public:
+    std::shared_ptr<TCPQuerySender> d_sender{nullptr};
+    TCPQuery d_query;
+  };
+
+  PacketBuffer d_responseBuffer;
+  std::deque<PendingRequest> d_pendingQueries;
+  std::unordered_map<uint16_t, PendingRequest> d_pendingResponses;
+  std::unique_ptr<std::vector<ProxyProtocolValue>> d_proxyProtocolValuesSent{nullptr};
+  PendingRequest d_currentQuery;
+  size_t d_currentPos{0};
+  uint16_t d_responseSize{0};
+  State d_state{State::idle};
 };
 
 class DownstreamConnectionsManager
 {
 public:
   static std::shared_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now);
-  static void releaseDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>&& conn);
   static void cleanupClosedTCPConnections(struct timeval now);
   static size_t clear();
 
index 99e933aa2c8917cea7b927d7d5a694b56d2a0d0d..9ed8b6b3fa4dbb697bd0bbe0714318020f1d857f 100644 (file)
@@ -105,9 +105,9 @@ public:
     return false;
   }
 
-  std::shared_ptr<TCPConnectionToBackend> getActiveDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs);
+  std::shared_ptr<TCPConnectionToBackend> getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs);
   std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now);
-  void registerActiveDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn);
+  void registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn);
 
   static size_t clearAllDownstreamConnections();
 
@@ -141,14 +141,14 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
   std::string toString() const
   {
     ostringstream o;
-    o << "Incoming TCP connection from "<<d_ci.remote.toStringWithPort()<<" over FD "<<d_handler.getDescriptor()<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<<d_queriesCount<<", current queries count is "<<d_currentQueriesCount<<", "<<d_queuedResponses.size()<<" queued responses, "<<d_activeConnectionsToBackend.size()<<" active connections to a backend";
+    o << "Incoming TCP connection from "<<d_ci.remote.toStringWithPort()<<" over FD "<<d_handler.getDescriptor()<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<<d_queriesCount<<", current queries count is "<<d_currentQueriesCount<<", "<<d_queuedResponses.size()<<" queued responses, "<<d_ownedConnectionsToBackend.size()<<" owned connections to a backend";
     return o.str();
   }
 
   enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
 
   TCPResponse d_currentResponse;
-  std::map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> d_activeConnectionsToBackend;
+  std::map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> d_ownedConnectionsToBackend;
   std::deque<TCPResponse> d_queuedResponses;
   PacketBuffer d_buffer;
   ConnectionInfo d_ci;
index b9c86c916fdc7db184c7cb5c1b6521c109afae72..9154f2f650f2d237a9e6a4a8f3a2b8149b191242 100644 (file)
@@ -121,7 +121,7 @@ struct InternalQuery
 
 using TCPQuery = InternalQuery;
 
-class TCPConnectionToBackend;
+class ConnectionToBackend;
 
 struct TCPResponse : public TCPQuery
 {
@@ -131,13 +131,13 @@ struct TCPResponse : public TCPQuery
     memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
   }
 
-  TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr<TCPConnectionToBackend> conn) :
+  TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr<ConnectionToBackend> conn) :
     TCPQuery(std::move(buffer), std::move(state)), d_connection(conn)
   {
     memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
   }
 
-  std::shared_ptr<TCPConnectionToBackend> d_connection{nullptr};
+  std::shared_ptr<ConnectionToBackend> d_connection{nullptr};
   dnsheader d_cleartextDH;
   bool d_selfGenerated{false};
 };
index b17b1944c7801e69e42da92111dbb1e5a1a43900..4d6224149cc25d022d518eb949d4e118f6b7d32f 100644 (file)
@@ -419,9 +419,25 @@ static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t por
   return ComboAddress("192.0.2." + lastDigit, port);
 }
 
+static void appendPayloadEditingID(PacketBuffer& buffer, const PacketBuffer& payload, uint16_t newID)
+{
+  PacketBuffer newPayload(payload);
+  auto dh = reinterpret_cast<dnsheader*>(&newPayload.at(sizeof(uint16_t)));
+  dh->id = htons(newID);
+  buffer.insert(buffer.end(), newPayload.begin(), newPayload.end());
+}
+
+static void prependPayloadEditingID(PacketBuffer& buffer, const PacketBuffer& payload, uint16_t newID)
+{
+  PacketBuffer newPayload(payload);
+  auto dh = reinterpret_cast<dnsheader*>(&newPayload.at(sizeof(uint16_t)));
+  dh->id = htons(newID);
+  buffer.insert(buffer.begin(), newPayload.begin(), newPayload.end());
+}
+
 static void testInit(const std::string& name, TCPClientThreadData& threadData)
 {
-#if 0
+#ifdef DEBUGLOG_ENABLED
   cerr<<name<<endl;
 #else
   (void) name;
@@ -435,6 +451,7 @@ static void testInit(const std::string& name, TCPClientThreadData& threadData)
 
   g_proxyProtocolACL.clear();
   g_verbose = false;
+  IncomingTCPConnectionState::clearAllDownstreamConnections();
 
   threadData.mplexer = std::make_unique<MockupFDMultiplexer>();
 }
@@ -786,8 +803,6 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered)
     s_readBuffer = query;
     // preprend the proxy protocol payload
     s_readBuffer.insert(s_readBuffer.begin(), proxyPayload.begin(), proxyPayload.end());
-    // append a second query
-    s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
 
     s_steps = {
       { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
@@ -836,6 +851,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
   PacketBuffer query;
   GenericDNSPacketWriter<PacketBuffer> pwQ(query, DNSName("powerdns.com."), QType::A, QClass::IN, 0);
   pwQ.getHeader()->rd = 1;
+  pwQ.getHeader()->id = 0;
 
   auto shortQuery = query;
   shortQuery.resize(sizeof(dnsheader) - 1);
@@ -1083,11 +1099,11 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     TEST_INIT("=> Short read and write to backend");
     s_readBuffer = query;
     // append a second query
-    s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
+    appendPayloadEditingID(s_readBuffer, query, 1);
 
     s_backendReadBuffer = query;
     // append a second query
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), query.begin(), query.end());
+    appendPayloadEditingID(s_backendReadBuffer, query, 1);
 
     s_steps = {
       { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
@@ -1629,8 +1645,8 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR)
     s_readBuffer = query;
 
     for (size_t idx = 0; idx < count; idx++) {
-      s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
-      s_backendReadBuffer.insert(s_backendReadBuffer.end(), query.begin(), query.end());
+      appendPayloadEditingID(s_readBuffer, query, idx);
+      appendPayloadEditingID(s_backendReadBuffer, query, idx);
     }
 
     s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
@@ -1716,7 +1732,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
   for (auto& query : queries) {
     GenericDNSPacketWriter<PacketBuffer> pwQ(query, DNSName("powerdns" + std::to_string(counter) + ".com."), QType::A, QClass::IN, 0);
     pwQ.getHeader()->rd = 1;
-    pwQ.getHeader()->id = counter;
+    pwQ.getHeader()->id = htons(counter);
     uint16_t querySize = static_cast<uint16_t>(query.size());
     const uint8_t sizeBytes[] = { static_cast<uint8_t>(querySize / 256), static_cast<uint8_t>(querySize % 256) };
     query.insert(query.begin(), sizeBytes, sizeBytes + 2);
@@ -1732,7 +1748,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     pwR.getHeader()->qr = 1;
     pwR.getHeader()->rd = 1;
     pwR.getHeader()->ra = 1;
-    pwR.getHeader()->id = counter;
+    pwR.getHeader()->id = htons(counter);
     pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER);
     pwR.xfr32BitInt(0x01020304);
     pwR.commit();
@@ -1749,16 +1765,18 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     PacketBuffer expectedWriteBuffer;
     PacketBuffer expectedBackendWriteBuffer;
 
+    uint16_t backendCounter = 0;
     for (const auto& query : queries) {
       s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
+      appendPayloadEditingID(expectedBackendWriteBuffer, query, backendCounter++);
     }
-    expectedBackendWriteBuffer = s_readBuffer;
 
+    backendCounter = 0;
     for (const auto& response : responses) {
       /* reverse order */
-      s_backendReadBuffer.insert(s_backendReadBuffer.begin(), response.begin(), response.end());
+      prependPayloadEditingID(s_backendReadBuffer, response, backendCounter++);
+      expectedWriteBuffer.insert(expectedWriteBuffer.begin(), response.begin(), response.end());
     }
-    expectedWriteBuffer = s_backendReadBuffer;
 
     s_steps = {
       { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
@@ -1884,8 +1902,14 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
     }
 
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end());
+    uint16_t backendCounter = 0;
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), backendCounter++);
+
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(4), 3);
 
     /* self-answered */
     expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(3).begin(), responses.at(3).end());
@@ -1893,12 +1917,6 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end());
     expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(4).begin(), responses.at(4).end());
 
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(4).begin(), queries.at(4).end());
-
-
     bool timeout = false;
     s_steps = {
       { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
@@ -2027,13 +2045,24 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     for (const auto& query : queries) {
       s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
     }
-    expectedBackendWriteBuffer = s_readBuffer;
-
     for (const auto& response : responses) {
       expectedWriteBuffer.insert(expectedWriteBuffer.end(), response.begin(), response.end());
     }
 
-    s_backendReadBuffer = expectedWriteBuffer;
+    uint16_t backendCounter = 0;
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), backendCounter);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), backendCounter);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(1), backendCounter++);
+
+    // new connection
+    backendCounter = 0;
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), backendCounter);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(2), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(3), backendCounter);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(3), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), backendCounter);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(4), backendCounter++);
 
     bool timeout = false;
     int backendDesc;
@@ -2341,15 +2370,18 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     s_readBuffer.insert(s_readBuffer.end(), queries.at(1).begin(), queries.at(1).end());
     s_readBuffer.insert(s_readBuffer.end(), queries.at(4).begin(), queries.at(4).end());
 
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(4).begin(), queries.at(4).end());
+    uint16_t backendCounter = 0;
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), backendCounter++);
 
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end());
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(1), 1);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(4), 2);
 
-    expectedWriteBuffer = s_backendReadBuffer;
+    appendPayloadEditingID(expectedWriteBuffer, responses.at(1), 1);
+    appendPayloadEditingID(expectedWriteBuffer, responses.at(0), 0);
+    appendPayloadEditingID(expectedWriteBuffer, responses.at(4), 4);
 
     /* make sure that the backend's timeout is longer than the client's */
     backend->tcpRecvTimeout = 30;
@@ -2713,14 +2745,16 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     s_readBuffer = axfrQuery;
     s_readBuffer.insert(s_readBuffer.end(), secondQuery.begin(), secondQuery.end());
 
-    expectedBackendWriteBuffer = s_readBuffer;
+    uint16_t backendCounter = 0;
+    appendPayloadEditingID(expectedBackendWriteBuffer, axfrQuery, backendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, secondQuery, backendCounter++);
 
     for (const auto& response : axfrResponses) {
-      s_backendReadBuffer.insert(s_backendReadBuffer.end(), response.begin(), response.end());
+      appendPayloadEditingID(s_backendReadBuffer, response, 0);
+      expectedWriteBuffer.insert(expectedWriteBuffer.end(), response.begin(), response.end());
     }
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), secondResponse.begin(), secondResponse.end());
-
-    expectedWriteBuffer = s_backendReadBuffer;
+    appendPayloadEditingID(s_backendReadBuffer, secondResponse, 1);
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), secondResponse.begin(), secondResponse.end());
 
     bool timeout = false;
     s_steps = {
@@ -2973,15 +3007,18 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     s_readBuffer.insert(s_readBuffer.end(), ixfrQuery.begin(), ixfrQuery.end());
     s_readBuffer.insert(s_readBuffer.end(), secondQuery.begin(), secondQuery.end());
 
-    expectedBackendWriteBuffer = s_readBuffer;
+    appendPayloadEditingID(expectedBackendWriteBuffer, firstQuery, 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, ixfrQuery, 1);
+    appendPayloadEditingID(expectedBackendWriteBuffer, secondQuery, 2);
 
-    s_backendReadBuffer = firstResponse;
+    appendPayloadEditingID(s_backendReadBuffer, firstResponse, 0);
+    expectedWriteBuffer.insert(expectedWriteBuffer.begin(), firstResponse.begin(), firstResponse.end());
     for (const auto& response : ixfrResponses) {
-      s_backendReadBuffer.insert(s_backendReadBuffer.end(), response.begin(), response.end());
+      appendPayloadEditingID(s_backendReadBuffer, response, 1);
+      expectedWriteBuffer.insert(expectedWriteBuffer.end(), response.begin(), response.end());
     }
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), secondResponse.begin(), secondResponse.end());
-
-    expectedWriteBuffer = s_backendReadBuffer;
+    appendPayloadEditingID(s_backendReadBuffer, secondResponse, 2);
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), secondResponse.begin(), secondResponse.end());
 
     bool timeout = false;
     s_steps = {
@@ -3083,19 +3120,22 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     proxyEnabledBackend->useProxyProtocol = true;
 
     expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end());
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 1);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 2);
     expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end());
     /* we are using an unordered_map, so all bets are off here :-/ */
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end());
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 1);
 
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end());
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0);
+    /* after the reconnection */
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(1), 1);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(2), 0);
 
-    expectedWriteBuffer = s_backendReadBuffer;
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(1).begin(), responses.at(1).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(2).begin(), responses.at(2).end());
 
     s_steps = {
       { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
@@ -3204,16 +3244,12 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     proxyEnabledBackend->d_tlsCtx = tlsCtx;
     /* enable out-of-order on the backend side as well */
     proxyEnabledBackend->d_maxInFlightQueriesPerConn = 65536;
-    proxyEnabledBackend-> useProxyProtocol = true;
-
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end());
+    proxyEnabledBackend->useProxyProtocol = true;
 
     expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end());
-    expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end());
-    //s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end());
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 1);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 2);
 
     s_steps = {
       { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done },
@@ -3245,31 +3281,10 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       }},
       /* client closes the connection */
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 },
-      /* closing the client connection */
-      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 },
-      /* try to read response from backend, connection has been closed */
-      { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, 0 },
-      //{ ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(2).size() },
       /* closing the backend connection */
       { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done },
-      { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done },
-      { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done },
-      { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done },
-      { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done },
-      /* sending query (3) to the backend */
-      { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, proxyPayload.size() + queries.at(2).size() },
-      /* sending query (2) to the backend */
-      { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 },
-      { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 },
+      /* closing the client connection */
+      { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 },
     };
 
     s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend) -> ProcessQueryResult {
@@ -3371,8 +3386,8 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     g_tcpRecvTimeout = 2;
 
     /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */
-    /* we should have nothing to clear since the connection cannot be reused due to the Proxy Protocol payload */
-    BOOST_CHECK_EQUAL(IncomingTCPConnectionState::clearAllDownstreamConnections(), 0U);
+    /* we have one connection to clear, no proxy protocol */
+    BOOST_CHECK_EQUAL(IncomingTCPConnectionState::clearAllDownstreamConnections(), 1U);
   }
 
   {
@@ -3383,15 +3398,29 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
     for (const auto& query : queries) {
       s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
     }
-    expectedBackendWriteBuffer = s_readBuffer;
 
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(3).begin(), responses.at(3).end());
+    /* queries 0, 1 and 4 are sent to the first backend, 2 and 3 to the second */
+    uint16_t firstBackendCounter = 0;
+    uint16_t secondBackendCounter = 0;
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), firstBackendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), firstBackendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), secondBackendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(3), secondBackendCounter++);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), firstBackendCounter++);
+
+    firstBackendCounter = 0;
+    secondBackendCounter = 0;
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), firstBackendCounter++);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(1), firstBackendCounter++);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(2), secondBackendCounter++);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(4), firstBackendCounter++);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(3), secondBackendCounter++);
 
-    expectedWriteBuffer = s_backendReadBuffer;
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(1).begin(), responses.at(1).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(2).begin(), responses.at(2).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(4).begin(), responses.at(4).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(3).begin(), responses.at(3).end());
 
     auto backend1 = std::make_shared<DownstreamState>(getBackendAddress("42", 53), ComboAddress("0.0.0.0:0"), 0, std::string(), 1, false);
     backend1->d_tlsCtx = tlsCtx;
@@ -3539,17 +3568,21 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
   }
 
   {
-    TEST_INIT("=> 2 OOOR queries to the backend with duplicated IDs, backend responds to only one of them");
+    TEST_INIT("=> 2 OOOR queries to the backend with duplicated IDs");
     PacketBuffer expectedWriteBuffer;
     PacketBuffer expectedBackendWriteBuffer;
 
     s_readBuffer.insert(s_readBuffer.end(), queries.at(0).begin(), queries.at(0).end());
     s_readBuffer.insert(s_readBuffer.end(), queries.at(0).begin(), queries.at(0).end());
 
-    expectedBackendWriteBuffer = s_readBuffer;
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 1);
 
-    s_backendReadBuffer.insert(s_backendReadBuffer.begin(), responses.at(0).begin(), responses.at(0).end());
-    expectedWriteBuffer = s_backendReadBuffer;
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 1);
+
+    appendPayloadEditingID(expectedWriteBuffer, responses.at(0), 0);
+    appendPayloadEditingID(expectedWriteBuffer, responses.at(0), 0);
 
     bool timeout = false;
     s_steps = {
@@ -3575,7 +3608,12 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR)
       /* nothing more from the client either */
       { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0 },
 
-      /* reading a response from the backend */
+      /* reading response (1) from the backend */
+      { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size() - 2 },
+      { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size()},
+      /* sending it to the client */
+      { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, responses.at(0).size()},
+      /* reading response (2) from the backend */
       { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size() - 2 },
       { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size(), [&threadData](int desc, const ExpectedStep& step) {
         dynamic_cast<MockupFDMultiplexer*>(threadData.mplexer.get())->setNotReady(desc);
@@ -3691,15 +3729,24 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR)
     for (const auto& query : queries) {
       s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end());
     }
-    expectedBackendWriteBuffer = s_readBuffer;
 
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end());
-    s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(3).begin(), responses.at(3).end());
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(3), 0);
+    appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), 0);
 
-    expectedWriteBuffer = s_backendReadBuffer;
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(2), 0);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(1), 0);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(4), 0);
+    appendPayloadEditingID(s_backendReadBuffer, responses.at(3), 0);
+
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(2).begin(), responses.at(2).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(1).begin(), responses.at(1).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(4).begin(), responses.at(4).end());
+    expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(3).begin(), responses.at(3).end());
 
     std::vector<int> backendDescriptors = { -1, -1, -1, -1, -1 };
 
index b4ea7be4b9a58bc507b1c35d813893f9e1dfcf4c..d4897e86673beb18413f22490d5d838a6e80d42a 100644 (file)
@@ -240,12 +240,66 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
         sock.close()
 
     @classmethod
-    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None):
+    def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None):
+      ignoreTrailing = trailingDataResponse is True
+      data = conn.recv(2)
+      if not data:
+        conn.close()
+        return
+
+      (datalen,) = struct.unpack("!H", data)
+      data = conn.recv(datalen)
+      forceRcode = None
+      try:
+        request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
+      except dns.message.TrailingJunk as e:
+        if trailingDataResponse is False or forceRcode is True:
+          raise
+        print("TCP query with trailing data, synthesizing response")
+        request = dns.message.from_wire(data, ignore_trailing=True)
+        forceRcode = trailingDataResponse
+
+      if callback:
+        wire = callback(request)
+      else:
+        response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
+        if response:
+          wire = response.to_wire(max_size=65535)
+
+      if not wire:
+        conn.close()
+        return
+
+      conn.send(struct.pack("!H", len(wire)))
+      conn.send(wire)
+
+      while multipleResponses:
+        if fromQueue.empty():
+          break
+
+        response = fromQueue.get(True, cls._queueTimeout)
+        if not response:
+          break
+
+        response = copy.copy(response)
+        response.id = request.id
+        wire = response.to_wire(max_size=65535)
+        try:
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+        except socket.error as e:
+          # some of the tests are going to close
+          # the connection on us, just deal with it
+          break
+
+      conn.close()
+
+    @classmethod
+    def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False):
         # trailingDataResponse=True means "ignore trailing data".
         # Other values are either False (meaning "raise an exception")
         # or are interpreted as a response RCODE for queries with trailing data.
         # callback is invoked for every -even healthcheck ones- query and should return a raw response
-        ignoreTrailing = trailingDataResponse is True
 
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@@ -269,57 +323,14 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase):
               continue
 
             conn.settimeout(5.0)
-            data = conn.recv(2)
-            if not data:
-                conn.close()
-                continue
-
-            (datalen,) = struct.unpack("!H", data)
-            data = conn.recv(datalen)
-            forceRcode = None
-            try:
-                request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing)
-            except dns.message.TrailingJunk as e:
-                if trailingDataResponse is False or forceRcode is True:
-                    raise
-                print("TCP query with trailing data, synthesizing response")
-                request = dns.message.from_wire(data, ignore_trailing=True)
-                forceRcode = trailingDataResponse
-
-            if callback:
-              wire = callback(request)
+            if multipleConnections:
+              thread = threading.Thread(name='TCP Connection Handler',
+                                        target=cls.handleTCPConnection,
+                                        args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback])
+              thread.setDaemon(True)
+              thread.start()
             else:
-              response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode)
-              if response:
-                wire = response.to_wire(max_size=65535)
-
-            if not wire:
-                conn.close()
-                continue
-
-            conn.send(struct.pack("!H", len(wire)))
-            conn.send(wire)
-
-            while multipleResponses:
-                if fromQueue.empty():
-                    break
-
-                response = fromQueue.get(True, cls._queueTimeout)
-                if not response:
-                    break
-
-                response = copy.copy(response)
-                response.id = request.id
-                wire = response.to_wire(max_size=65535)
-                try:
-                    conn.send(struct.pack("!H", len(wire)))
-                    conn.send(wire)
-                except socket.error as e:
-                    # some of the tests are going to close
-                    # the connection on us, just deal with it
-                    break
-
-            conn.close()
+              cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback)
 
         sock.close()
 
index c4e43c7bad3212efdde044330ab05493bccfe36d..8076dd1524af84dc807a96616a6b5cecf4f00ef8 100644 (file)
@@ -13,6 +13,7 @@ class TestAXFR(DNSDistTest):
     _config_template = """
     newServer{address="127.0.0.1:%s"}
     """
+
     @classmethod
     def startResponders(cls):
         print("Launching responders..")
@@ -20,7 +21,7 @@ class TestAXFR(DNSDistTest):
         cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
         cls._UDPResponder.setDaemon(True)
         cls._UDPResponder.start()
-        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, True])
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, True, None, None, True])
         cls._TCPResponder.setDaemon(True)
         cls._TCPResponder.start()
 
index 483d555fcd38b00a1c72d70fff649236f1976d6b..788a0b17d7341b4fc3ca93f1b2c5af8eb69b2e51 100644 (file)
@@ -780,6 +780,7 @@ class TestDynBlockQPSActionTruncated(DNSDistTest):
         # check over TCP, which should not be truncated
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
 
+        receivedQuery.id = query.id
         self.assertEqual(query, receivedQuery)
         self.assertEqual(receivedResponse, response)
 
@@ -798,6 +799,7 @@ class TestDynBlockQPSActionTruncated(DNSDistTest):
         for _ in range((self._dynBlockQPS * self._dynBlockPeriod) + 1):
             (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
             sent = sent + 1
+            receivedQuery.id = query.id
             self.assertEqual(query, receivedQuery)
             self.assertEqual(receivedResponse, response)
             receivedQuery.id = query.id
index 74998e87e4bc10853fb07d8b326cd8b262dd6f7a..53677a66123b55f3df0491482dc3aa838cd98608 100644 (file)
@@ -52,6 +52,7 @@ class OutgoingTLSTests(object):
         numberOfUDPQueries = 10
         for _ in range(numberOfUDPQueries):
             (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+            receivedQuery.id = query.id
             self.assertEqual(query, receivedQuery)
             self.assertEqual(receivedResponse, expectedResponse)
 
@@ -82,6 +83,7 @@ class OutgoingTLSTests(object):
         expectedResponse.answer.append(rrset)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+        receivedQuery.id = query.id
         self.assertEqual(query, receivedQuery)
         self.assertEqual(receivedResponse, expectedResponse)
         self.checkOnlyTLSResponderHit()
index 7ada4cacd6580eb482948f407e460ce0cdaa367d..91edd504479c2a07c5783a2c93e3b933aa8ee69b 100644 (file)
@@ -24,6 +24,7 @@ class TestTCPOnly(DNSDistTest):
         expectedResponse.answer.append(rrset)
 
         (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse)
+        receivedQuery.id = query.id
         self.assertEqual(query, receivedQuery)
         self.assertEqual(receivedResponse, expectedResponse)
 
@@ -47,6 +48,7 @@ class TestTCPOnly(DNSDistTest):
         expectedResponse.answer.append(rrset)
 
         (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse)
+        receivedQuery.id = query.id
         self.assertEqual(query, receivedQuery)
         self.assertEqual(receivedResponse, expectedResponse)
         if 'UDP Responder' in self._responsesCounter: