From: Remi Gacogne Date: Fri, 15 Oct 2021 15:36:16 +0000 (+0200) Subject: dnsdist: Use the same outgoing TCP connection for different clients X-Git-Tag: rec-4.6.0-beta1~28^2~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=645a1ca439fe83f742b7d5fe71042d38fdb575f4;p=thirdparty%2Fpdns.git dnsdist: Use the same outgoing TCP connection for different clients --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index d70f7692e5..3307980240 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -113,12 +113,14 @@ std::shared_ptr IncomingTCPConnectionState::getDownstrea { std::shared_ptr downstream{nullptr}; - downstream = getActiveDownstreamConnection(ds, tlvs); + downstream = getOwnedDownstreamConnection(ds, tlvs); if (!downstream) { - /* we don't have a connection to this backend active yet, let's get one (it might not be a fresh one, though) */ + /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */ downstream = DownstreamConnectionsManager::getConnectionToDownstream(d_threadData.mplexer, ds, now); - registerActiveDownstreamConnection(downstream); + if (ds->useProxyProtocol) { + registerOwnedDownstreamConnection(downstream); + } } return downstream; @@ -307,17 +309,17 @@ void IncomingTCPConnectionState::resetForNewQuery() d_state = State::waitingForQuery; } -std::shared_ptr IncomingTCPConnectionState::getActiveDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs) +std::shared_ptr IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs) { - auto it = d_activeConnectionsToBackend.find(ds); - if (it == d_activeConnectionsToBackend.end()) { - DEBUGLOG("no active connection found for "<getName()); + auto it = d_ownedConnectionsToBackend.find(ds); + if (it == d_ownedConnectionsToBackend.end()) { + DEBUGLOG("no owned connection found for "<getName()); return nullptr; } for (auto& conn : it->second) { - if (conn->canAcceptNewQueries() && conn->matchesTLVs(tlvs)) { - DEBUGLOG("Got one active connection accepting more for "<getName()); + if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) { + DEBUGLOG("Got one owned connection accepting more for "<getName()); conn->setReused(); return conn; } @@ -327,9 +329,9 @@ std::shared_ptr IncomingTCPConnectionState::getActiveDow return nullptr; } -void IncomingTCPConnectionState::registerActiveDownstreamConnection(std::shared_ptr& conn) +void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_ptr& conn) { - d_activeConnectionsToBackend[conn->getDS()].push_front(conn); + d_ownedConnectionsToBackend[conn->getDS()].push_front(conn); } /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */ @@ -375,7 +377,12 @@ void IncomingTCPConnectionState::terminateClientConnection() d_queuedResponses.clear(); /* we have already released idle connections that could be reused, we don't care about the ones still waiting for responses */ - d_activeConnectionsToBackend.clear(); + for (auto& backend : d_ownedConnectionsToBackend) { + for (auto& conn : backend.second) { + conn->release(); + } + } + d_ownedConnectionsToBackend.clear(); /* meaning we will no longer be 'active' when the backend response or timeout comes in */ d_ioState.reset(); @@ -419,18 +426,18 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe { std::shared_ptr state = shared_from_this(); - if (response.d_connection && response.d_connection->isIdle()) { - // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool yet, no one else will be able to use it anyway - if (response.d_connection->canBeReused()) { - const auto connIt = state->d_activeConnectionsToBackend.find(response.d_connection->getDS()); - if (connIt != state->d_activeConnectionsToBackend.end()) { + if (response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->useProxyProtocol) { + // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway + if (!response.d_connection->willBeReusable(true)) { + // if it can't be reused even by us, well + const auto connIt = state->d_ownedConnectionsToBackend.find(response.d_connection->getDS()); + if (connIt != state->d_ownedConnectionsToBackend.end()) { auto& list = connIt->second; for (auto it = list.begin(); it != list.end(); ++it) { if (*it == response.d_connection) { try { response.d_connection->release(); - DownstreamConnectionsManager::releaseDownstreamConnection(std::move(*it)); } catch (const std::exception& e) { vinfolog("Error releasing connection: %s", e.what()); @@ -1056,7 +1063,7 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptrd_ci.remote.toStringWithPort()); DEBUGLOG("client timeout"); - DEBUGLOG("Processed "<d_queriesCount<<" queries, current count is "<d_currentQueriesCount<<", "<d_activeConnectionsToBackend.size()<<" active connections, "<d_queuedResponses.size()<<" response queued"); + DEBUGLOG("Processed "<d_queriesCount<<" queries, current count is "<d_currentQueriesCount<<", "<d_ownedConnectionsToBackend.size()<<" owned connections, "<d_queuedResponses.size()<<" response queued"); if (write || state->d_currentQueriesCount == 0) { ++state->d_ci.cs->tcpClientTimeouts; @@ -1067,14 +1074,6 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptrd_state = IncomingTCPConnectionState::State::idle; state->d_ioState->update(IOState::Done, handleIOCallback, state); - -#ifdef DEBUGLOG_ENABLED - for (const auto& active : state->d_activeConnectionsToBackend) { - for (const auto& conn: active.second) { - DEBUGLOG("Connection to "<getName()<<" is "<<(conn->isIdle() ? "idle" : "not idle")); - } - } -#endif } } diff --git a/pdns/dnsdistdist/dnsdist-nghttp2.cc b/pdns/dnsdistdist/dnsdist-nghttp2.cc index 4768b2a7e7..03abd4332a 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2.cc @@ -43,7 +43,7 @@ std::unique_ptr g_dohClientThreads{nullptr}; std::optional g_outgoingDoHWorkerThreads{std::nullopt}; #ifdef HAVE_NGHTTP2 -class DoHConnectionToBackend : public TCPConnectionToBackend +class DoHConnectionToBackend : public ConnectionToBackend { public: DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload); @@ -58,17 +58,14 @@ public: return o.str(); } - bool reachedMaxStreamID() const; - bool canBeReused() const override; - /* full now but will become usable later */ - bool willBeReusable() const; - void setHealthCheck(bool h) { d_healthCheckQuery = h; } void stopIO(); + bool reachedMaxConcurrentQueries() const override; + bool reachedMaxStreamID() const override; private: static ssize_t send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data); @@ -79,7 +76,6 @@ private: static int on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data); static void handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param); static void handleWritableIOCallback(int fd, FDMultiplexer::funcparam_t& param); - static void handleIO(std::shared_ptr& conn, const struct timeval& now); static void addStaticHeader(std::vector& headers, const std::string& nameKey, const std::string& valueKey); static void addDynamicHeader(std::vector& headers, const std::string& nameKey, const std::string& value); @@ -117,7 +113,6 @@ private: std::unique_ptr d_session{nullptr, nghttp2_session_del}; size_t d_outPos{0}; size_t d_inPos{0}; - uint32_t d_highestStreamID{0}; bool d_healthCheckQuery{false}; bool d_firstWrite{true}; }; @@ -213,34 +208,12 @@ bool DoHConnectionToBackend::reachedMaxStreamID() const return d_highestStreamID == maximumStreamID; } -bool DoHConnectionToBackend::canBeReused() const +bool DoHConnectionToBackend::reachedMaxConcurrentQueries() const { - if (d_connectionDied) { - return false; - } - - if (!d_proxyProtocolPayload.empty()) { - return false; - } - - if (reachedMaxStreamID()) { - return false; - } - //cerr<<"Got "< mplexer{nullptr}; }; -void DoHConnectionToBackend::handleIO(std::shared_ptr& conn, const struct timeval& now) -{ -} - void DoHConnectionToBackend::handleReadableIOCallback(int fd, FDMultiplexer::funcparam_t& param) { auto conn = boost::any_cast>(param); @@ -505,7 +474,7 @@ void DoHConnectionToBackend::stopIO() { d_ioState->reset(); - if (d_connectionDied) { + if (!willBeReusable(false)) { /* remove ourselves from the connection cache, this might mean that our reference count drops to zero after that, so we need to be careful */ auto shared = std::dynamic_pointer_cast(shared_from_this()); @@ -547,7 +516,7 @@ void DoHConnectionToBackend::updateIO(IOState newState, FDMultiplexer::callbackf void DoHConnectionToBackend::watchForRemoteHostClosingConnection() { - if (willBeReusable() && !d_healthCheckQuery) { + if (willBeReusable(false) && !d_healthCheckQuery) { updateIO(IOState::NeedRead, handleReadableIOCallback, false); } } @@ -820,9 +789,9 @@ int DoHConnectionToBackend::on_error_callback(nghttp2_session* session, int lib_ } DoHConnectionToBackend::DoHConnectionToBackend(std::shared_ptr ds, std::unique_ptr& mplexer, const struct timeval& now, std::string&& proxyProtocolPayload) : - TCPConnectionToBackend(ds, mplexer, now), d_proxyProtocolPayload(std::move(proxyProtocolPayload)) + ConnectionToBackend(ds, mplexer, now), d_proxyProtocolPayload(std::move(proxyProtocolPayload)) { - // inherit most of the stuff from the TCPConnectionToBackend() + // inherit most of the stuff from the ConnectionToBackend() d_ioState = make_unique(*d_mplexer, d_handler->getDescriptor()); nghttp2_session_callbacks* cbs = nullptr; @@ -973,7 +942,7 @@ std::shared_ptr DownstreamDoHConnectionsManager::getConn for (auto listIt = list.begin(); listIt != list.end();) { auto& entry = *listIt; if (!entry->canBeReused()) { - if (!entry->willBeReusable()) { + if (!entry->willBeReusable(false)) { listIt = list.erase(listIt); } else { @@ -1003,7 +972,7 @@ std::shared_ptr DownstreamDoHConnectionsManager::getConn auto newConnection = std::make_shared(ds, mplexer, now, std::move(proxyProtocolPayload)); if (!haveProxyProtocol) { - t_downstreamConnections[backendId].push_back(newConnection); + t_downstreamConnections[backendId].push_front(newConnection); } return newConnection; } diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index 657c9a44c4..8bb84145e0 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -5,12 +5,8 @@ #include "dnsparser.hh" -TCPConnectionToBackend::~TCPConnectionToBackend() +ConnectionToBackend::~ConnectionToBackend() { - if (d_ds && !d_pendingResponses.empty()) { - d_ds->outstanding -= d_pendingResponses.size(); - } - if (d_ds && d_handler) { --d_ds->tcpCurrentConnections; struct timeval now; @@ -36,6 +32,100 @@ TCPConnectionToBackend::~TCPConnectionToBackend() } } +bool ConnectionToBackend::reconnect() +{ + std::unique_ptr tlsSession{nullptr}; + if (d_handler) { + DEBUGLOG("closing socket "<getDescriptor()); + if (d_handler->isTLS()) { + if (d_handler->hasTLSSessionBeenResumed()) { + ++d_ds->tlsResumptions; + } + try { + auto sessions = d_handler->getTLSSessions(); + if (!sessions.empty()) { + tlsSession = std::move(sessions.back()); + sessions.pop_back(); + if (!sessions.empty()) { + g_sessionCache.putSessions(d_ds->getID(), time(nullptr), std::move(sessions)); + } + } + } + catch (const std::exception& e) { + vinfolog("Unable to get a TLS session to resume: %s", e.what()); + } + } + d_handler->close(); + d_ioState.reset(); + --d_ds->tcpCurrentConnections; + } + + d_fresh = true; + d_highestStreamID = 0; + d_proxyProtocolPayloadSent = false; + + do { + vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures); + DEBUGLOG("Opening TCP connection to backend "<getNameWithAddr()); + ++d_ds->tcpNewConnections; + try { + auto socket = std::make_unique(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0); + DEBUGLOG("result of socket() is "<getHandle()); + + if (!IsAnyAddress(d_ds->sourceAddr)) { + SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); +#ifdef IP_BIND_ADDRESS_NO_PORT + if (d_ds->ipBindAddrNoPort) { + SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); + } +#endif +#ifdef SO_BINDTODEVICE + if (!d_ds->sourceItfName.empty()) { + int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length()); + if (res != 0) { + vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror()); + } + } +#endif + socket->bind(d_ds->sourceAddr, false); + } + socket->setNonBlocking(); + + gettimeofday(&d_connectionStartTime, nullptr); + auto handler = std::make_unique(d_ds->d_tlsSubjectName, socket->releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec); + if (!tlsSession && d_ds->d_tlsCtx) { + tlsSession = g_sessionCache.getSession(d_ds->getID(), d_connectionStartTime.tv_sec); + } + if (tlsSession) { + handler->setTLSSession(tlsSession); + } + handler->tryConnect(d_ds->tcpFastOpen && isFastOpenEnabled(), d_ds->remote); + d_queries = 0; + + d_handler = std::move(handler); + d_ds->incCurrentConnectionsCount(); + return true; + } + catch (const std::runtime_error& e) { + vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what()); + d_downstreamFailures++; + if (d_downstreamFailures >= d_ds->d_retries) { + throw; + } + } + } + while (d_downstreamFailures < d_ds->d_retries); + + return false; +} + +TCPConnectionToBackend::~TCPConnectionToBackend() +{ + if (d_ds && !d_pendingResponses.empty()) { + d_ds->outstanding -= d_pendingResponses.size(); + } +} + void TCPConnectionToBackend::release() { d_ds->outstanding -= d_pendingResponses.size(); @@ -43,7 +133,6 @@ void TCPConnectionToBackend::release() d_pendingResponses.clear(); d_pendingQueries.clear(); - d_sender.reset(); if (d_ioState) { d_ioState.reset(); } @@ -52,6 +141,9 @@ void TCPConnectionToBackend::release() IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr& conn) { conn->d_currentQuery = std::move(conn->d_pendingQueries.front()); + dnsheader* dh = reinterpret_cast(&conn->d_currentQuery.d_query.d_buffer.at(sizeof(uint16_t) + (conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded ? conn->d_currentQuery.d_query.d_proxyProtocolPayload.size() : 0))); + uint16_t id = conn->d_highestStreamID; + dh->id = htons(id); conn->d_pendingQueries.pop_front(); conn->d_state = State::sendingQueryToBackend; conn->d_currentPos = 0; @@ -63,7 +155,7 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptrgetDS()->getName()<<" over FD "<d_handler->getDescriptor()); - IOState state = conn->d_handler->tryWrite(conn->d_currentQuery.d_buffer, conn->d_currentPos, conn->d_currentQuery.d_buffer.size()); + IOState state = conn->d_handler->tryWrite(conn->d_currentQuery.d_query.d_buffer, conn->d_currentPos, conn->d_currentQuery.d_query.d_buffer.size()); if (state != IOState::Done) { return state; @@ -71,20 +163,22 @@ IOState TCPConnectionToBackend::sendQuery(std::shared_ptrd_currentQuery.d_proxyProtocolPayloadAdded) { + if (conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded) { conn->d_proxyProtocolPayloadSent = true; } ++conn->d_queries; conn->d_currentPos = 0; - DEBUGLOG("adding a pending response for ID "<d_currentQuery.d_idstate.origID)<<" and QNAME "<d_currentQuery.d_idstate.qname); - auto res = conn->d_pendingResponses.insert({ntohs(conn->d_currentQuery.d_idstate.origID), std::move(conn->d_currentQuery)}); - /* if there was already a pending response with that ID, the client messed up and we don't expect more + DEBUGLOG("adding a pending response for ID "<d_highestStreamID<<" and QNAME "<d_currentQuery.d_query.d_idstate.qname); + auto res = conn->d_pendingResponses.insert({conn->d_highestStreamID, std::move(conn->d_currentQuery)}); + /* if there was already a pending response with that ID, we messed up and we don't expect more than one response */ if (res.second) { ++conn->d_ds->outstanding; } - conn->d_currentQuery.d_buffer.clear(); + ++conn->d_highestStreamID; + conn->d_currentQuery.d_sender.reset(); + conn->d_currentQuery.d_query.d_buffer.clear(); return state; } @@ -152,7 +246,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c iostate = conn->handleResponse(conn, now); } catch (const std::exception& e) { - vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what()); + vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_query.d_idstate.origRemote.toStringWithPort(), e.what()); ioGuard.release(); conn->release(); return; @@ -173,7 +267,7 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c but it might also be a real IO error or something else. Let's just drop the connection */ - vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_state == State::sendingQueryToBackend ? "writing to" : "reading from"), conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what()); + vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_state == State::sendingQueryToBackend ? "writing to" : "reading from"), conn->d_currentQuery.d_query.d_idstate.origRemote.toStringWithPort(), e.what()); if (conn->d_state == State::sendingQueryToBackend) { ++conn->d_ds->tcpDiedSendingQuery; @@ -206,14 +300,23 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c conn->d_ioState = make_unique(*conn->d_mplexer, conn->d_handler->getDescriptor()); /* we need to resend the queries that were in flight, if any */ + if (conn->d_state == State::sendingQueryToBackend) { + /* we need to edit this query so it has the correct ID */ + auto query = std::move(conn->d_currentQuery); + dnsheader* dh = reinterpret_cast(&query.d_query.d_buffer.at(sizeof(uint16_t) + (query.d_query.d_proxyProtocolPayloadAdded ? query.d_query.d_proxyProtocolPayload.size() : 0))); + uint16_t id = conn->d_highestStreamID; + dh->id = htons(id); + conn->d_currentQuery = std::move(query); + } + for (auto& pending : conn->d_pendingResponses) { --conn->d_ds->outstanding; - if (pending.second.isXFR() && pending.second.d_xfrStarted) { + if (pending.second.d_query.isXFR() && pending.second.d_query.d_xfrStarted) { /* this one can't be restarted, sorry */ DEBUGLOG("A XFR for which a response has already been sent cannot be restarted"); try { - conn->d_sender->notifyIOError(std::move(pending.second.d_idstate), now); + pending.second.d_sender->notifyIOError(std::move(pending.second.d_query.d_idstate), now); } catch (const std::exception& e) { vinfolog("Got an exception while notifying: %s", e.what()); @@ -241,9 +344,9 @@ void TCPConnectionToBackend::handleIO(std::shared_ptr& c iostate = queueNextQuery(conn); } - if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_proxyProtocolPayload.empty()) { - conn->d_currentQuery.d_buffer.insert(conn->d_currentQuery.d_buffer.begin(), conn->d_currentQuery.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_proxyProtocolPayload.end()); - conn->d_currentQuery.d_proxyProtocolPayloadAdded = true; + if (conn->needProxyProtocolPayload() && !conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !conn->d_currentQuery.d_query.d_proxyProtocolPayload.empty()) { + conn->d_currentQuery.d_query.d_buffer.insert(conn->d_currentQuery.d_query.d_buffer.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.begin(), conn->d_currentQuery.d_query.d_proxyProtocolPayload.end()); + conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true; } reconnected = true; @@ -304,125 +407,39 @@ void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t void TCPConnectionToBackend::queueQuery(std::shared_ptr& sender, TCPQuery&& query) { - if (!d_sender) { - d_sender = sender; + if (!d_ioState) { d_ioState = make_unique(*d_mplexer, d_handler->getDescriptor()); } - else if (d_sender != sender) { - throw std::runtime_error("Assigning a query from a different client to an existing backend connection with pending queries"); - } // if we are not already sending a query or in the middle of reading a response (so idle), // start sending the query if (d_state == State::idle || d_state == State::waitingForResponseFromBackend) { - DEBUGLOG("Sending new query to backend right away"); + DEBUGLOG("Sending new query to backend right away, with ID "<(&query.d_buffer.at(sizeof(uint16_t) + (query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0))); + uint16_t id = d_highestStreamID; + dh->id = htons(id); + d_currentQuery = PendingRequest({sender, std::move(query)}); + + if (needProxyProtocolPayload() && !d_currentQuery.d_query.d_proxyProtocolPayloadAdded && !d_currentQuery.d_query.d_proxyProtocolPayload.empty()) { + d_currentQuery.d_query.d_buffer.insert(d_currentQuery.d_query.d_buffer.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.begin(), d_currentQuery.d_query.d_proxyProtocolPayload.end()); + d_currentQuery.d_query.d_proxyProtocolPayloadAdded = true; } struct timeval now; gettimeofday(&now, 0); - auto shared = shared_from_this(); + auto shared = std::dynamic_pointer_cast(shared_from_this()); handleIO(shared, now); } else { DEBUGLOG("Adding new query to the queue because we are in state "<<(int)d_state); // store query in the list of queries to send - d_pendingQueries.push_back(std::move(query)); + d_pendingQueries.push_back(PendingRequest({sender, std::move(query)})); } } -bool TCPConnectionToBackend::reconnect() -{ - std::unique_ptr tlsSession{nullptr}; - if (d_handler) { - DEBUGLOG("closing socket "<getDescriptor()); - if (d_handler->isTLS()) { - if (d_handler->hasTLSSessionBeenResumed()) { - ++d_ds->tlsResumptions; - } - try { - auto sessions = d_handler->getTLSSessions(); - if (!sessions.empty()) { - tlsSession = std::move(sessions.back()); - sessions.pop_back(); - if (!sessions.empty()) { - g_sessionCache.putSessions(d_ds->getID(), time(nullptr), std::move(sessions)); - } - } - } - catch (const std::exception& e) { - vinfolog("Unable to get a TLS session to resume: %s", e.what()); - } - } - d_handler->close(); - d_ioState.reset(); - --d_ds->tcpCurrentConnections; - } - - d_fresh = true; - d_proxyProtocolPayloadSent = false; - - do { - vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures); - DEBUGLOG("Opening TCP connection to backend "<getNameWithAddr()); - ++d_ds->tcpNewConnections; - try { - auto socket = std::make_unique(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0); - DEBUGLOG("result of socket() is "<getHandle()); - - if (!IsAnyAddress(d_ds->sourceAddr)) { - SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); -#ifdef IP_BIND_ADDRESS_NO_PORT - if (d_ds->ipBindAddrNoPort) { - SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); - } -#endif -#ifdef SO_BINDTODEVICE - if (!d_ds->sourceItfName.empty()) { - int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length()); - if (res != 0) { - vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror()); - } - } -#endif - socket->bind(d_ds->sourceAddr, false); - } - socket->setNonBlocking(); - - gettimeofday(&d_connectionStartTime, nullptr); - auto handler = std::make_unique(d_ds->d_tlsSubjectName, socket->releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec); - if (!tlsSession && d_ds->d_tlsCtx) { - tlsSession = g_sessionCache.getSession(d_ds->getID(), d_connectionStartTime.tv_sec); - } - if (tlsSession) { - handler->setTLSSession(tlsSession); - } - handler->tryConnect(d_ds->tcpFastOpen && isFastOpenEnabled(), d_ds->remote); - d_queries = 0; - - d_handler = std::move(handler); - d_ds->incCurrentConnectionsCount(); - return true; - } - catch (const std::runtime_error& e) { - vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what()); - d_downstreamFailures++; - if (d_downstreamFailures >= d_ds->d_retries) { - throw; - } - } - } - while (d_downstreamFailures < d_ds->d_retries); - - return false; -} - void TCPConnectionToBackend::handleTimeout(const struct timeval& now, bool write) { /* in some cases we could retry, here, reconnecting and sending our pending responses again */ @@ -458,37 +475,49 @@ void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, F { d_connectionDied = true; - auto& sender = d_sender; - if (!sender->active()) { - // a client timeout occurred, or something like that */ - d_sender.reset(); - return; - } - - if (reason == FailureReason::timeout) { - const ClientState* cs = sender->getClientState(); - if (cs) { - ++cs->tcpDownstreamTimeouts; + /* we might be terminated while notifying a query sender */ + d_ds->outstanding -= d_pendingResponses.size(); + auto pendingQueries = std::move(d_pendingQueries); + auto pendingResponses = std::move(d_pendingResponses); + + auto increaseCounters = [reason](std::shared_ptr& sender) { + if (reason == FailureReason::timeout) { + const ClientState* cs = sender->getClientState(); + if (cs) { + ++cs->tcpDownstreamTimeouts; + } } - } - else if (reason == FailureReason::gaveUp) { - const ClientState* cs = sender->getClientState(); - if (cs) { - ++cs->tcpGaveUp; + else if (reason == FailureReason::gaveUp) { + const ClientState* cs = sender->getClientState(); + if (cs) { + ++cs->tcpGaveUp; + } } - } + }; try { if (d_state == State::sendingQueryToBackend) { - sender->notifyIOError(std::move(d_currentQuery.d_idstate), now); + auto sender = d_currentQuery.d_sender; + if (sender->active()) { + increaseCounters(sender); + sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now); + } } - for (auto& query : d_pendingQueries) { - sender->notifyIOError(std::move(query.d_idstate), now); + for (auto& query : pendingQueries) { + auto sender = query.d_sender; + if (sender->active()) { + increaseCounters(sender); + sender->notifyIOError(std::move(query.d_query.d_idstate), now); + } } - for (auto& response : d_pendingResponses) { - sender->notifyIOError(std::move(response.second.d_idstate), now); + for (auto& response : pendingResponses) { + auto sender = response.second.d_sender; + if (sender->active()) { + increaseCounters(sender); + sender->notifyIOError(std::move(response.second.d_query.d_idstate), now); + } } } catch (const std::exception& e) { @@ -527,16 +556,6 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptractive()) { - // a client timeout occurred, or something like that */ - d_connectionDied = true; - - release(); - - return IOState::Done; - } - uint16_t queryId = 0; try { queryId = getQueryIdFromResponse(); @@ -554,19 +573,24 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrsecond.isXFR()) { + auto dh = reinterpret_cast(d_responseBuffer.data()); + dh->id = it->second.d_query.d_idstate.origID; + + auto sender = it->second.d_sender; + + if (sender->active() && it->second.d_query.isXFR()) { DEBUGLOG("XFR!"); bool done = false; TCPResponse response; response.d_buffer = std::move(d_responseBuffer); response.d_connection = conn; /* we don't move the whole IDS because we will need for the responses to come */ - response.d_idstate.qtype = it->second.d_idstate.qtype; - response.d_idstate.qname = it->second.d_idstate.qname; + response.d_idstate.qtype = it->second.d_query.d_idstate.qtype; + response.d_idstate.qname = it->second.d_query.d_idstate.qname; DEBUGLOG("passing XFRresponse to client connection for "<second.d_xfrStarted = true; - done = isXFRFinished(response, it->second); + it->second.d_query.d_xfrStarted = true; + done = isXFRFinished(response, it->second.d_query); if (done) { d_pendingResponses.erase(it); @@ -580,7 +604,6 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrhandleXFRResponse(now, std::move(response)); if (done) { d_state = State::idle; - d_sender.reset(); return IOState::Done; } @@ -592,26 +615,23 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptrd_ds->outstanding; - auto ids = std::move(it->second.d_idstate); + auto ids = std::move(it->second.d_query.d_idstate); d_pendingResponses.erase(it); /* marking as idle for now, so we can accept new queries if our queues are empty */ if (d_pendingQueries.empty() && d_pendingResponses.empty()) { d_state = State::idle; } - DEBUGLOG("passing response to client connection for "<releaseConnection(); - sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn)); + auto shared = conn; + if (sender->active()) { + DEBUGLOG("passing response to client connection for "<handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn)); + } if (!d_pendingQueries.empty()) { DEBUGLOG("still have some queries to send"); - d_state = State::sendingQueryToBackend; - d_currentQuery = std::move(d_pendingQueries.front()); - d_currentPos = 0; - d_pendingQueries.pop_front(); - return IOState::NeedWrite; + return queueNextQuery(shared); } else if (!d_pendingResponses.empty()) { DEBUGLOG("still have some responses to read"); @@ -623,10 +643,6 @@ IOState TCPConnectionToBackend::handleResponse(std::shared_ptr DownstreamConnectionsManager::getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now) { - std::shared_ptr result; struct timeval freshCutOff = now; freshCutOff.tv_sec -= 1; @@ -731,53 +746,47 @@ std::shared_ptr DownstreamConnectionsManager::getConnect const auto& it = t_downstreamConnections.find(backendId); if (it != t_downstreamConnections.end()) { auto& list = it->second; - while (!list.empty()) { - result = std::move(list.back()); - list.pop_back(); + for (auto listIt = list.begin(); listIt != list.end(); ) { + auto& entry = *listIt; + if (!entry->canBeReused()) { + if (!entry->willBeReusable(false)) { + listIt = list.erase(listIt); + } + else { + ++listIt; + } + continue; + } - result->setReused(); + entry->setReused(); /* for connections that have not been used very recently, check whether they have been closed in the meantime */ - if (freshCutOff < result->getLastDataReceivedTime()) { + if (freshCutOff < entry->getLastDataReceivedTime()) { /* used recently enough, skip the check */ ++ds->tcpReusedConnections; - return result; + return entry; } - if (isTCPSocketUsable(result->getHandle())) { + if (isTCPSocketUsable(entry->getHandle())) { ++ds->tcpReusedConnections; - return result; + return entry; + } + else { + listIt = list.erase(listIt); + continue; } /* otherwise let's try the next one, if any */ + ++listIt; } } } - return std::make_shared(ds, mplexer, now); -} - -void DownstreamConnectionsManager::releaseDownstreamConnection(std::shared_ptr&& conn) -{ - if (conn == nullptr) { - return; - } - - if (!conn->canBeReused()) { - conn.reset(); - return; - } - - const auto& ds = conn->getDS(); - { - auto& list = t_downstreamConnections[ds->getID()]; - while (list.size() >= s_maxCachedConnectionsPerDownstream) { - /* too many connections queued already */ - list.pop_front(); - } - - list.push_back(std::move(conn)); + auto newConnection = std::make_shared(ds, mplexer, now); + if (!ds->useProxyProtocol) { + t_downstreamConnections[backendId].push_front(newConnection); } + return newConnection; } void DownstreamConnectionsManager::cleanupClosedTCPConnections(struct timeval now) diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh index c05c1264df..0ea546b3df 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh @@ -7,15 +7,15 @@ #include "dnsdist.hh" #include "dnsdist-tcp.hh" -class TCPConnectionToBackend : public std::enable_shared_from_this +class ConnectionToBackend : public std::enable_shared_from_this { public: - TCPConnectionToBackend(std::shared_ptr& ds, std::unique_ptr& mplexer, const struct timeval& now): d_connectionStartTime(now), d_lastDataReceivedTime(now), d_ds(ds), d_responseBuffer(s_maxPacketCacheEntrySize), d_mplexer(mplexer), d_enableFastOpen(ds->tcpFastOpen) + ConnectionToBackend(std::shared_ptr& ds, std::unique_ptr& mplexer, const struct timeval& now): d_connectionStartTime(now), d_lastDataReceivedTime(now), d_ds(ds), d_mplexer(mplexer), d_enableFastOpen(ds->tcpFastOpen) { reconnect(); } - virtual ~TCPConnectionToBackend(); + virtual ~ConnectionToBackend(); int getHandle() const { @@ -61,43 +61,52 @@ public: return d_enableFastOpen; } - /* whether we can accept new queries FOR THE SAME CLIENT */ - bool canAcceptNewQueries() const + /* whether a connection can be used now */ + bool canBeReused(bool sameClient = false) const { if (d_connectionDied) { return false; } - if ((d_pendingQueries.size() + d_pendingResponses.size()) >= d_ds->d_maxInFlightQueriesPerConn) { + /* we can't reuse a connection where a proxy protocol payload has been sent, + since: + - it cannot be reused for a different client + - we might have different TLV values for each query + */ + if (d_ds && d_ds->useProxyProtocol == true && !sameClient) { return false; } - return true; - } + if (reachedMaxStreamID()) { + return false; + } - bool isIdle() const - { - return d_state == State::idle && d_pendingQueries.size() == 0 && d_pendingResponses.size() == 0; + if (reachedMaxConcurrentQueries()) { + return false; + } + + return true; } - /* whether a connection can be reused for a different client */ - virtual bool canBeReused() const + /* full now but will become usable later */ + bool willBeReusable(bool sameClient) const { - if (d_connectionDied) { + if (d_connectionDied || reachedMaxStreamID()) { return false; } - /* we can't reuse a connection where a proxy protocol payload has been sent, - since: - - it cannot be reused for a different client - - we might have different TLV values for each query - */ + if (d_ds && d_ds->useProxyProtocol == true) { - return false; + return sameClient; } + return true; } - bool matchesTLVs(const std::unique_ptr>& tlvs) const; + virtual bool reachedMaxStreamID() const = 0; + virtual bool reachedMaxConcurrentQueries() const = 0; + virtual void release() + { + } bool matches(const std::shared_ptr& ds) const { @@ -107,44 +116,18 @@ public: return ds == d_ds; } - virtual void queueQuery(std::shared_ptr& sender, TCPQuery&& query); - virtual void handleTimeout(const struct timeval& now, bool write); - void release(); - - void setProxyProtocolValuesSent(std::unique_ptr>&& proxyProtocolValuesSent); + virtual void queueQuery(std::shared_ptr& sender, TCPQuery&& query) = 0; + virtual void handleTimeout(const struct timeval& now, bool write) = 0; struct timeval getLastDataReceivedTime() const { return d_lastDataReceivedTime; } - virtual std::string toString() const - { - ostringstream o; - o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<& conn, const struct timeval& now); - static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); - static IOState queueNextQuery(std::shared_ptr& conn); - static IOState sendQuery(std::shared_ptr& conn, const struct timeval& now); - static bool isXFRFinished(const TCPResponse& response, TCPQuery& query); - - IOState handleResponse(std::shared_ptr& conn, const struct timeval& now); - uint16_t getQueryIdFromResponse() const; bool reconnect(); - void notifyAllQueriesFailed(const struct timeval& now, FailureReason reason); - bool needProxyProtocolPayload() const - { - return !d_proxyProtocolPayloadSent && (d_ds && d_ds->useProxyProtocol); - } boost::optional getBackendHealthCheckTTD(const struct timeval& now) const { @@ -206,34 +189,107 @@ protected: return res; } - TCPQuery d_currentQuery; - std::deque d_pendingQueries; - std::unordered_map d_pendingResponses; struct timeval d_connectionStartTime; struct timeval d_lastDataReceivedTime; std::shared_ptr d_ds{nullptr}; std::shared_ptr d_sender{nullptr}; - PacketBuffer d_responseBuffer; std::unique_ptr& d_mplexer; - std::unique_ptr> d_proxyProtocolValuesSent{nullptr}; std::unique_ptr d_handler{nullptr}; std::unique_ptr d_ioState{nullptr}; - size_t d_currentPos{0}; uint64_t d_queries{0}; - uint64_t d_downstreamFailures{0}; - uint16_t d_responseSize{0}; - State d_state{State::idle}; - bool d_fresh{true}; + uint32_t d_highestStreamID{0}; + uint16_t d_downstreamFailures{0}; + bool d_proxyProtocolPayloadSent{false}; bool d_enableFastOpen{false}; bool d_connectionDied{false}; - bool d_proxyProtocolPayloadSent{false}; + bool d_fresh{true}; +}; + +class TCPConnectionToBackend : public ConnectionToBackend +{ +public: + TCPConnectionToBackend(std::shared_ptr& ds, std::unique_ptr& mplexer, const struct timeval& now): ConnectionToBackend(ds, mplexer, now), d_responseBuffer(s_maxPacketCacheEntrySize) + { + } + + virtual ~TCPConnectionToBackend(); + + bool isIdle() const + { + return d_state == State::idle && d_pendingQueries.size() == 0 && d_pendingResponses.size() == 0; + } + + bool reachedMaxStreamID() const override + { + /* TCP/DoT has only 2^16 usable identifiers, DoH has 2^32 */ + const uint32_t maximumStreamID = std::numeric_limits::max() - 1; + return d_highestStreamID == maximumStreamID; + } + + bool reachedMaxConcurrentQueries() const override + { + const size_t concurrent = d_pendingQueries.size() + d_pendingResponses.size(); + if (concurrent > 0 && concurrent >= d_ds->d_maxInFlightQueriesPerConn) { + return true; + } + return false; + } + bool matchesTLVs(const std::unique_ptr>& tlvs) const; + + void queueQuery(std::shared_ptr& sender, TCPQuery&& query) override; + void handleTimeout(const struct timeval& now, bool write) override; + void release() override; + + std::string toString() const override + { + ostringstream o; + o << "TCP connection to backend "<<(d_ds ? d_ds->getName() : "empty")<<" over FD "<<(d_handler ? std::to_string(d_handler->getDescriptor()) : "no socket")<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<>&& proxyProtocolValuesSent); + +private: + /* waitingForResponseFromBackend is a state where we have not yet started reading the size, + so we can still switch to sending instead */ + enum class State : uint8_t { idle, sendingQueryToBackend, waitingForResponseFromBackend, readingResponseSizeFromBackend, readingResponseFromBackend }; + enum class FailureReason : uint8_t { /* too many attempts */ gaveUp, timeout, unexpectedQueryID }; + + static void handleIO(std::shared_ptr& conn, const struct timeval& now); + static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); + static IOState queueNextQuery(std::shared_ptr& conn); + static IOState sendQuery(std::shared_ptr& conn, const struct timeval& now); + static bool isXFRFinished(const TCPResponse& response, TCPQuery& query); + + IOState handleResponse(std::shared_ptr& conn, const struct timeval& now); + uint16_t getQueryIdFromResponse() const; + void notifyAllQueriesFailed(const struct timeval& now, FailureReason reason); + bool needProxyProtocolPayload() const + { + return !d_proxyProtocolPayloadSent && (d_ds && d_ds->useProxyProtocol); + } + + class PendingRequest + { + public: + std::shared_ptr d_sender{nullptr}; + TCPQuery d_query; + }; + + PacketBuffer d_responseBuffer; + std::deque d_pendingQueries; + std::unordered_map d_pendingResponses; + std::unique_ptr> d_proxyProtocolValuesSent{nullptr}; + PendingRequest d_currentQuery; + size_t d_currentPos{0}; + uint16_t d_responseSize{0}; + State d_state{State::idle}; }; class DownstreamConnectionsManager { public: static std::shared_ptr getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now); - static void releaseDownstreamConnection(std::shared_ptr&& conn); static void cleanupClosedTCPConnections(struct timeval now); static size_t clear(); diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 99e933aa2c..9ed8b6b3fa 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -105,9 +105,9 @@ public: return false; } - std::shared_ptr getActiveDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs); + std::shared_ptr getOwnedDownstreamConnection(const std::shared_ptr& ds, const std::unique_ptr>& tlvs); std::shared_ptr getDownstreamConnection(std::shared_ptr& ds, const std::unique_ptr>& tlvs, const struct timeval& now); - void registerActiveDownstreamConnection(std::shared_ptr& conn); + void registerOwnedDownstreamConnection(std::shared_ptr& conn); static size_t clearAllDownstreamConnections(); @@ -141,14 +141,14 @@ static void handleTimeout(std::shared_ptr& state, bo std::string toString() const { ostringstream o; - o << "Incoming TCP connection from "<getState() : "empty")<<", queries count is "<getState() : "empty")<<", queries count is "<, std::deque>> d_activeConnectionsToBackend; + std::map, std::deque>> d_ownedConnectionsToBackend; std::deque d_queuedResponses; PacketBuffer d_buffer; ConnectionInfo d_ci; diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index b9c86c916f..9154f2f650 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -121,7 +121,7 @@ struct InternalQuery using TCPQuery = InternalQuery; -class TCPConnectionToBackend; +class ConnectionToBackend; struct TCPResponse : public TCPQuery { @@ -131,13 +131,13 @@ struct TCPResponse : public TCPQuery memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); } - TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr conn) : + TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr conn) : TCPQuery(std::move(buffer), std::move(state)), d_connection(conn) { memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); } - std::shared_ptr d_connection{nullptr}; + std::shared_ptr d_connection{nullptr}; dnsheader d_cleartextDH; bool d_selfGenerated{false}; }; diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index b17b1944c7..4d6224149c 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -419,9 +419,25 @@ static ComboAddress getBackendAddress(const std::string& lastDigit, uint16_t por return ComboAddress("192.0.2." + lastDigit, port); } +static void appendPayloadEditingID(PacketBuffer& buffer, const PacketBuffer& payload, uint16_t newID) +{ + PacketBuffer newPayload(payload); + auto dh = reinterpret_cast(&newPayload.at(sizeof(uint16_t))); + dh->id = htons(newID); + buffer.insert(buffer.end(), newPayload.begin(), newPayload.end()); +} + +static void prependPayloadEditingID(PacketBuffer& buffer, const PacketBuffer& payload, uint16_t newID) +{ + PacketBuffer newPayload(payload); + auto dh = reinterpret_cast(&newPayload.at(sizeof(uint16_t))); + dh->id = htons(newID); + buffer.insert(buffer.begin(), newPayload.begin(), newPayload.end()); +} + static void testInit(const std::string& name, TCPClientThreadData& threadData) { -#if 0 +#ifdef DEBUGLOG_ENABLED cerr<(); } @@ -786,8 +803,6 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionWithProxyProtocol_SelfAnswered) s_readBuffer = query; // preprend the proxy protocol payload s_readBuffer.insert(s_readBuffer.begin(), proxyPayload.begin(), proxyPayload.end()); - // append a second query - s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, @@ -836,6 +851,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) PacketBuffer query; GenericDNSPacketWriter pwQ(query, DNSName("powerdns.com."), QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; + pwQ.getHeader()->id = 0; auto shortQuery = query; shortQuery.resize(sizeof(dnsheader) - 1); @@ -1083,11 +1099,11 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) TEST_INIT("=> Short read and write to backend"); s_readBuffer = query; // append a second query - s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); + appendPayloadEditingID(s_readBuffer, query, 1); s_backendReadBuffer = query; // append a second query - s_backendReadBuffer.insert(s_backendReadBuffer.end(), query.begin(), query.end()); + appendPayloadEditingID(s_backendReadBuffer, query, 1); s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, @@ -1629,8 +1645,8 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnection_BackendNoOOOR) s_readBuffer = query; for (size_t idx = 0; idx < count; idx++) { - s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), query.begin(), query.end()); + appendPayloadEditingID(s_readBuffer, query, idx); + appendPayloadEditingID(s_backendReadBuffer, query, idx); } s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, @@ -1716,7 +1732,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) for (auto& query : queries) { GenericDNSPacketWriter pwQ(query, DNSName("powerdns" + std::to_string(counter) + ".com."), QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; - pwQ.getHeader()->id = counter; + pwQ.getHeader()->id = htons(counter); uint16_t querySize = static_cast(query.size()); const uint8_t sizeBytes[] = { static_cast(querySize / 256), static_cast(querySize % 256) }; query.insert(query.begin(), sizeBytes, sizeBytes + 2); @@ -1732,7 +1748,7 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) pwR.getHeader()->qr = 1; pwR.getHeader()->rd = 1; pwR.getHeader()->ra = 1; - pwR.getHeader()->id = counter; + pwR.getHeader()->id = htons(counter); pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER); pwR.xfr32BitInt(0x01020304); pwR.commit(); @@ -1749,16 +1765,18 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) PacketBuffer expectedWriteBuffer; PacketBuffer expectedBackendWriteBuffer; + uint16_t backendCounter = 0; for (const auto& query : queries) { s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); + appendPayloadEditingID(expectedBackendWriteBuffer, query, backendCounter++); } - expectedBackendWriteBuffer = s_readBuffer; + backendCounter = 0; for (const auto& response : responses) { /* reverse order */ - s_backendReadBuffer.insert(s_backendReadBuffer.begin(), response.begin(), response.end()); + prependPayloadEditingID(s_backendReadBuffer, response, backendCounter++); + expectedWriteBuffer.insert(expectedWriteBuffer.begin(), response.begin(), response.end()); } - expectedWriteBuffer = s_backendReadBuffer; s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, @@ -1884,8 +1902,14 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); } - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end()); + uint16_t backendCounter = 0; + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), backendCounter++); + + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0); + appendPayloadEditingID(s_backendReadBuffer, responses.at(4), 3); /* self-answered */ expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(3).begin(), responses.at(3).end()); @@ -1893,12 +1917,6 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end()); expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(4).begin(), responses.at(4).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(4).begin(), queries.at(4).end()); - - bool timeout = false; s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, @@ -2027,13 +2045,24 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) for (const auto& query : queries) { s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); } - expectedBackendWriteBuffer = s_readBuffer; - for (const auto& response : responses) { expectedWriteBuffer.insert(expectedWriteBuffer.end(), response.begin(), response.end()); } - s_backendReadBuffer = expectedWriteBuffer; + uint16_t backendCounter = 0; + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), backendCounter); + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), backendCounter); + appendPayloadEditingID(s_backendReadBuffer, responses.at(1), backendCounter++); + + // new connection + backendCounter = 0; + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), backendCounter); + appendPayloadEditingID(s_backendReadBuffer, responses.at(2), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(3), backendCounter); + appendPayloadEditingID(s_backendReadBuffer, responses.at(3), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), backendCounter); + appendPayloadEditingID(s_backendReadBuffer, responses.at(4), backendCounter++); bool timeout = false; int backendDesc; @@ -2341,15 +2370,18 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) s_readBuffer.insert(s_readBuffer.end(), queries.at(1).begin(), queries.at(1).end()); s_readBuffer.insert(s_readBuffer.end(), queries.at(4).begin(), queries.at(4).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(4).begin(), queries.at(4).end()); + uint16_t backendCounter = 0; + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), backendCounter++); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end()); + appendPayloadEditingID(s_backendReadBuffer, responses.at(1), 1); + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0); + appendPayloadEditingID(s_backendReadBuffer, responses.at(4), 2); - expectedWriteBuffer = s_backendReadBuffer; + appendPayloadEditingID(expectedWriteBuffer, responses.at(1), 1); + appendPayloadEditingID(expectedWriteBuffer, responses.at(0), 0); + appendPayloadEditingID(expectedWriteBuffer, responses.at(4), 4); /* make sure that the backend's timeout is longer than the client's */ backend->tcpRecvTimeout = 30; @@ -2713,14 +2745,16 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) s_readBuffer = axfrQuery; s_readBuffer.insert(s_readBuffer.end(), secondQuery.begin(), secondQuery.end()); - expectedBackendWriteBuffer = s_readBuffer; + uint16_t backendCounter = 0; + appendPayloadEditingID(expectedBackendWriteBuffer, axfrQuery, backendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, secondQuery, backendCounter++); for (const auto& response : axfrResponses) { - s_backendReadBuffer.insert(s_backendReadBuffer.end(), response.begin(), response.end()); + appendPayloadEditingID(s_backendReadBuffer, response, 0); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), response.begin(), response.end()); } - s_backendReadBuffer.insert(s_backendReadBuffer.end(), secondResponse.begin(), secondResponse.end()); - - expectedWriteBuffer = s_backendReadBuffer; + appendPayloadEditingID(s_backendReadBuffer, secondResponse, 1); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), secondResponse.begin(), secondResponse.end()); bool timeout = false; s_steps = { @@ -2973,15 +3007,18 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) s_readBuffer.insert(s_readBuffer.end(), ixfrQuery.begin(), ixfrQuery.end()); s_readBuffer.insert(s_readBuffer.end(), secondQuery.begin(), secondQuery.end()); - expectedBackendWriteBuffer = s_readBuffer; + appendPayloadEditingID(expectedBackendWriteBuffer, firstQuery, 0); + appendPayloadEditingID(expectedBackendWriteBuffer, ixfrQuery, 1); + appendPayloadEditingID(expectedBackendWriteBuffer, secondQuery, 2); - s_backendReadBuffer = firstResponse; + appendPayloadEditingID(s_backendReadBuffer, firstResponse, 0); + expectedWriteBuffer.insert(expectedWriteBuffer.begin(), firstResponse.begin(), firstResponse.end()); for (const auto& response : ixfrResponses) { - s_backendReadBuffer.insert(s_backendReadBuffer.end(), response.begin(), response.end()); + appendPayloadEditingID(s_backendReadBuffer, response, 1); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), response.begin(), response.end()); } - s_backendReadBuffer.insert(s_backendReadBuffer.end(), secondResponse.begin(), secondResponse.end()); - - expectedWriteBuffer = s_backendReadBuffer; + appendPayloadEditingID(s_backendReadBuffer, secondResponse, 2); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), secondResponse.begin(), secondResponse.end()); bool timeout = false; s_steps = { @@ -3083,19 +3120,22 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) proxyEnabledBackend->useProxyProtocol = true; expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end()); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 1); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 2); expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end()); /* we are using an unordered_map, so all bets are off here :-/ */ - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end()); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 1); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end()); + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0); + /* after the reconnection */ + appendPayloadEditingID(s_backendReadBuffer, responses.at(1), 1); + appendPayloadEditingID(s_backendReadBuffer, responses.at(2), 0); - expectedWriteBuffer = s_backendReadBuffer; + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(1).begin(), responses.at(1).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(2).begin(), responses.at(2).end()); s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, @@ -3204,16 +3244,12 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) proxyEnabledBackend->d_tlsCtx = tlsCtx; /* enable out-of-order on the backend side as well */ proxyEnabledBackend->d_maxInFlightQueriesPerConn = 65536; - proxyEnabledBackend-> useProxyProtocol = true; - - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(0).begin(), queries.at(0).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(1).begin(), queries.at(1).end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end()); + proxyEnabledBackend->useProxyProtocol = true; expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), proxyPayload.begin(), proxyPayload.end()); - expectedBackendWriteBuffer.insert(expectedBackendWriteBuffer.end(), queries.at(2).begin(), queries.at(2).end()); - //s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end()); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 1); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 2); s_steps = { { ExpectedStep::ExpectedRequest::handshakeClient, IOState::Done }, @@ -3245,31 +3281,10 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) }}, /* client closes the connection */ { ExpectedStep::ExpectedRequest::readFromClient, IOState::Done, 0 }, - /* closing the client connection */ - { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 }, - /* try to read response from backend, connection has been closed */ - { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, 0 }, - //{ ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(2).size() }, /* closing the backend connection */ { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done }, - { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done }, - { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done }, - { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done }, - { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done }, - /* sending query (3) to the backend */ - { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, proxyPayload.size() + queries.at(2).size() }, - /* sending query (2) to the backend */ - { ExpectedStep::ExpectedRequest::writeToBackend, IOState::Done, 0 }, - { ExpectedStep::ExpectedRequest::closeBackend, IOState::Done, 0 }, + /* closing the client connection */ + { ExpectedStep::ExpectedRequest::closeClient, IOState::Done, 0 }, }; s_processQuery = [proxyEnabledBackend](DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) -> ProcessQueryResult { @@ -3371,8 +3386,8 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) g_tcpRecvTimeout = 2; /* we need to clear them now, otherwise we end up with dangling pointers to the steps via the TLS context, etc */ - /* we should have nothing to clear since the connection cannot be reused due to the Proxy Protocol payload */ - BOOST_CHECK_EQUAL(IncomingTCPConnectionState::clearAllDownstreamConnections(), 0U); + /* we have one connection to clear, no proxy protocol */ + BOOST_CHECK_EQUAL(IncomingTCPConnectionState::clearAllDownstreamConnections(), 1U); } { @@ -3383,15 +3398,29 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) for (const auto& query : queries) { s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); } - expectedBackendWriteBuffer = s_readBuffer; - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(3).begin(), responses.at(3).end()); + /* queries 0, 1 and 4 are sent to the first backend, 2 and 3 to the second */ + uint16_t firstBackendCounter = 0; + uint16_t secondBackendCounter = 0; + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), firstBackendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), firstBackendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), secondBackendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(3), secondBackendCounter++); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), firstBackendCounter++); + + firstBackendCounter = 0; + secondBackendCounter = 0; + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), firstBackendCounter++); + appendPayloadEditingID(s_backendReadBuffer, responses.at(1), firstBackendCounter++); + appendPayloadEditingID(s_backendReadBuffer, responses.at(2), secondBackendCounter++); + appendPayloadEditingID(s_backendReadBuffer, responses.at(4), firstBackendCounter++); + appendPayloadEditingID(s_backendReadBuffer, responses.at(3), secondBackendCounter++); - expectedWriteBuffer = s_backendReadBuffer; + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(1).begin(), responses.at(1).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(2).begin(), responses.at(2).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(4).begin(), responses.at(4).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(3).begin(), responses.at(3).end()); auto backend1 = std::make_shared(getBackendAddress("42", 53), ComboAddress("0.0.0.0:0"), 0, std::string(), 1, false); backend1->d_tlsCtx = tlsCtx; @@ -3539,17 +3568,21 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) } { - TEST_INIT("=> 2 OOOR queries to the backend with duplicated IDs, backend responds to only one of them"); + TEST_INIT("=> 2 OOOR queries to the backend with duplicated IDs"); PacketBuffer expectedWriteBuffer; PacketBuffer expectedBackendWriteBuffer; s_readBuffer.insert(s_readBuffer.end(), queries.at(0).begin(), queries.at(0).end()); s_readBuffer.insert(s_readBuffer.end(), queries.at(0).begin(), queries.at(0).end()); - expectedBackendWriteBuffer = s_readBuffer; + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 1); - s_backendReadBuffer.insert(s_backendReadBuffer.begin(), responses.at(0).begin(), responses.at(0).end()); - expectedWriteBuffer = s_backendReadBuffer; + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0); + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 1); + + appendPayloadEditingID(expectedWriteBuffer, responses.at(0), 0); + appendPayloadEditingID(expectedWriteBuffer, responses.at(0), 0); bool timeout = false; s_steps = { @@ -3575,7 +3608,12 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendOOOR) /* nothing more from the client either */ { ExpectedStep::ExpectedRequest::readFromClient, IOState::NeedRead, 0 }, - /* reading a response from the backend */ + /* reading response (1) from the backend */ + { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size() - 2 }, + { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size()}, + /* sending it to the client */ + { ExpectedStep::ExpectedRequest::writeToClient, IOState::Done, responses.at(0).size()}, + /* reading response (2) from the backend */ { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size() - 2 }, { ExpectedStep::ExpectedRequest::readFromBackend, IOState::Done, responses.at(0).size(), [&threadData](int desc, const ExpectedStep& step) { dynamic_cast(threadData.mplexer.get())->setNotReady(desc); @@ -3691,15 +3729,24 @@ BOOST_AUTO_TEST_CASE(test_IncomingConnectionOOOR_BackendNotOOOR) for (const auto& query : queries) { s_readBuffer.insert(s_readBuffer.end(), query.begin(), query.end()); } - expectedBackendWriteBuffer = s_readBuffer; - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(0).begin(), responses.at(0).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(2).begin(), responses.at(2).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(1).begin(), responses.at(1).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(4).begin(), responses.at(4).end()); - s_backendReadBuffer.insert(s_backendReadBuffer.end(), responses.at(3).begin(), responses.at(3).end()); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(0), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(1), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(2), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(3), 0); + appendPayloadEditingID(expectedBackendWriteBuffer, queries.at(4), 0); - expectedWriteBuffer = s_backendReadBuffer; + appendPayloadEditingID(s_backendReadBuffer, responses.at(0), 0); + appendPayloadEditingID(s_backendReadBuffer, responses.at(2), 0); + appendPayloadEditingID(s_backendReadBuffer, responses.at(1), 0); + appendPayloadEditingID(s_backendReadBuffer, responses.at(4), 0); + appendPayloadEditingID(s_backendReadBuffer, responses.at(3), 0); + + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(0).begin(), responses.at(0).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(2).begin(), responses.at(2).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(1).begin(), responses.at(1).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(4).begin(), responses.at(4).end()); + expectedWriteBuffer.insert(expectedWriteBuffer.end(), responses.at(3).begin(), responses.at(3).end()); std::vector backendDescriptors = { -1, -1, -1, -1, -1 }; diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index b4ea7be4b9..d4897e8667 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -240,12 +240,66 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): sock.close() @classmethod - def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None): + def handleTCPConnection(cls, conn, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None): + ignoreTrailing = trailingDataResponse is True + data = conn.recv(2) + if not data: + conn.close() + return + + (datalen,) = struct.unpack("!H", data) + data = conn.recv(datalen) + forceRcode = None + try: + request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) + except dns.message.TrailingJunk as e: + if trailingDataResponse is False or forceRcode is True: + raise + print("TCP query with trailing data, synthesizing response") + request = dns.message.from_wire(data, ignore_trailing=True) + forceRcode = trailingDataResponse + + if callback: + wire = callback(request) + else: + response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) + if response: + wire = response.to_wire(max_size=65535) + + if not wire: + conn.close() + return + + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + + while multipleResponses: + if fromQueue.empty(): + break + + response = fromQueue.get(True, cls._queueTimeout) + if not response: + break + + response = copy.copy(response) + response.id = request.id + wire = response.to_wire(max_size=65535) + try: + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + except socket.error as e: + # some of the tests are going to close + # the connection on us, just deal with it + break + + conn.close() + + @classmethod + def TCPResponder(cls, port, fromQueue, toQueue, trailingDataResponse=False, multipleResponses=False, callback=None, tlsContext=None, multipleConnections=False): # trailingDataResponse=True means "ignore trailing data". # Other values are either False (meaning "raise an exception") # or are interpreted as a response RCODE for queries with trailing data. # callback is invoked for every -even healthcheck ones- query and should return a raw response - ignoreTrailing = trailingDataResponse is True sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -269,57 +323,14 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): continue conn.settimeout(5.0) - data = conn.recv(2) - if not data: - conn.close() - continue - - (datalen,) = struct.unpack("!H", data) - data = conn.recv(datalen) - forceRcode = None - try: - request = dns.message.from_wire(data, ignore_trailing=ignoreTrailing) - except dns.message.TrailingJunk as e: - if trailingDataResponse is False or forceRcode is True: - raise - print("TCP query with trailing data, synthesizing response") - request = dns.message.from_wire(data, ignore_trailing=True) - forceRcode = trailingDataResponse - - if callback: - wire = callback(request) + if multipleConnections: + thread = threading.Thread(name='TCP Connection Handler', + target=cls.handleTCPConnection, + args=[conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback]) + thread.setDaemon(True) + thread.start() else: - response = cls._getResponse(request, fromQueue, toQueue, synthesize=forceRcode) - if response: - wire = response.to_wire(max_size=65535) - - if not wire: - conn.close() - continue - - conn.send(struct.pack("!H", len(wire))) - conn.send(wire) - - while multipleResponses: - if fromQueue.empty(): - break - - response = fromQueue.get(True, cls._queueTimeout) - if not response: - break - - response = copy.copy(response) - response.id = request.id - wire = response.to_wire(max_size=65535) - try: - conn.send(struct.pack("!H", len(wire))) - conn.send(wire) - except socket.error as e: - # some of the tests are going to close - # the connection on us, just deal with it - break - - conn.close() + cls.handleTCPConnection(conn, fromQueue, toQueue, trailingDataResponse, multipleResponses, callback) sock.close() diff --git a/regression-tests.dnsdist/test_AXFR.py b/regression-tests.dnsdist/test_AXFR.py index c4e43c7bad..8076dd1524 100644 --- a/regression-tests.dnsdist/test_AXFR.py +++ b/regression-tests.dnsdist/test_AXFR.py @@ -13,6 +13,7 @@ class TestAXFR(DNSDistTest): _config_template = """ newServer{address="127.0.0.1:%s"} """ + @classmethod def startResponders(cls): print("Launching responders..") @@ -20,7 +21,7 @@ class TestAXFR(DNSDistTest): cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue]) cls._UDPResponder.setDaemon(True) cls._UDPResponder.start() - cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, True]) + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, True, None, None, True]) cls._TCPResponder.setDaemon(True) cls._TCPResponder.start() diff --git a/regression-tests.dnsdist/test_DynBlocks.py b/regression-tests.dnsdist/test_DynBlocks.py index 483d555fcd..788a0b17d7 100644 --- a/regression-tests.dnsdist/test_DynBlocks.py +++ b/regression-tests.dnsdist/test_DynBlocks.py @@ -780,6 +780,7 @@ class TestDynBlockQPSActionTruncated(DNSDistTest): # check over TCP, which should not be truncated (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, response) @@ -798,6 +799,7 @@ class TestDynBlockQPSActionTruncated(DNSDistTest): for _ in range((self._dynBlockQPS * self._dynBlockPeriod) + 1): (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) sent = sent + 1 + receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, response) receivedQuery.id = query.id diff --git a/regression-tests.dnsdist/test_OutgoingTLS.py b/regression-tests.dnsdist/test_OutgoingTLS.py index 74998e87e4..53677a6612 100644 --- a/regression-tests.dnsdist/test_OutgoingTLS.py +++ b/regression-tests.dnsdist/test_OutgoingTLS.py @@ -52,6 +52,7 @@ class OutgoingTLSTests(object): numberOfUDPQueries = 10 for _ in range(numberOfUDPQueries): (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse) + receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, expectedResponse) @@ -82,6 +83,7 @@ class OutgoingTLSTests(object): expectedResponse.answer.append(rrset) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse) + receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, expectedResponse) self.checkOnlyTLSResponderHit() diff --git a/regression-tests.dnsdist/test_TCPOnly.py b/regression-tests.dnsdist/test_TCPOnly.py index 7ada4cacd6..91edd50447 100644 --- a/regression-tests.dnsdist/test_TCPOnly.py +++ b/regression-tests.dnsdist/test_TCPOnly.py @@ -24,6 +24,7 @@ class TestTCPOnly(DNSDistTest): expectedResponse.answer.append(rrset) (receivedQuery, receivedResponse) = self.sendUDPQuery(query, expectedResponse) + receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, expectedResponse) @@ -47,6 +48,7 @@ class TestTCPOnly(DNSDistTest): expectedResponse.answer.append(rrset) (receivedQuery, receivedResponse) = self.sendTCPQuery(query, expectedResponse) + receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(receivedResponse, expectedResponse) if 'UDP Responder' in self._responsesCounter: