From: Remi Gacogne Date: Fri, 5 Jun 2020 15:58:31 +0000 (+0200) Subject: dnsdist: TCP out-of-order implementation X-Git-Tag: auth-4.5.0-alpha0~14^2~23 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=086a43eb82acdfa35b1d749bde9b2f97734c6d9f;p=thirdparty%2Fpdns.git dnsdist: TCP out-of-order implementation --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 048339929a..3533bccda5 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -19,27 +19,27 @@ * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ + +#include +#include +#include + #include "dnsdist.hh" #include "dnsdist-ecs.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rings.hh" +#include "dnsdist-tcp-downstream.hh" +#include "dnsdist-tcp-upstream.hh" #include "dnsdist-xpf.hh" - #include "dnsparser.hh" -#include "ednsoptions.hh" #include "dolog.hh" -#include "lock.hh" +#include "ednsoptions.hh" #include "gettime.hh" +#include "lock.hh" +#include "sstuff.hh" #include "tcpiohandler.hh" +#include "tcpiohandler-mplexer.hh" #include "threadname.hh" -#include -#include -#include - -#include "sstuff.hh" - -using std::thread; -using std::atomic; /* TCP: the grand design. We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops. @@ -58,7 +58,7 @@ using std::atomic; static std::mutex tcpClientsCountMutex; static std::map tcpClientsCount; -static const size_t g_maxCachedConnectionsPerDownstream = 20; + uint64_t g_maxTCPQueuedConnections{1000}; size_t g_maxTCPQueriesPerConn{0}; size_t g_maxTCPConnectionDuration{0}; @@ -66,243 +66,174 @@ size_t g_maxTCPConnectionsPerClient{0}; uint16_t g_downstreamTCPCleanupInterval{60}; bool g_useTCPSinglePipe{false}; -static std::unique_ptr setupTCPDownstream(shared_ptr& ds, uint16_t& downstreamFailures) +class DownstreamConnectionsManager { - std::unique_ptr result; +public: - do { - vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures); - try { - result = std::unique_ptr(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0)); - if (!IsAnyAddress(ds->sourceAddr)) { - SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); -#ifdef IP_BIND_ADDRESS_NO_PORT - if (ds->ipBindAddrNoPort) { - SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); - } -#endif -#ifdef SO_BINDTODEVICE - if (!ds->sourceItfName.empty()) { - int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, ds->sourceItfName.c_str(), ds->sourceItfName.length()); - if (res != 0) { - vinfolog("Error setting up the interface on backend TCP socket '%s': %s", ds->getNameWithAddr(), stringerror()); - } - } -#endif - result->bind(ds->sourceAddr, false); - } - result->setNonBlocking(); -#ifdef MSG_FASTOPEN - if (!ds->tcpFastOpen) { - SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0); - } -#else - SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0); -#endif /* MSG_FASTOPEN */ - return result; - } - catch(const std::runtime_error& e) { - vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what()); - downstreamFailures++; - if (downstreamFailures > ds->retries) { - throw; + static std::unique_ptr getConnectionToDownstream(std::unique_ptr& mplexer, std::shared_ptr& ds, const struct timeval& now) + { + std::unique_ptr result; + + const auto& it = t_downstreamConnections.find(ds->remote); + if (it != t_downstreamConnections.end()) { + auto& list = it->second; + if (!list.empty()) { + result = std::move(list.front()); + list.pop_front(); + result->setReused(); + return result; } } - } while(downstreamFailures <= ds->retries); - return nullptr; -} - -class TCPConnectionToBackend -{ -public: - TCPConnectionToBackend(std::shared_ptr& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen) - { - d_socket = setupTCPDownstream(d_ds, downstreamFailures); - ++d_ds->tcpCurrentConnections; + return make_unique(ds, now); } - ~TCPConnectionToBackend() + static void releaseDownstreamConnection(std::unique_ptr&& conn) { - if (d_ds && d_socket) { - --d_ds->tcpCurrentConnections; - struct timeval now; - gettimeofday(&now, nullptr); - - auto diff = now - d_connectionStartTime; - d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000); + if (conn == nullptr) { + return; } - } - int getHandle() const - { - if (!d_socket) { - throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection"); + if (!conn->canBeReused()) { + conn.reset(); + return; } - return d_socket->getHandle(); - } - - const ComboAddress& getRemote() const - { - return d_ds->remote; - } - - bool isFresh() const - { - return d_fresh; - } - - void incQueries() - { - ++d_queries; - } - - void setReused() - { - d_fresh = false; - } - - void disableFastOpen() - { - d_enableFastOpen = false; - } - - bool isFastOpenEnabled() - { - return d_enableFastOpen; - } - - bool canBeReused() const - { - /* we can't reuse a connection where a proxy protocol payload has been sent, - since: - - it cannot be reused for a different client - - we might have different TLV values for each query - */ - if (d_ds && d_ds->useProxyProtocol) { - return false; + const auto& remote = conn->getRemote(); + const auto& it = t_downstreamConnections.find(remote); + if (it != t_downstreamConnections.end()) { + auto& list = it->second; + if (list.size() >= s_maxCachedConnectionsPerDownstream) { + /* too many connections queued already */ + conn.reset(); + return; + } + list.push_back(std::move(conn)); + } + else { + t_downstreamConnections[remote].push_back(std::move(conn)); } - return true; } - bool matches(const std::shared_ptr& ds) const + static void cleanupClosedTCPConnections() { - if (!ds || !d_ds) { - return false; + for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) { + for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) { + if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) { + ++connIt; + } + else { + connIt = dsIt->second.erase(connIt); + } + } + + if (!dsIt->second.empty()) { + ++dsIt; + } + else { + dsIt = t_downstreamConnections.erase(dsIt); + } } - return ds == d_ds; } private: - std::unique_ptr d_socket{nullptr}; - std::shared_ptr d_ds{nullptr}; - struct timeval d_connectionStartTime; - uint64_t d_queries{0}; - bool d_fresh{true}; - bool d_enableFastOpen{false}; + static thread_local map>> t_downstreamConnections; + static const size_t s_maxCachedConnectionsPerDownstream; }; -static thread_local map>> t_downstreamConnections; +thread_local map>> DownstreamConnectionsManager::t_downstreamConnections; +const size_t DownstreamConnectionsManager::s_maxCachedConnectionsPerDownstream{20}; -static std::unique_ptr getConnectionToDownstream(std::shared_ptr& ds, uint16_t& downstreamFailures, const struct timeval& now) +static void decrementTCPClientCount(const ComboAddress& client) { - std::unique_ptr result; - - const auto& it = t_downstreamConnections.find(ds->remote); - if (it != t_downstreamConnections.end()) { - auto& list = it->second; - if (!list.empty()) { - result = std::move(list.front()); - list.pop_front(); - result->setReused(); - return result; + if (g_maxTCPConnectionsPerClient) { + std::lock_guard lock(tcpClientsCountMutex); + tcpClientsCount.at(client)--; + if (tcpClientsCount[client] == 0) { + tcpClientsCount.erase(client); } } - - return std::unique_ptr(new TCPConnectionToBackend(ds, downstreamFailures, now)); } -static void releaseDownstreamConnection(std::unique_ptr&& conn) +IncomingTCPConnectionState::~IncomingTCPConnectionState() { - if (conn == nullptr) { - return; + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0); + // DEBUG: cerr<<"updated tcp metrics"<canBeReused()) { - conn.reset(); - return; +#if 0 + if (d_ds != nullptr) { + + if (d_downstreamConnection) { + try { + if (d_lastIOState == IOState::NeedRead) { + // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover backend read FD "<getHandle()<removeReadFD(d_downstreamConnection->getHandle()); + } + else if (d_lastIOState == IOState::NeedWrite) { + // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover backend write FD "<getHandle()<removeWriteFD(d_downstreamConnection->getHandle()); + } + } + catch(const FDMultiplexerException& e) { + vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what()); + } + catch(const std::runtime_error& e) { + /* might be thrown by getHandle() */ + vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what()); + } + } } +#endif - const auto& remote = conn->getRemote(); - const auto& it = t_downstreamConnections.find(remote); - if (it != t_downstreamConnections.end()) { - auto& list = it->second; - if (list.size() >= g_maxCachedConnectionsPerDownstream) { - /* too many connections queued already */ - conn.reset(); - return; + // DEBUG: cerr<<"about to remove left over FDs"<removeReadFD(d_ci.fd); + } + else if (d_lastIOState == IOState::NeedWrite) { + // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover client write FD "<removeWriteFD(d_ci.fd); } - list.push_back(std::move(conn)); } - else { - t_downstreamConnections[remote].push_back(std::move(conn)); + catch (const FDMultiplexerException& e) { + vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what()); + } + catch (...) { + vinfolog("Got an unknown exception when trying to remove a pending IO operation on an incoming TCP connection from %s", d_ci.remote.toStringWithPort()); } + // DEBUG: cerr<<"done"< IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr& ds, const struct timeval& now) { - ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1) - { - } - ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd) - { - rhs.cs = nullptr; - rhs.fd = -1; - } + std::shared_ptr downstream{nullptr}; - ConnectionInfo(const ConnectionInfo& rhs) = delete; - ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete; - - ConnectionInfo& operator=(ConnectionInfo&& rhs) - { - remote = rhs.remote; - cs = rhs.cs; - rhs.cs = nullptr; - fd = rhs.fd; - rhs.fd = -1; - return *this; + if (!ds->useProxyProtocol || !d_proxyProtocolPayloadHasTLV) { + downstream = getActiveDownstreamConnection(ds); } - ~ConnectionInfo() - { - if (fd != -1) { - close(fd); - fd = -1; - } - if (cs) { - --cs->tcpCurrentConnections; - } + if (!downstream) { + /* we don't have a connection to this backend active yet, let's ask one (it might not be a fresh one, though) */ + downstream = DownstreamConnectionsManager::getConnectionToDownstream(d_threadData.mplexer, ds, now); } - ComboAddress remote; - ClientState* cs{nullptr}; - int fd{-1}; -}; - -void tcpClientThread(int pipefd); - -static void decrementTCPClientCount(const ComboAddress& client) -{ - if (g_maxTCPConnectionsPerClient) { - std::lock_guard lock(tcpClientsCountMutex); - tcpClientsCount[client]--; - if (tcpClientsCount[client] == 0) { - tcpClientsCount.erase(client); - } - } + return downstream; } +static void tcpClientThread(int pipefd); + void TCPClientCollection::addTCPClientThread() { int pipefds[2] = { -1, -1}; @@ -349,7 +280,7 @@ void TCPClientCollection::addTCPClientThread() } try { - thread t1(tcpClientThread, pipefds[0]); + std::thread t1(tcpClientThread, pipefds[0]); t1.detach(); } catch(const std::runtime_error& e) { @@ -367,27 +298,6 @@ void TCPClientCollection::addTCPClientThread() } } -static void cleanupClosedTCPConnections() -{ - for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) { - for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) { - if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) { - ++connIt; - } - else { - connIt = dsIt->second.erase(connIt); - } - } - - if (!dsIt->second.empty()) { - ++dsIt; - } - else { - dsIt = t_downstreamConnections.erase(dsIt); - } - } -} - /* Tries to read exactly toRead bytes into the buffer, starting at position pos. Updates pos everytime a successful read occurs, throws an std::runtime_error in case of IO error, @@ -395,7 +305,7 @@ static void cleanupClosedTCPConnections() would block. */ // XXX could probably be implemented as a TCPIOHandler -static IOState tryRead(int fd, std::vector& buffer, size_t& pos, size_t toRead) +IOState tryRead(int fd, std::vector& buffer, size_t& pos, size_t toRead) { if (buffer.size() < (pos + toRead)) { throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos)); @@ -426,374 +336,155 @@ static IOState tryRead(int fd, std::vector& buffer, size_t& pos, size_t std::unique_ptr g_tcpclientthreads; -class TCPClientThreadData +static IOState handleResponseSent(std::shared_ptr& state, const struct timeval& now) { -public: - TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent())) - { - } - - LocalHolders holders; - LocalStateHolder > localRespRulactions; - std::unique_ptr mplexer{nullptr}; -}; - -static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param); - -class IncomingTCPConnectionState -{ -public: - IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_responseBuffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now) - { - d_ids.origDest.reset(); - d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family; - socklen_t socklen = d_ids.origDest.getSocklen(); - if (getsockname(d_ci.fd, reinterpret_cast(&d_ids.origDest), &socklen)) { - d_ids.origDest = d_ci.cs->local; - } - } - - IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete; - IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete; - - ~IncomingTCPConnectionState() - { - decrementTCPClientCount(d_ci.remote); - if (d_ci.cs != nullptr) { - struct timeval now; - gettimeofday(&now, nullptr); - - auto diff = now - d_connectionStartTime; - d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0); - } - - if (d_ds != nullptr) { - if (d_outstanding) { - --d_ds->outstanding; - d_outstanding = false; - } - - if (d_downstreamConnection) { - try { - if (d_lastIOState == IOState::NeedRead) { - cerr<<__func__<<": removing leftover backend read FD "<getHandle()<removeReadFD(d_downstreamConnection->getHandle()); - } - else if (d_lastIOState == IOState::NeedWrite) { - cerr<<__func__<<": removing leftover backend write FD "<getHandle()<removeWriteFD(d_downstreamConnection->getHandle()); - } - } - catch(const FDMultiplexerException& e) { - vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what()); - } - catch(const std::runtime_error& e) { - /* might be thrown by getHandle() */ - vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what()); - } - } - } - - try { - if (d_lastIOState == IOState::NeedRead) { - cerr<<__func__<<": removing leftover client read FD "<removeReadFD(d_ci.fd); - } - else if (d_lastIOState == IOState::NeedWrite) { - cerr<<__func__<<": removing leftover client write FD "<removeWriteFD(d_ci.fd); - } - } - catch(const FDMultiplexerException& e) { - vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what()); - } - } - - void resetForNewQuery() - { - d_buffer.resize(sizeof(uint16_t)); - d_currentPos = 0; - d_querySize = 0; - d_responseSize = 0; - d_downstreamFailures = 0; - d_state = State::readingQuerySize; - d_lastIOState = IOState::Done; - d_selfGeneratedResponse = false; - } - - boost::optional getClientReadTTD(struct timeval now) const - { - if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) { - return boost::none; - } - - if (g_maxTCPConnectionDuration > 0) { - auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec; - if (elapsed < 0 || (static_cast(elapsed) >= g_maxTCPConnectionDuration)) { - return now; - } - auto remaining = g_maxTCPConnectionDuration - elapsed; - if (g_tcpRecvTimeout == 0 || remaining <= static_cast(g_tcpRecvTimeout)) { - now.tv_sec += remaining; - return now; - } - } - - now.tv_sec += g_tcpRecvTimeout; - return now; - } - - boost::optional getBackendReadTTD(const struct timeval& now) const - { - if (d_ds == nullptr) { - throw std::runtime_error("getBackendReadTTD() without any backend selected"); - } - if (d_ds->tcpRecvTimeout == 0) { - return boost::none; - } - - struct timeval res = now; - res.tv_sec += d_ds->tcpRecvTimeout; - - return res; - } - - boost::optional getClientWriteTTD(const struct timeval& now) const - { - if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) { - return boost::none; + if (!state->d_isXFR) { + const auto& currentResponse = state->d_currentResponse; + if (state->d_selfGeneratedResponse == false && currentResponse.d_ds) { + struct timespec answertime; + gettime(&answertime); + const auto& ids = currentResponse.d_idstate; + double udiff = ids.sentTime.udiff(); + g_rings.insertResponse(answertime, state->d_ci.remote, ids.qname, ids.qtype, static_cast(udiff), static_cast(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, currentResponse.d_ds->remote); + vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", currentResponse.d_ds->remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff); + } + + switch (currentResponse.d_cleartextDH.rcode) { + case RCode::NXDomain: + ++g_stats.frontendNXDomain; + break; + case RCode::ServFail: + ++g_stats.servfailResponses; + ++g_stats.frontendServFail; + break; + case RCode::NoError: + ++g_stats.frontendNoError; + break; } - struct timeval res = now; - - if (g_maxTCPConnectionDuration > 0) { - auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec; - if (elapsed < 0 || static_cast(elapsed) >= g_maxTCPConnectionDuration) { - return res; - } - auto remaining = g_maxTCPConnectionDuration - elapsed; - if (g_tcpSendTimeout == 0 || remaining <= static_cast(g_tcpSendTimeout)) { - res.tv_sec += remaining; - return res; - } + if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) { + vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn); + return IOState::Done; } - res.tv_sec += g_tcpSendTimeout; - return res; - } - - boost::optional getBackendWriteTTD(const struct timeval& now) const - { - if (d_ds == nullptr) { - throw std::runtime_error("getBackendReadTTD() called without any backend selected"); - } - if (d_ds->tcpSendTimeout == 0) { - return boost::none; + if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); + return IOState::Done; } - - struct timeval res = now; - res.tv_sec += d_ds->tcpSendTimeout; - - return res; } - bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now) - { - if (maxConnectionDuration) { - time_t curtime = now.tv_sec; - unsigned int elapsed = 0; - if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward - elapsed = curtime - d_connectionStartTime.tv_sec; - } - if (elapsed >= maxConnectionDuration) { - return true; - } - d_remainingTime = maxConnectionDuration - elapsed; + if (state->d_queuedResponses.empty()) { + // DEBUG: cerr<<"no response remaining"<d_isXFR) { + /* we should still be reading from the backend, and we don't want to read from the client */ + state->d_state = IncomingTCPConnectionState::State::idle; + state->d_currentPos = 0; + // DEBUG: cerr<<"idling for XFR completion"<resetForNewQuery(); + return IOState::NeedRead; } - - return false; } - - void dump() const - { - static std::mutex s_mutex; - - struct timeval now; - gettimeofday(&now, 0); - - { - std::lock_guard lock(s_mutex); - fprintf(stderr, "State is %p\n", this); - cerr << "Current state is " << static_cast(d_state) << ", got "< State::doingHandshake) { - cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl; - } - if (d_state > State::readingQuerySize) { - cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl; - } - if (d_state > State::readingQuerySize) { - cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl; - } - if (d_state > State::readingQuery) { - cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl; - } - if (d_state > State::sendingQueryToBackend) { - cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl; - } - if (d_state > State::readingResponseFromBackend) { - cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl; - } - } + else { + // DEBUG: cerr<<"queue size is "<d_queuedResponses.size()<d_queuedResponses.front()); + state->d_queuedResponses.pop_front(); + state->d_state = IncomingTCPConnectionState::State::idle; + state->sendResponse(state, now, std::move(resp)); + return IOState::NeedWrite; } +} - enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse }; - - std::vector d_buffer; - std::vector d_responseBuffer; - TCPClientThreadData& d_threadData; - IDState d_ids; - ConnectionInfo d_ci; - TCPIOHandler d_handler; - std::unique_ptr d_downstreamConnection{nullptr}; - std::shared_ptr d_ds{nullptr}; - dnsheader d_cleartextDH; - struct timeval d_connectionStartTime; - struct timeval d_handshakeDoneTime; - struct timeval d_firstQuerySizeReadTime; - struct timeval d_querySizeReadTime; - struct timeval d_queryReadTime; - struct timeval d_querySentTime; - struct timeval d_responseReadTime; - size_t d_currentPos{0}; - size_t d_queriesCount{0}; - unsigned int d_remainingTime{0}; - uint16_t d_querySize{0}; - uint16_t d_responseSize{0}; - uint16_t d_downstreamFailures{0}; - State d_state{State::doingHandshake}; - IOState d_lastIOState{IOState::Done}; - bool d_readingFirstQuery{true}; - bool d_outstanding{false}; - bool d_firstResponsePacket{true}; - bool d_isXFR{false}; - bool d_xfrStarted{false}; - bool d_selfGeneratedResponse{false}; - bool d_proxyProtocolPayloadAdded{false}; - bool d_proxyProtocolPayloadHasTLV{false}; -}; - -static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); -static void handleNewIOState(std::shared_ptr& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional ttd=boost::none); -static void handleIO(std::shared_ptr& state, struct timeval& now); -static void handleDownstreamIO(std::shared_ptr& state, struct timeval& now); - -static void handleResponseSent(std::shared_ptr& state, struct timeval& now) +void IncomingTCPConnectionState::resetForNewQuery() { - handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback); + d_buffer.resize(sizeof(uint16_t)); + d_currentPos = 0; + d_querySize = 0; + //d_responseSize = 0; + d_downstreamFailures = 0; + d_state = State::readingQuerySize; + d_lastIOState = IOState::Done; + d_selfGeneratedResponse = false; +} - if (state->d_isXFR && state->d_downstreamConnection) { - /* we need to resume reading from the backend! */ - state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend; +/* this version is called when the buffer has been set and the rules have been processed */ +void IncomingTCPConnectionState::sendResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response) +{ + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<d_state == IncomingTCPConnectionState::State::idle || + state->d_state == IncomingTCPConnectionState::State::readingQuerySize) { + + state->d_state = IncomingTCPConnectionState::State::sendingResponse; + + uint16_t responseSize = static_cast(response.d_buffer.size()); + const uint8_t sizeBytes[] = { static_cast(responseSize / 256), static_cast(responseSize % 256) }; + /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes + that could occur if we had to deal with the size during the processing, + especially alignment issues */ + response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2); state->d_currentPos = 0; - handleDownstreamIO(state, now); - return; - } - - if (state->d_selfGeneratedResponse == false && state->d_ds) { - /* if we have no downstream server selected, this was a self-answered response - but cache hits have a selected server as well, so be careful */ - struct timespec answertime; - gettime(&answertime); - double udiff = state->d_ids.sentTime.udiff(); - g_rings.insertResponse(answertime, state->d_ci.remote, state->d_ids.qname, state->d_ids.qtype, static_cast(udiff), static_cast(state->d_responseBuffer.size()), state->d_cleartextDH, state->d_ds->remote); - vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", state->d_ds->remote.toStringWithPort(), state->d_ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff); - } + state->d_currentResponse = std::move(response); - switch (state->d_cleartextDH.rcode) { - case RCode::NXDomain: - ++g_stats.frontendNXDomain; - break; - case RCode::ServFail: - ++g_stats.servfailResponses; - ++g_stats.frontendServFail; - break; - case RCode::NoError: - ++g_stats.frontendNoError; - break; - } - - if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) { - vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn); - return; + //IncomingTCPConnectionState::handleIO(state, now); + state->d_ioState->update(IOState::NeedWrite, handleIOCallback, state, getClientWriteTTD(now)); + // DEBUG: cerr<<"updated IO state"<maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { - vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); - return; + else { + // queue response + state->d_queuedResponses.push_back(std::move(response)); + // DEBUG: cerr<<"queueing response because state is "<<(int)state->d_state<<", queue size is now "<d_queuedResponses.size()<resetForNewQuery(); - - handleIO(state, now); -} - -static void sendResponse(std::shared_ptr& state, struct timeval& now) -{ - state->d_state = IncomingTCPConnectionState::State::sendingResponse; - const uint8_t sizeBytes[] = { static_cast(state->d_responseSize / 256), static_cast(state->d_responseSize % 256) }; - /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes - that could occur if we had to deal with the size during the processing, - especially alignment issues */ - state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2); - - state->d_currentPos = 0; - - handleIO(state, now); } -static void handleResponse(std::shared_ptr& state, struct timeval& now) +/* this version is called from the backend code when a new response has been received */ +void IncomingTCPConnectionState::handleResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response) { - if (state->d_responseSize < sizeof(dnsheader) || !state->d_ds) { + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<(&state->d_responseBuffer.at(0)); + uint16_t responseSize = response.d_buffer.size(); + response.d_buffer.resize(responseSize + static_cast(512)); + size_t responseCapacity = response.d_buffer.size(); + auto responseAsCharArray = reinterpret_cast(&response.d_buffer.at(0)); + + auto& ids = response.d_idstate; + // DEBUG: cerr<<"IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) { + // DEBUG: cerr<<"about to match response for "<remote, consumed)) { + // DEBUG: cerr<<"content does not match"<d_firstResponsePacket = false; - - if (state->d_outstanding) { - --state->d_ds->outstanding; - state->d_outstanding = false; - } - auto dh = reinterpret_cast(response); + auto dh = reinterpret_cast(responseAsCharArray); uint16_t addRoom = 0; - DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true); + DNSResponse dr = makeDNSResponseFromIDState(ids, dh, responseCapacity, responseSize, true); if (dr.dnsCryptQuery) { addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; } - memcpy(&state->d_cleartextDH, dr.dh, sizeof(state->d_cleartextDH)); + memcpy(&response.d_cleartextDH, dr.dh, sizeof(response.d_cleartextDH)); std::vector rewrittenResponse; - size_t responseSize = state->d_responseBuffer.size(); - if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) { + if (!processResponse(&responseAsCharArray, &responseSize, &responseCapacity, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) { + // DEBUG: cerr<<"process said to drop it"<d_responseBuffer = std::move(rewrittenResponse); - state->d_responseSize = state->d_responseBuffer.size(); + response.d_buffer = std::move(rewrittenResponse); } else { /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */ - state->d_responseBuffer.resize(state->d_responseSize); + response.d_buffer.resize(responseSize); } if (state->d_isXFR && !state->d_xfrStarted) { @@ -801,67 +492,23 @@ static void handleResponse(std::shared_ptr& state, s state->d_xfrStarted = true; ++g_stats.responses; ++state->d_ci.cs->responses; - ++state->d_ds->responses; + ++response.d_ds->responses; } if (!state->d_isXFR) { ++g_stats.responses; ++state->d_ci.cs->responses; - ++state->d_ds->responses; + ++response.d_ds->responses; } - sendResponse(state, now); + sendResponse(state, now, std::move(response)); } -static void sendQueryToBackend(std::shared_ptr& state, struct timeval& now) -{ - auto ds = state->d_ds; - state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend; - state->d_currentPos = 0; - state->d_firstResponsePacket = true; - - if (state->d_xfrStarted) { - /* sorry, but we are not going to resume a XFR if we have already sent some packets - to the client */ - return; - } - - if (!state->d_downstreamConnection) { - if (state->d_downstreamFailures < state->d_ds->retries) { - try { - state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now); - } - catch (const std::runtime_error& e) { - state->d_downstreamConnection.reset(); - } - } - - if (!state->d_downstreamConnection) { - ++ds->tcpGaveUp; - ++state->d_ci.cs->tcpGaveUp; - vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures); - return; - } - - if (ds->useProxyProtocol && !state->d_proxyProtocolPayloadAdded) { - /* we know there is no TLV values to add, otherwise we would not have tried - to reuse the connection and d_proxyProtocolPayloadAdded would be true already */ - addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, std::vector()); - state->d_proxyProtocolPayloadAdded = true; - } - } - - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", state->d_ids.qname.toLogString(), QType(state->d_ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName()); - - handleDownstreamIO(state, now); - return; -} - -static void handleQuery(std::shared_ptr& state, struct timeval& now) +static bool handleQuery(std::shared_ptr& state, const struct timeval& now) { if (state->d_querySize < sizeof(dnsheader)) { ++g_stats.nonCompliantQueries; - return; + return true; } state->d_readingFirstQuery = false; @@ -900,21 +547,24 @@ static void handleQuery(std::shared_ptr& state, stru std::shared_ptr dnsCryptQuery{nullptr}; auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true); if (dnsCryptResponse) { - state->d_responseBuffer = std::move(*dnsCryptResponse); - state->d_responseSize = state->d_responseBuffer.size(); - sendResponse(state, now); - return; + //state->d_responseBuffer = std::move(*dnsCryptResponse); + //state->d_responseSize = state->d_responseBuffer.size(); + TCPResponse response; + response.d_buffer = std::move(*dnsCryptResponse); + state->d_state = IncomingTCPConnectionState::State::idle; + state->sendResponse(state, now, std::move(response)); + return false; } const auto& dh = reinterpret_cast(query); if (!checkQueryHeaders(dh)) { - return; + return true; } uint16_t qtype, qclass; unsigned int consumed = 0; DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime); + DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_origDest, &state->d_ci.remote, reinterpret_cast(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime); dq.dnsCryptQuery = std::move(dnsCryptQuery); dq.sni = state->d_handler.getServerNameIndication(); @@ -923,27 +573,39 @@ static void handleQuery(std::shared_ptr& state, stru dq.skipCache = true; } - state->d_ds.reset(); - auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds); + std::shared_ptr ds; + auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, ds); if (result == ProcessQueryResult::Drop) { - return; + return true; } if (result == ProcessQueryResult::SendAnswer) { state->d_selfGeneratedResponse = true; state->d_buffer.resize(dq.len); - state->d_responseBuffer = std::move(state->d_buffer); - state->d_responseSize = state->d_responseBuffer.size(); - sendResponse(state, now); - return; + TCPResponse response; + response.d_buffer = std::move(state->d_buffer); + state->d_state = IncomingTCPConnectionState::State::idle; + state->sendResponse(state, now, std::move(response)); + return false; } - if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) { - return; + if (result != ProcessQueryResult::PassToBackend || ds == nullptr) { + return true; + } + +#warning move this, we should just never read another question again on this client connection + if (state->d_xfrStarted) { + /* sorry, but we are not going to resume a XFR if we have already sent some packets + to the client */ + return true; } - setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname)); + IDState ids; + // DEBUG: cerr<<"DQ has "<<(dq.qTag?" TAGS ": "NO TAGS")<id); const uint8_t sizeBytes[] = { static_cast(dq.len / 256), static_cast(dq.len % 256) }; /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes @@ -955,350 +617,271 @@ static void handleQuery(std::shared_ptr& state, stru dq.size = state->d_buffer.size(); state->d_buffer.resize(dq.len); - if (state->d_ds->useProxyProtocol) { + bool proxyProtocolPayloadAdded = false; + std::string proxyProtocolPayload; + + if (ds->useProxyProtocol) { /* if we ever sent a TLV over a connection, we can never go back */ if (!state->d_proxyProtocolPayloadHasTLV) { state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty(); } - if (state->d_downstreamConnection && !state->d_proxyProtocolPayloadHasTLV && state->d_downstreamConnection->matches(state->d_ds)) { - /* we have an existing connection, on which we already sent a Proxy Protocol header with no values - (in the previous query had TLV values we would have reset the connection afterwards), - so let's reuse it as long as we still don't have any values */ - state->d_proxyProtocolPayloadAdded = false; - } - else { - state->d_downstreamConnection.reset(); - addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector()); - state->d_proxyProtocolPayloadAdded = true; + proxyProtocolPayload = getProxyProtocolPayload(dq); + + if (state->d_proxyProtocolPayloadHasTLV) { + /* we will not be able to reuse an existing connection anyway so let's add the payload right now */ + addProxyProtocol(state->d_buffer, proxyProtocolPayload); + proxyProtocolPayloadAdded = true; } } - sendQueryToBackend(state, now); -} - -static void handleNewIOState(std::shared_ptr& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional ttd) -{ - //cerr<<"in "<<__func__<<" for fd "<d_lastIOState<<", new state is "<<(int)iostate<getDownstreamConnection(ds, now); + downstreamConnection->assignToClientConnection(state, state->d_isXFR); - if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) { - state->d_threadData.mplexer->removeReadFD(fd); - //cerr<<__func__<<": remove read FD "<d_lastIOState = IOState::Done; + if (proxyProtocolPayloadAdded) { + downstreamConnection->setProxyProtocolPayloadAdded(true); } - else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) { - state->d_threadData.mplexer->removeWriteFD(fd); - //cerr<<__func__<<": remove write FD "<d_lastIOState = IOState::Done; + else { + downstreamConnection->setProxyProtocolPayload(std::move(proxyProtocolPayload)); } - if (iostate == IOState::NeedRead) { - if (state->d_lastIOState == IOState::NeedRead) { - if (ttd) { - /* let's update the TTD ! */ - state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0); - } - return; - } + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName()); - state->d_lastIOState = IOState::NeedRead; - //cerr<<__func__<<": add read FD "<d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr); - } - else if (iostate == IOState::NeedWrite) { - if (state->d_lastIOState == IOState::NeedWrite) { - return; - } +// DEBUG: cerr<<"about to be queued query IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<queueQuery(TCPQuery(std::move(state->d_buffer), std::move(ids)), downstreamConnection); - state->d_lastIOState = IOState::NeedWrite; - //cerr<<__func__<<": add write FD "<d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr); - } - else if (iostate == IOState::Done) { - state->d_lastIOState = IOState::Done; - } + //sendQueryToBackend(state, now); + // DEBUG: cerr<<"out of "<<__PRETTY_FUNCTION__<& state, struct timeval& now) +void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) { - if (state->d_downstreamConnection == nullptr) { - throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!"); + auto conn = boost::any_cast>(param); + if (fd != conn->d_ci.fd) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_ci.fd)); } - int fd = state->d_downstreamConnection->getHandle(); - IOState iostate = IOState::Done; - bool connectionDied = false; + struct timeval now; + gettimeofday(&now, 0); + handleIO(conn, now); +} - try { - if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) { - int socketFlags = 0; -#ifdef MSG_FASTOPEN - if (state->d_downstreamConnection->isFastOpenEnabled()) { - socketFlags |= MSG_FASTOPEN; - } -#endif /* MSG_FASTOPEN */ - - size_t sent = sendMsgWithOptions(fd, reinterpret_cast(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, socketFlags); - if (sent == state->d_buffer.size()) { - /* request sent ! */ - state->d_downstreamConnection->incQueries(); - state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend; - state->d_currentPos = 0; - state->d_querySentTime = now; - iostate = IOState::NeedRead; - if (!state->d_isXFR && !state->d_outstanding) { - /* don't bother with the outstanding count for XFR queries */ - ++state->d_ds->outstanding; - state->d_outstanding = true; - } - } - else { - state->d_currentPos += sent; - iostate = IOState::NeedWrite; - /* disable fast open on partial write */ - state->d_downstreamConnection->disableFastOpen(); - } - } +void IncomingTCPConnectionState::handleIO(std::shared_ptr& state, const struct timeval& now) +{ + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<d_ioState); - if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) { - // then we need to allocate a new buffer (new because we might need to re-send the query if the - // backend dies on us - // We also might need to read and send to the client more than one response in case of XFR (yeah!) - // should very likely be a TCPIOHandler d_downstreamHandler - iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos); - if (iostate == IOState::Done) { - state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend; - state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1); - state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize); - state->d_currentPos = 0; - } + if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); + // will be handled by the ioGuard + //handleNewIOState(state, IOState::Done, fd, handleIOCallback); + return; } - if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) { - iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos); - if (iostate == IOState::Done) { - handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback); - - if (state->d_isXFR) { - /* Don't reuse the TCP connection after an {A,I}XFR */ - /* but don't reset it either, we will need to read more messages */ - } - else { - /* if we did not send a Proxy Protocol header, let's pool the connection */ - if (state->d_ds && state->d_ds->useProxyProtocol == false) { - releaseDownstreamConnection(std::move(state->d_downstreamConnection)); - } - else { - if (state->d_proxyProtocolPayloadHasTLV) { - /* sent a Proxy Protocol header with TLV values, we can't reuse it */ - state->d_downstreamConnection.reset(); + try { + if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) { + // DEBUG: cerr<<"doing handshake"<d_handler.tryHandshake(); + if (iostate == IOState::Done) { + // DEBUG: cerr<<"handshake done"<d_handler.isTLS()) { + if (!state->d_handler.hasTLSSessionBeenResumed()) { + ++state->d_ci.cs->tlsNewSessions; } else { - /* if we did but there was no TLV values, let's try to reuse it but only - for this incoming connection */ + ++state->d_ci.cs->tlsResumptions; + } + if (state->d_handler.getResumedFromInactiveTicketKey()) { + ++state->d_ci.cs->tlsInactiveTicketKey; + } + if (state->d_handler.getUnknownTicketKey()) { + ++state->d_ci.cs->tlsUnknownTicketKey; } } - } - fd = -1; - state->d_responseReadTime = now; - try { - handleResponse(state, now); + state->d_handshakeDoneTime = now; + state->d_state = IncomingTCPConnectionState::State::readingQuerySize; } - catch (const std::exception& e) { - vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", state->d_ds ? state->d_ds->getName() : "unknown", state->d_ci.remote.toStringWithPort(), e.what()); + else { + wouldBlock = true; } - return; } - } - if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend && - state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend && - state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) { - vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast(state->d_state)); - } - } - catch(const std::exception& e) { - /* most likely an EOF because the other end closed the connection, - but it might also be a real IO error or something else. - Let's just drop the connection - */ - vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what()); - if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) { - ++state->d_ds->tcpDiedSendingQuery; - } - else { - ++state->d_ds->tcpDiedReadingResponse; - } - - /* don't increase this counter when reusing connections */ - if (state->d_downstreamConnection && state->d_downstreamConnection->isFresh()) { - ++state->d_downstreamFailures; - } - - if (state->d_outstanding) { - state->d_outstanding = false; + if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) { + // DEBUG: cerr<<"reading query size"<d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t)); + if (iostate == IOState::Done) { + // DEBUG: cerr<<"query size received"<d_state = IncomingTCPConnectionState::State::readingQuery; + state->d_querySizeReadTime = now; + if (state->d_queriesCount == 0) { + state->d_firstQuerySizeReadTime = now; + } + state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1); + if (state->d_querySize < sizeof(dnsheader)) { + /* go away */ + // will be handled by the guard + //handleNewIOState(state, IOState::Done, fd, handleIOCallback); + return; + } - if (state->d_ds != nullptr) { - --state->d_ds->outstanding; + /* allocate a bit more memory to be able to spoof the content, get an answer from the cache + or to add ECS without allocating a new buffer */ + state->d_buffer.resize(std::max(state->d_querySize + static_cast(512), s_maxPacketCacheEntrySize)); + state->d_currentPos = 0; + } + else { + wouldBlock = true; + } } - } - /* remove this FD from the IO multiplexer */ - iostate = IOState::Done; - connectionDied = true; - } - if (iostate == IOState::Done) { - handleNewIOState(state, iostate, fd, handleDownstreamIOCallback); - } - else { - handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now)); - } - - if (connectionDied) { - state->d_downstreamConnection.reset(); - sendQueryToBackend(state, now); - } -} - -static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param) -{ - auto state = boost::any_cast>(param); - if (state->d_downstreamConnection == nullptr) { - throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!"); - } - if (fd != state->d_downstreamConnection->getHandle()) { - throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle())); - } - - struct timeval now; - gettimeofday(&now, 0); - handleDownstreamIO(state, now); -} - -static void handleIO(std::shared_ptr& state, struct timeval& now) -{ - int fd = state->d_ci.fd; - IOState iostate = IOState::Done; - - if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { - vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); - handleNewIOState(state, IOState::Done, fd, handleIOCallback); - return; - } - - try { - if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) { - iostate = state->d_handler.tryHandshake(); - if (iostate == IOState::Done) { - if (state->d_handler.isTLS()) { - if (!state->d_handler.hasTLSSessionBeenResumed()) { - ++state->d_ci.cs->tlsNewSessions; + if (state->d_state == IncomingTCPConnectionState::State::readingQuery) { + // DEBUG: cerr<<"reading query"<d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); + if (iostate == IOState::Done) { + // DEBUG: cerr<<"query received"<d_queuedResponses.empty()) { + state->resetForNewQuery(); + // DEBUG: cerr<<__LINE__<d_ioState->update(IOState::NeedRead, handleIOCallback, state, state->getClientReadTTD(now)); + // DEBUG: cerr<<__LINE__<d_queuedResponses.front()); + state->d_queuedResponses.pop_front(); + ioGuard.release(); + state->sendResponse(state, now, std::move(resp)); + return; + } } else { - ++state->d_ci.cs->tlsResumptions; - } - if (state->d_handler.getResumedFromInactiveTicketKey()) { - ++state->d_ci.cs->tlsInactiveTicketKey; - } - if (state->d_handler.getUnknownTicketKey()) { - ++state->d_ci.cs->tlsUnknownTicketKey; + /* otherwise the state should already be waiting for + the socket to be writable */ + // DEBUG: cerr<<"should be waiting for writable socket"<d_handshakeDoneTime = now; - state->d_state = IncomingTCPConnectionState::State::readingQuerySize; + else { + wouldBlock = true; + } } - } - if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) { - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t)); - if (iostate == IOState::Done) { - state->d_state = IncomingTCPConnectionState::State::readingQuery; - state->d_querySizeReadTime = now; - if (state->d_queriesCount == 0) { - state->d_firstQuerySizeReadTime = now; - } - state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1); - if (state->d_querySize < sizeof(dnsheader)) { - /* go away */ - handleNewIOState(state, IOState::Done, fd, handleIOCallback); - return; + if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + // DEBUG: cerr<<"sending response"<d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size()); + if (iostate == IOState::Done) { + // DEBUG: cerr<<"response sent"<d_buffer.resize(std::max(state->d_querySize + static_cast(512), s_maxPacketCacheEntrySize)); - state->d_currentPos = 0; + // DEBUG: cerr<<__LINE__<d_ioState->update(IOState::NeedRead, handleIOCallback, state, state->getClientReadTTD(now)); + //// DEBUG: cerr<<__LINE__<d_state == IncomingTCPConnectionState::State::readingQuery) { - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); - if (iostate == IOState::Done) { - handleNewIOState(state, IOState::Done, fd, handleIOCallback); - handleQuery(state, now); - return; + if (state->d_state != IncomingTCPConnectionState::State::idle && + state->d_state != IncomingTCPConnectionState::State::doingHandshake && + state->d_state != IncomingTCPConnectionState::State::readingQuerySize && + state->d_state != IncomingTCPConnectionState::State::readingQuery && + state->d_state != IncomingTCPConnectionState::State::sendingResponse) { + vinfolog("Unexpected state %d in handleIOCallback", static_cast(state->d_state)); } } - - if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { - iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size()); - if (iostate == IOState::Done) { - handleResponseSent(state, now); - return; + catch(const std::exception& e) { + /* most likely an EOF because the other end closed the connection, + but it might also be a real IO error or something else. + Let's just drop the connection + */ + if (state->d_state == IncomingTCPConnectionState::State::idle || + state->d_state == IncomingTCPConnectionState::State::doingHandshake || + state->d_state == IncomingTCPConnectionState::State::readingQuerySize || + state->d_state == IncomingTCPConnectionState::State::readingQuery) { + ++state->d_ci.cs->tcpDiedReadingQuery; + } + else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + ++state->d_ci.cs->tcpDiedSendingResponse; } - } - if (state->d_state != IncomingTCPConnectionState::State::doingHandshake && - state->d_state != IncomingTCPConnectionState::State::readingQuerySize && - state->d_state != IncomingTCPConnectionState::State::readingQuery && - state->d_state != IncomingTCPConnectionState::State::sendingResponse) { - vinfolog("Unexpected state %d in handleIOCallback", static_cast(state->d_state)); - } - } - catch(const std::exception& e) { - /* most likely an EOF because the other end closed the connection, - but it might also be a real IO error or something else. - Let's just drop the connection - */ - if (state->d_state == IncomingTCPConnectionState::State::doingHandshake || - state->d_state == IncomingTCPConnectionState::State::readingQuerySize || - state->d_state == IncomingTCPConnectionState::State::readingQuery) { - ++state->d_ci.cs->tcpDiedReadingQuery; - } - else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { - ++state->d_ci.cs->tcpDiedSendingResponse; + if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) { + // DEBUG: cerr<<"Got an exception while handling TCP query: "<d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); + } + else { + vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort()); + // DEBUG: cerr<<"Closing TCP client connection: "<d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) { - vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); + if (iostate == IOState::Done) { + // DEBUG: cerr<<__LINE__<d_ioState->update(iostate, handleIOCallback, state); + // DEBUG: cerr<<__LINE__<d_ci.remote.toStringWithPort()); + // DEBUG: cerr<<__LINE__<d_ioState->update(iostate, handleIOCallback, state, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now)); + // DEBUG: cerr<<__LINE__<d_state == IncomingTCPConnectionState::State::readingQuerySize && iostate == IOState::NeedRead && !wouldBlock); +} + +void IncomingTCPConnectionState::notifyIOError(std::shared_ptr& state, IDState&& query, const struct timeval& now) +{ + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<getClientReadTTD(now) : state->getClientWriteTTD(now)); + // the backend code already tried to reconnect if it was possible + d_lastIOState = IOState::Done; + d_ioState->reset(); } + + // DEBUG: cerr<<"out "<<__PRETTY_FUNCTION__<& state, const struct timeval& now, TCPResponse&& response) { - auto state = boost::any_cast>(param); - if (fd != state->d_ci.fd) { - throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd)); - } - struct timeval now; - gettimeofday(&now, 0); + sendResponse(state, now, std::move(response)); +} - handleIO(state, now); +void IncomingTCPConnectionState::handleTimeout(bool write) +{ + // DEBUG: cerr<<"client timeout"<tcpClientTimeouts; + d_lastIOState = IOState::Done; + d_ioState->reset(); } static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param) @@ -1333,7 +916,7 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param /* let's update the remaining time */ state->d_remainingTime = g_maxTCPConnectionDuration; - handleIO(state, now); + IncomingTCPConnectionState::handleIO(state, now); } catch(...) { delete citmp; @@ -1342,7 +925,7 @@ static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param } } -void tcpClientThread(int pipefd) +static void tcpClientThread(int pipefd) { /* we get launched with a pipe on which we receive file descriptors from clients that we own from that point on */ @@ -1361,42 +944,50 @@ void tcpClientThread(int pipefd) data.mplexer->run(&now); if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) { - cleanupClosedTCPConnections(); + DownstreamConnectionsManager::cleanupClosedTCPConnections(); lastTCPCleanup = now.tv_sec; } if (now.tv_sec > lastTimeoutScan) { lastTimeoutScan = now.tv_sec; auto expiredReadConns = data.mplexer->getTimeouts(now, false); - for(const auto& conn : expiredReadConns) { - auto state = boost::any_cast>(conn.second); - if (conn.first == state->d_ci.fd) { - vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); - ++state->d_ci.cs->tcpClientTimeouts; + for (const auto& cbData : expiredReadConns) { + if (cbData.second.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(cbData.second); + if (cbData.first == state->d_ci.fd) { + vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(false); + } } - else if (state->d_ds) { - vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName()); - ++state->d_ci.cs->tcpDownstreamTimeouts; - ++state->d_ds->tcpReadTimeouts; + else if (cbData.second.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(cbData.second); + vinfolog("Timeout (read) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, false); } - data.mplexer->removeReadFD(conn.first); - state->d_lastIOState = IOState::Done; } auto expiredWriteConns = data.mplexer->getTimeouts(now, true); - for(const auto& conn : expiredWriteConns) { - auto state = boost::any_cast>(conn.second); - if (conn.first == state->d_ci.fd) { - vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); - ++state->d_ci.cs->tcpClientTimeouts; + for (const auto& cbData : expiredWriteConns) { + if (cbData.second.type() == typeid(std::shared_ptr)) { + auto state = boost::any_cast>(cbData.second); + if (cbData.first == state->d_ci.fd) { + vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(true); + } + } + else if (cbData.second.type() == typeid(std::shared_ptr)) { + auto conn = boost::any_cast>(cbData.second); + vinfolog("Timeout (write) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, true); } - else if (state->d_ds) { - vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName()); - ++state->d_ci.cs->tcpDownstreamTimeouts; - ++state->d_ds->tcpWriteTimeouts; +#if 0 + try { + data.mplexer->removeWriteFD(cbData.first); + } + catch (const FDMultiplexerException& fde) { + warnlog("Exception while removing a socket (%d) after a write timeout: %s", cbData.first, fde.what()); } - data.mplexer->removeWriteFD(conn.first); - state->d_lastIOState = IOState::Done; +#endif } } } diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 55c0868647..0dee6cc8df 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -465,13 +465,60 @@ struct ClientState; struct IDState { IDState(): sentTime(true), delayMsec(0), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0;} - IDState(const IDState& orig): origRemote(orig.origRemote), origDest(orig.origDest), age(orig.age) + IDState(const IDState& orig) = delete; + IDState(IDState&& rhs): origRemote(rhs.origRemote), origDest(rhs.origDest), sentTime(rhs.sentTime), qname(std::move(rhs.qname)), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), subnet(rhs.subnet), packetCache(std::move(rhs.packetCache)), qTag(std::move(rhs.qTag)), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), age(rhs.age), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), origFD(rhs.origFD), delayMsec(rhs.delayMsec), tempFailureTTL(rhs.tempFailureTTL), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), destHarvested(rhs.destHarvested), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope) { - usageIndicator.store(orig.usageIndicator.load()); - origFD = orig.origFD; - origID = orig.origID; - delayMsec = orig.delayMsec; - tempFailureTTL = orig.tempFailureTTL; + if (rhs.isInUse()) { + throw std::runtime_error("Trying to move an in-use IDState"); + } + +#ifdef HAVE_PROTOBUF + uniqueId = std::move(rhs.uniqueId); +#endif + } + + IDState& operator=(IDState&& rhs) + { + if (isInUse()) { + throw std::runtime_error("Trying to overwrite an in-use IDState"); + } + + if (rhs.isInUse()) { + throw std::runtime_error("Trying to move an in-use IDState"); + } + + origRemote = rhs.origRemote; + origDest = rhs.origDest; + sentTime = rhs.sentTime; + qname = std::move(rhs.qname); + dnsCryptQuery = std::move(rhs.dnsCryptQuery); + subnet = rhs.subnet; + packetCache = std::move(rhs.packetCache); + qTag = std::move(rhs.qTag); + cs = rhs.cs; + du = std::move(rhs.du); + cacheKey = rhs.cacheKey; + cacheKeyNoECS = rhs.cacheKeyNoECS; + age = rhs.age; + qtype = rhs.qtype; + qclass = rhs.qclass; + origID = rhs.origID; + origFlags = rhs.origFlags; + origFD = rhs.origFD; + delayMsec = rhs.delayMsec; + tempFailureTTL = rhs.tempFailureTTL; + ednsAdded = rhs.ednsAdded; + ecsAdded = rhs.ecsAdded; + skipCache = rhs.skipCache; + destHarvested = rhs.destHarvested; + dnssecOK = rhs.dnssecOK; + useZeroScope = rhs.useZeroScope; + +#ifdef HAVE_PROTOBUF + uniqueId = std::move(rhs.uniqueId); +#endif + + return *this; } static const int64_t unusedIndicator = -1; @@ -563,14 +610,14 @@ struct IDState uint16_t origID; // 2 uint16_t origFlags; // 2 int origFD{-1}; - int delayMsec; + int delayMsec{0}; boost::optional tempFailureTTL; bool ednsAdded{false}; bool ecsAdded{false}; bool skipCache{false}; bool destHarvested{false}; // if true, origDest holds the original dest addr, otherwise the listening addr bool dnssecOK{false}; - bool useZeroScope; + bool useZeroScope{false}; }; typedef std::unordered_map QueryCountRecords; diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 1bea3b3330..ca5a079395 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -159,6 +159,8 @@ dnsdist_SOURCES = \ dnsdist-snmp.cc dnsdist-snmp.hh \ dnsdist-systemd.cc dnsdist-systemd.hh \ dnsdist-tcp.cc \ + dnsdist-tcp-downstream.cc dnsdist-tcp-downstream.hh \ + dnsdist-tcp-upstream.hh \ dnsdist-web.cc dnsdist-web.hh \ dnsdist-xpf.cc dnsdist-xpf.hh \ dnsdist.cc dnsdist.hh \ @@ -200,6 +202,7 @@ dnsdist_SOURCES = \ statnode.cc statnode.hh \ svc-records.cc svc-records.hh \ tcpiohandler.cc tcpiohandler.hh \ + tcpiohandler-mplexer.hh \ threadname.hh threadname.cc \ uuid-utils.hh uuid-utils.cc \ views.hh \ diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc index 38398e28eb..2eb19a8c59 100644 --- a/pdns/dnsdistdist/dnsdist-backend.cc +++ b/pdns/dnsdistdist/dnsdist-backend.cc @@ -150,7 +150,7 @@ void DownstreamState::setWeight(int newWeight) } } -DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_), name(remote_.toStringWithPort()), nameWithAddr(remote_.toStringWithPort()) +DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, const std::string& sourceItfName_, size_t numberOfSockets, bool connect=true): sourceItfName(sourceItfName_), remote(remote_), idStates(g_maxOutstanding), sourceAddr(sourceAddr_), sourceItf(sourceItf_), name(remote_.toStringWithPort()), nameWithAddr(remote_.toStringWithPort()) { id = getUniqueID(); threadStarted.clear(); @@ -164,7 +164,6 @@ DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress if (connect && !IsAnyAddress(remote)) { reconnect(); - idStates.resize(g_maxOutstanding); sw.start(); } } diff --git a/pdns/dnsdistdist/dnsdist-proxy-protocol.cc b/pdns/dnsdistdist/dnsdist-proxy-protocol.cc index 083b0d345a..e7773a2003 100644 --- a/pdns/dnsdistdist/dnsdist-proxy-protocol.cc +++ b/pdns/dnsdistdist/dnsdist-proxy-protocol.cc @@ -22,9 +22,13 @@ #include "dnsdist-proxy-protocol.hh" -bool addProxyProtocol(DNSQuestion& dq) +std::string getProxyProtocolPayload(const DNSQuestion& dq) +{ + return makeProxyHeader(dq.tcp, *dq.remote, *dq.local, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector()); +} + +bool addProxyProtocol(DNSQuestion& dq, const std::string& payload) { - auto payload = makeProxyHeader(dq.tcp, *dq.remote, *dq.local, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector()); if ((dq.size - dq.len) < payload.size()) { return false; } @@ -36,10 +40,14 @@ bool addProxyProtocol(DNSQuestion& dq) return true; } -bool addProxyProtocol(std::vector& buffer, bool tcp, const ComboAddress& source, const ComboAddress& destination, const std::vector& values) +bool addProxyProtocol(DNSQuestion& dq) { - auto payload = makeProxyHeader(tcp, source, destination, values); + auto payload = getProxyProtocolPayload(dq); + return addProxyProtocol(dq, payload); +} +bool addProxyProtocol(std::vector& buffer, const std::string& payload) +{ auto previousSize = buffer.size(); if (payload.size() > (std::numeric_limits::max() - previousSize)) { return false; @@ -51,3 +59,9 @@ bool addProxyProtocol(std::vector& buffer, bool tcp, const ComboAddress return true; } + +bool addProxyProtocol(std::vector& buffer, bool tcp, const ComboAddress& source, const ComboAddress& destination, const std::vector& values) +{ + auto payload = makeProxyHeader(tcp, source, destination, values); + return addProxyProtocol(buffer, payload); +} diff --git a/pdns/dnsdistdist/dnsdist-proxy-protocol.hh b/pdns/dnsdistdist/dnsdist-proxy-protocol.hh index 433a7d2394..a218a403d4 100644 --- a/pdns/dnsdistdist/dnsdist-proxy-protocol.hh +++ b/pdns/dnsdistdist/dnsdist-proxy-protocol.hh @@ -23,5 +23,9 @@ #include "dnsdist.hh" +std::string getProxyProtocolPayload(const DNSQuestion& dq); + bool addProxyProtocol(DNSQuestion& dq); +bool addProxyProtocol(DNSQuestion& dq, const std::string& payload); +bool addProxyProtocol(std::vector& buffer, const std::string& payload); bool addProxyProtocol(std::vector& buffer, bool tcp, const ComboAddress& source, const ComboAddress& destination, const std::vector& values); diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc new file mode 100644 index 0000000000..86dee17204 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -0,0 +1,469 @@ + +#include "dnsdist-tcp-downstream.hh" +#include "dnsdist-tcp-upstream.hh" + +const uint16_t TCPConnectionToBackend::s_xfrID = 0; + +void TCPConnectionToBackend::assignToClientConnection(std::shared_ptr& clientConn, bool isXFR) +{ + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<(clientConn->getIOMPlexer(), d_socket->getHandle()); +} + +IOState TCPConnectionToBackend::sendNextQuery(std::shared_ptr& conn) +{ + conn->d_currentQuery = std::move(conn->d_pendingQueries.front()); + conn->d_pendingQueries.pop_front(); + conn->d_state = State::sendingQueryToBackend; + return IOState::NeedWrite; +} + +void TCPConnectionToBackend::handleIO(std::shared_ptr& conn, const struct timeval& now) +{ + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<d_socket == nullptr) { + throw std::runtime_error("No downstream socket in " + std::string(__PRETTY_FUNCTION__) + "!"); + } + + bool connectionDied = false; + IOState iostate = IOState::Done; + IOStateGuard ioGuard(conn->d_ioState); + int fd = conn->d_socket->getHandle(); + + try { + if (conn->d_state == State::sendingQueryToBackend) { + // DEBUG: cerr<<"sending query to backend over FD "<isFastOpenEnabled()) { + socketFlags |= MSG_FASTOPEN; + } +#endif /* MSG_FASTOPEN */ + + size_t sent = sendMsgWithOptions(fd, reinterpret_cast(&conn->d_currentQuery.d_buffer.at(conn->d_currentPos)), conn->d_currentQuery.d_buffer.size() - conn->d_currentPos, &conn->d_ds->remote, &conn->d_ds->sourceAddr, conn->d_ds->sourceItf, socketFlags); + if (sent == conn->d_currentQuery.d_buffer.size()) { + // DEBUG: cerr<<"query sent to backend"<incQueries(); + conn->d_currentPos = 0; + //conn->d_currentQuery.d_querySentTime = now; + // DEBUG: cerr<<"adding a pending response for ID "<d_currentQuery.d_idstate.origID<<" and QNAME "<d_currentQuery.d_idstate.qname<d_currentQuery.d_idstate.qTag?"tags":"no tags")<d_pendingResponses[conn->d_currentQuery.d_idstate.origID] = std::move(conn->d_currentQuery); + conn->d_currentQuery.d_buffer.clear(); +#if 0 + if (!conn->d_usedForXFR) { + /* don't bother with the outstanding count for XFR queries */ + ++conn->d_ds->outstanding; + ++conn->d_outstanding; + } +#endif + + if (conn->d_pendingQueries.empty()) { + conn->d_state = State::readingResponseSizeFromBackend; + conn->d_currentPos = 0; + conn->d_responseBuffer.resize(sizeof(uint16_t)); + iostate = IOState::NeedRead; + } + else { + iostate = sendNextQuery(conn); + } + } + else { + conn->d_currentPos += sent; + iostate = IOState::NeedWrite; + /* disable fast open on partial write */ + conn->disableFastOpen(); + } + } + + if (conn->d_state == State::readingResponseSizeFromBackend) { + // DEBUG: cerr<<"reading response size from backend"<d_responseBuffer.resize(sizeof(uint16_t)); + iostate = tryRead(fd, conn->d_responseBuffer, conn->d_currentPos, sizeof(uint16_t) - conn->d_currentPos); + if (iostate == IOState::Done) { + // DEBUG: cerr<<"got response size from backend"<d_state = State::readingResponseFromBackend; + conn->d_responseSize = conn->d_responseBuffer.at(0) * 256 + conn->d_responseBuffer.at(1); + conn->d_responseBuffer.reserve(conn->d_responseSize + /* we will need to prepend the size later */ 2); + conn->d_responseBuffer.resize(conn->d_responseSize); + conn->d_currentPos = 0; + } + } + + if (conn->d_state == State::readingResponseFromBackend) { + // DEBUG: cerr<<"reading response from backend"<d_responseBuffer, conn->d_currentPos, conn->d_responseSize - conn->d_currentPos); + if (iostate == IOState::Done) { + // DEBUG: cerr<<"got response from backend"<d_responseReadTime = now; + try { + iostate = conn->handleResponse(now); + } + catch (const std::exception& e) { + vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what()); + } + //return; + } + } + + if (conn->d_state != State::idle && + conn->d_state != State::sendingQueryToBackend && + conn->d_state != State::readingResponseSizeFromBackend && + conn->d_state != State::readingResponseFromBackend) { + vinfolog("Unexpected state %d in TCPConnectionToBackend::handleIO", static_cast(conn->d_state)); + } + } + catch (const std::exception& e) { + /* most likely an EOF because the other end closed the connection, + but it might also be a real IO error or something else. + Let's just drop the connection + */ + vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_ioState->getState() == IOState::NeedRead ? "reading from" : "writing to"), conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what()); + if (conn->d_state == State::sendingQueryToBackend) { + ++conn->d_ds->tcpDiedSendingQuery; + } + else { + ++conn->d_ds->tcpDiedReadingResponse; + } + + /* don't increase this counter when reusing connections */ + if (conn->d_fresh) { + ++conn->d_downstreamFailures; + } + +#if 0 + if (conn->d_outstanding) { + conn->d_outstanding = false; + + if (conn->d_ds != nullptr) { + --conn->d_ds->outstanding; + } + } +#endif + /* remove this FD from the IO multiplexer */ + iostate = IOState::Done; + connectionDied = true; + } + + if (connectionDied) { + bool reconnected = false; + // DEBUG: cerr<<"connection died, number of failures is "<d_downstreamFailures<<", retries is "<d_ds->retries<d_usedForXFR || conn->d_queries == 0) && conn->d_downstreamFailures < conn->d_ds->retries) { + // DEBUG: cerr<<"reconnecting"<d_ioState->reset(); + ioGuard.release(); + + if (conn->reconnect()) { + // DEBUG: cerr<<"reconnected"<d_ioState = make_unique(conn->d_clientConn->getIOMPlexer(), conn->d_socket->getHandle()); + // DEBUG: cerr<<"new state"<d_pendingResponses) { + conn->d_pendingQueries.push_back(std::move(pending.second)); + } + conn->d_pendingResponses.clear(); + conn->d_currentPos = 0; + + if (conn->d_state == State::doingHandshake || + conn->d_state == State::sendingQueryToBackend) { + iostate = IOState::NeedWrite; + // resume sending query + } + else { + // DEBUG: cerr<<"sending next query"<d_proxyProtocolPayloadAdded && !conn->d_proxyProtocolPayload.empty()) { + conn->d_currentQuery.d_buffer.insert(conn->d_currentQuery.d_buffer.begin(), conn->d_proxyProtocolPayload.begin(), conn->d_proxyProtocolPayload.end()); + conn->d_proxyProtocolPayloadAdded = true; + } + + reconnected = true; + } + } + + if (!reconnected) { + /* reconnect failed, we give up */ + conn->d_connectionDied = true; + conn->notifyAllQueriesFailed(now); + } + } + + if (iostate == IOState::Done) { + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<", done"<d_ioState->update(iostate, handleIOCallback, conn); + } + else { + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<", updating to "<<(int)iostate<d_ioState->update(iostate, handleIOCallback, conn, iostate == IOState::NeedRead ? conn->getBackendReadTTD(now) : conn->getBackendWriteTTD(now)); + } + ioGuard.release(); + +} + +void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto conn = boost::any_cast>(param); + if (fd != conn->getHandle()) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle())); + } + + struct timeval now; + gettimeofday(&now, 0); + handleIO(conn, now); +} + +void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr& sharedSelf) +{ + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<update(IOState::NeedWrite, handleIOCallback, sharedSelf, getBackendWriteTTD(now)); + } + else { + // store query in the list of queries to send + d_pendingQueries.push_back(std::move(query)); + } + // DEBUG: cerr<<"out of "<<__PRETTY_FUNCTION__< result; + + if (d_socket) { + // DEBUG: cerr<<"closing socket "<getHandle()<getHandle(), SHUT_RDWR); + d_socket.reset(); + d_ioState.reset(); + --d_ds->tcpCurrentConnections; + } + + do { + vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures); + try { + result = std::unique_ptr(new Socket(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0)); + // DEBUG: cerr<<"result of connect is "<getHandle()<sourceAddr)) { + SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); +#ifdef IP_BIND_ADDRESS_NO_PORT + if (d_ds->ipBindAddrNoPort) { + SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); + } +#endif +#ifdef SO_BINDTODEVICE + if (!d_ds->sourceItfName.empty()) { + int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length()); + if (res != 0) { + vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror()); + } + } +#endif + result->bind(d_ds->sourceAddr, false); + } + result->setNonBlocking(); +#ifdef MSG_FASTOPEN + if (!d_ds->tcpFastOpen || !isFastOpenEnabled()) { + SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0); + } +#else + SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0); +#endif /* MSG_FASTOPEN */ + + d_socket = std::move(result); + // DEBUG: cerr<<"connected new socket "<getHandle()<tcpCurrentConnections; + return true; + } + catch(const std::runtime_error& e) { + vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what()); + d_downstreamFailures++; + if (d_downstreamFailures > d_ds->retries) { + throw; + } + } + } + while (d_downstreamFailures <= d_ds->retries); + + return false; +} + +void TCPConnectionToBackend::handleTimeout(const struct timeval& now, bool write) +{ + if (write) { + ++d_ds->tcpWriteTimeouts; + } + else { + ++d_ds->tcpReadTimeouts; + } + + if (d_ioState) { + d_ioState->reset(); + } + + notifyAllQueriesFailed(now, true); +} + +void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, bool timeout) +{ + d_connectionDied = true; + //auto clientConn = d_clientConn.lock(); + //if (!clientConn) { + // d_clientConn.reset(); + // return; + //} + auto& clientConn = d_clientConn; + if (!clientConn->active()) { + // a client timeout occured, or something like that */ + d_connectionDied = true; + d_clientConn.reset(); + return; + } + + if (timeout) { + ++clientConn->d_ci.cs->tcpDownstreamTimeouts; + } + + if (d_state == State::doingHandshake || d_state == State::sendingQueryToBackend) { + clientConn->notifyIOError(clientConn, std::move(d_currentQuery.d_idstate), now); + } + + for (auto& query : d_pendingQueries) { + clientConn->notifyIOError(clientConn, std::move(query.d_idstate), now); + } + + for (auto& response : d_pendingResponses) { + clientConn->notifyIOError(clientConn, std::move(response.second.d_idstate), now); + } + + d_pendingQueries.clear(); + d_pendingResponses.clear(); + + d_clientConn.reset(); +} + +IOState TCPConnectionToBackend::handleResponse(const struct timeval& now) +{ + // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<active()) { + // DEBUG: cerr<<"client is not active"<handleXFRResponse(clientConn, now, std::move(response)); + d_state = State::readingResponseSizeFromBackend; + d_currentPos = 0; + d_responseBuffer.resize(sizeof(uint16_t)); + return IOState::NeedRead; + // get ready to read the next packet, if any + } + else { + // DEBUG: cerr<<"not XFR, phew"<second.d_idstate); + // DEBUG: cerr<<"IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<handleResponse(clientConn, now, TCPResponse(std::move(d_responseBuffer), std::move(ids), d_ds)); + d_pendingResponses.erase(it); + + if (!d_pendingQueries.empty()) { + // DEBUG: cerr<<"still have some queries to send"<getNameWithAddr()); + } + + dnsheader dh; + memcpy(&dh, &d_responseBuffer.at(0), sizeof(dh)); + return ntohs(dh.id); +} + +void TCPConnectionToBackend::setProxyProtocolPayload(std::string&& payload) +{ + d_proxyProtocolPayload = std::move(payload); +} + +void TCPConnectionToBackend::setProxyProtocolPayloadAdded(bool added) +{ + d_proxyProtocolPayloadAdded = added; +} diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.hh b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh new file mode 100644 index 0000000000..9d24de5f17 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.hh @@ -0,0 +1,212 @@ +#pragma once + +#include + +#include "sstuff.hh" +#include "tcpiohandler-mplexer.hh" +#include "dnsdist.hh" + +struct TCPQuery +{ + TCPQuery() + { + } + + TCPQuery(std::vector&& buffer, IDState&& state): d_idstate(std::move(state)), d_buffer(std::move(buffer)) + { + } + + IDState d_idstate; + std::vector d_buffer; +}; + +struct TCPResponse : public TCPQuery +{ + TCPResponse() + { + } + + TCPResponse(std::vector&& buffer, IDState&& state, std::shared_ptr ds): TCPQuery(std::move(buffer), std::move(state)), d_ds(ds) + { + } + + std::shared_ptr d_ds{nullptr}; + dnsheader d_cleartextDH; + bool d_selfGenerated{false}; +}; + +class IncomingTCPConnectionState; + +class TCPConnectionToBackend +{ +public: + TCPConnectionToBackend(std::shared_ptr& ds, const struct timeval& now): d_responseBuffer(s_maxPacketCacheEntrySize), d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen) + { + reconnect(); + } + + ~TCPConnectionToBackend() + { + if (d_ds && d_socket) { + --d_ds->tcpCurrentConnections; + struct timeval now; + gettimeofday(&now, nullptr); + + auto diff = now - d_connectionStartTime; + d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000); + } + } + + void assignToClientConnection(std::shared_ptr& clientConn, bool isXFR); + + int getHandle() const + { + if (!d_socket) { + throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection"); + } + + return d_socket->getHandle(); + } + + const ComboAddress& getRemote() const + { + return d_ds->remote; + } + + const std::string& getBackendName() const + { + return d_ds->getName(); + } + + bool isFresh() const + { + return d_fresh; + } + + void incQueries() + { + ++d_queries; + } + + void setReused() + { + d_fresh = false; + } + + void disableFastOpen() + { + d_enableFastOpen = false; + } + + bool isFastOpenEnabled() + { + return d_enableFastOpen; + } + + bool canAcceptNewQueries() const + { + if (d_usedForXFR || d_connectionDied) { + return false; + /* Don't reuse the TCP connection after an {A,I}XFR */ + /* but don't reset it either, we will need to read more messages */ + } +#warning FIXME: maximum number of pending queries + return true; + } + + bool canBeReused() const + { + if (d_usedForXFR || d_connectionDied) { + return false; + } + /* we can't reuse a connection where a proxy protocol payload has been sent, + since: + - it cannot be reused for a different client + - we might have different TLV values for each query + */ + if (d_ds && d_ds->useProxyProtocol == true) { + return false; + } + return true; + } + + bool matches(const std::shared_ptr& ds) const + { + if (!ds || !d_ds) { + return false; + } + return ds == d_ds; + } + + static void handleIO(std::shared_ptr& conn, const struct timeval& now); + static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); + static IOState sendNextQuery(std::shared_ptr& conn); + + void queueQuery(TCPQuery&& query, std::shared_ptr& sharedSelf); + void handleTimeout(const struct timeval& now, bool write); + IOState handleResponse(const struct timeval& now); + void setProxyProtocolPayload(std::string&& payload); + void setProxyProtocolPayloadAdded(bool added); + +private: + uint16_t getQueryIdFromResponse(); + bool reconnect(); + void notifyAllQueriesFailed(const struct timeval& now, bool timeout = false); + + boost::optional getBackendReadTTD(const struct timeval& now) const + { + if (d_ds == nullptr) { + throw std::runtime_error("getBackendReadTTD() without any backend selected"); + } + if (d_ds->tcpRecvTimeout == 0) { + return boost::none; + } + + struct timeval res = now; + res.tv_sec += d_ds->tcpRecvTimeout; + + return res; + } + + boost::optional getBackendWriteTTD(const struct timeval& now) const + { + if (d_ds == nullptr) { + throw std::runtime_error("getBackendReadTTD() called without any backend selected"); + } + if (d_ds->tcpSendTimeout == 0) { + return boost::none; + } + + struct timeval res = now; + res.tv_sec += d_ds->tcpSendTimeout; + + return res; + } + + /* waitingForResponseFromBackend is a state where we have not yet started reading the size, + so we can still switch to sending instead */ + enum class State { idle, doingHandshake, sendingQueryToBackend, waitingForResponseFromBackend, readingResponseSizeFromBackend, readingResponseFromBackend }; + static const uint16_t s_xfrID; + + std::vector d_responseBuffer; + std::deque d_pendingQueries; + std::unordered_map d_pendingResponses; + std::unique_ptr d_socket{nullptr}; + std::unique_ptr d_ioState{nullptr}; + std::shared_ptr d_ds{nullptr}; + //std::weak_ptr d_clientConn; + std::shared_ptr d_clientConn; + std::string d_proxyProtocolPayload; + TCPQuery d_currentQuery; + struct timeval d_connectionStartTime; + size_t d_currentPos{0}; + uint64_t d_queries{0}; + uint64_t d_downstreamFailures{0}; + uint16_t d_responseSize{0}; + State d_state{State::idle}; + bool d_fresh{true}; + bool d_enableFastOpen{false}; + bool d_connectionDied{true}; + bool d_usedForXFR{false}; + bool d_proxyProtocolPayloadAdded{false}; +}; diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh new file mode 100644 index 0000000000..9479f0a403 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -0,0 +1,228 @@ +#pragma once + +#include "dolog.hh" + +class TCPClientThreadData +{ +public: + TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent())) + { + } + + LocalHolders holders; + LocalStateHolder > localRespRulactions; + std::unique_ptr mplexer{nullptr}; +}; + +struct ConnectionInfo +{ + ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1) + { + } + ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd) + { + rhs.cs = nullptr; + rhs.fd = -1; + } + + ConnectionInfo(const ConnectionInfo& rhs) = delete; + ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete; + + ConnectionInfo& operator=(ConnectionInfo&& rhs) + { + remote = rhs.remote; + cs = rhs.cs; + rhs.cs = nullptr; + fd = rhs.fd; + rhs.fd = -1; + return *this; + } + + ~ConnectionInfo() + { + if (fd != -1) { + close(fd); + fd = -1; + } + if (cs) { + --cs->tcpCurrentConnections; + } + } + + ComboAddress remote; + ClientState* cs{nullptr}; + int fd{-1}; +}; + +class IncomingTCPConnectionState +{ +public: + //IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_responseBuffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_ioState(threadData.mplexer, d_ci.fd), _connectionStartTime(now) + IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_ioState(make_unique(threadData.mplexer, d_ci.fd)), d_connectionStartTime(now) + { + d_origDest.reset(); + d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family; + socklen_t socklen = d_origDest.getSocklen(); + if (getsockname(d_ci.fd, reinterpret_cast(&d_origDest), &socklen)) { + d_origDest = d_ci.cs->local; + } + } + + IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete; + IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete; + + ~IncomingTCPConnectionState(); + + void resetForNewQuery(); + + boost::optional getClientReadTTD(struct timeval now) const + { + if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) { + return boost::none; + } + + if (g_maxTCPConnectionDuration > 0) { + auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec; + if (elapsed < 0 || (static_cast(elapsed) >= g_maxTCPConnectionDuration)) { + return now; + } + auto remaining = g_maxTCPConnectionDuration - elapsed; + if (g_tcpRecvTimeout == 0 || remaining <= static_cast(g_tcpRecvTimeout)) { + now.tv_sec += remaining; + return now; + } + } + + now.tv_sec += g_tcpRecvTimeout; + return now; + } + + boost::optional getClientWriteTTD(const struct timeval& now) const + { + if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) { + return boost::none; + } + + struct timeval res = now; + + if (g_maxTCPConnectionDuration > 0) { + auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec; + if (elapsed < 0 || static_cast(elapsed) >= g_maxTCPConnectionDuration) { + return res; + } + auto remaining = g_maxTCPConnectionDuration - elapsed; + if (g_tcpSendTimeout == 0 || remaining <= static_cast(g_tcpSendTimeout)) { + res.tv_sec += remaining; + return res; + } + } + + res.tv_sec += g_tcpSendTimeout; + return res; + } + + bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now) + { + if (maxConnectionDuration) { + time_t curtime = now.tv_sec; + unsigned int elapsed = 0; + if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward + elapsed = curtime - d_connectionStartTime.tv_sec; + } + if (elapsed >= maxConnectionDuration) { + return true; + } + d_remainingTime = maxConnectionDuration - elapsed; + } + + return false; + } + + void dump() const + { + static std::mutex s_mutex; + + struct timeval now; + gettimeofday(&now, 0); + + { + std::lock_guard lock(s_mutex); + fprintf(stderr, "State is %p\n", this); + cerr << "Current state is " << static_cast(d_state) << ", got "< State::doingHandshake) { + cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl; + } + if (d_state > State::readingQuerySize) { + cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl; + } + if (d_state > State::readingQuerySize) { + cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl; + } + if (d_state > State::readingQuery) { + cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl; + } + } + } + + std::shared_ptr getActiveDownstreamConnection(const std::shared_ptr& ds) + { +#warning TODO: we need to find a connection to this DS, usable (no TLV values sent) and supporting OOR + return nullptr; + } + + std::shared_ptr getDownstreamConnection(std::shared_ptr& ds, const struct timeval& now); + + std::unique_ptr& getIOMPlexer() const + { + return d_threadData.mplexer; + } + + static void handleIO(std::shared_ptr& conn, const struct timeval& now); + static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); + + void queueQuery(TCPQuery&& query); + void notifyIOError(std::shared_ptr& state, IDState&& query, const struct timeval& now); + void sendResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); + void handleResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); + void handleXFRResponse(std::shared_ptr& state, const struct timeval& now, TCPResponse&& response); + void handleTimeout(bool write); + + bool active() const + { + return d_ioState != nullptr; + } + + enum class State { doingHandshake, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ }; + + std::vector d_buffer; + std::deque d_queuedResponses; + TCPClientThreadData& d_threadData; + TCPResponse d_currentResponse; + ConnectionInfo d_ci; + ComboAddress d_origDest; + TCPIOHandler d_handler; + std::unique_ptr d_ioState{nullptr}; + struct timeval d_connectionStartTime; + struct timeval d_handshakeDoneTime; + struct timeval d_firstQuerySizeReadTime; + struct timeval d_querySizeReadTime; + struct timeval d_queryReadTime; + size_t d_currentPos{0}; + size_t d_queriesCount{0}; + unsigned int d_remainingTime{0}; + uint16_t d_querySize{0}; + uint16_t d_downstreamFailures{0}; + State d_state{State::doingHandshake}; + IOState d_lastIOState{IOState::Done}; + bool d_readingFirstQuery{true}; + bool d_isXFR{false}; + bool d_xfrStarted{false}; + bool d_xfrDone{false}; + bool d_selfGeneratedResponse{false}; + bool d_proxyProtocolPayloadAdded{false}; + bool d_proxyProtocolPayloadHasTLV{false}; +}; + +IOState tryRead(int fd, std::vector& buffer, size_t& pos, size_t toRead); diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index 080c95e3b8..57c21606e0 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -866,7 +866,7 @@ try } string decoded; - /* rough estimate so we hopefully don't need a need allocation later */ + /* rough estimate so we hopefully don't need a new allocation later */ /* We reserve at least 512 additional bytes to be able to add EDNS, but we also want at least s_maxPacketCacheEntrySize bytes to be able to fill the answer from the packet cache */ const size_t estimate = ((sdns.size() * 3) / 4); diff --git a/pdns/dnsdistdist/tcpiohandler-mplexer.hh b/pdns/dnsdistdist/tcpiohandler-mplexer.hh new file mode 100644 index 0000000000..fdb1c5ed97 --- /dev/null +++ b/pdns/dnsdistdist/tcpiohandler-mplexer.hh @@ -0,0 +1,121 @@ + +#pragma once + +#include "mplexer.hh" +#include "tcpiohandler.hh" + +class IOStateHandler +{ +public: + IOStateHandler(std::unique_ptr& mplexer, const int fd): d_mplexer(mplexer), d_fd(fd), d_currentState(IOState::Done) + { + } + + IOStateHandler(std::unique_ptr& mplexer): d_mplexer(mplexer), d_fd(-1), d_currentState(IOState::Done) + { + } + + ~IOStateHandler() + { + /* be careful that this won't save us if the callback is still registered to the multiplexer, + because in that case the shared pointer count will never reach zero so this destructor won't + be called */ + reset(); + } + + IOState getState() const + { + return d_currentState; + } + + void setSocket(int fd) + { + if (d_fd != -1) { + throw std::runtime_error("Trying to set the socket descriptor on an already initialized IOStateHandler"); + } + d_fd = fd; + } + + void reset() + { + update(IOState::Done); + } + + void update(IOState iostate, FDMultiplexer::callbackfunc_t callback = FDMultiplexer::callbackfunc_t(), FDMultiplexer::funcparam_t callbackData = boost::any(), boost::optional ttd = boost::none) + { + cerr<<"in "<<__PRETTY_FUNCTION__<<" for fd "<removeReadFD(d_fd); + d_currentState = IOState::Done; + } + else if (d_currentState == IOState::NeedWrite && iostate != IOState::NeedWrite) { + cerr<<__PRETTY_FUNCTION__<<": remove write FD "<removeWriteFD(d_fd); + d_currentState = IOState::Done; + } + + if (iostate == IOState::NeedRead) { + if (d_currentState == IOState::NeedRead) { + if (ttd) { + /* let's update the TTD ! */ + d_mplexer->setReadTTD(d_fd, *ttd, /* we pass 0 here because we already have a TTD */0); + } + return; + } + + d_currentState = IOState::NeedRead; + cerr<<__PRETTY_FUNCTION__<<": add read FD "<addReadFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr); + } + else if (iostate == IOState::NeedWrite) { + if (d_currentState == IOState::NeedWrite) { + return; + } + + d_currentState = IOState::NeedWrite; + cerr<<__PRETTY_FUNCTION__<<": add write FD "<addWriteFD(d_fd, callback, callbackData, ttd ? &*ttd : nullptr); + } + else if (iostate == IOState::Done) { + d_currentState = IOState::Done; + cerr<<__PRETTY_FUNCTION__<<": done"<& d_mplexer; + int d_fd; + IOState d_currentState; +}; + +class IOStateGuard +{ +public: + /* this class is using RAII to make sure we don't forget to release an IOStateHandler + from the IO multiplexer in case of exception / error handling */ + IOStateGuard(std::unique_ptr& handler): d_handler(handler), d_enabled(true) + { + } + + ~IOStateGuard() + { + /* if we are still owning the state when we go out of scope, + let's reset the state so it's not registered to the IO multiplexer anymore + and its reference count goes to zero */ + if (d_enabled && d_handler) { + cerr<<"IOStateGuard destroyed while holding a state, let's reset it"<reset(); + d_enabled = false; + } + } + + void release() + { + d_enabled = false; + } + +private: + std::unique_ptr& d_handler; + bool d_enabled; +}; diff --git a/regression-tests.dnsdist/test_TCPShort.py b/regression-tests.dnsdist/test_TCPShort.py index de5aa4cdd3..55a0f5c93d 100644 --- a/regression-tests.dnsdist/test_TCPShort.py +++ b/regression-tests.dnsdist/test_TCPShort.py @@ -71,8 +71,8 @@ class TestTCPShort(DNSDistTest): # send announcedSize bytes minus 1 so we get a second read conn.send(wire) time.sleep(1) - # send 1024 bytes - conn.send(b'A' * 1024) + # send the remaining byte + conn.send(b'A') (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True) conn.close() @@ -112,8 +112,8 @@ class TestTCPShort(DNSDistTest): # send announcedSize bytes minus 1 so we get a second read conn.send(wire) time.sleep(1) - # send 1024 bytes - conn.send(b'A' * 1024) + # send the remaining byte + conn.send(b'A') (receivedQuery, receivedResponse) = self.recvTCPResponseOverConnection(conn, True) conn.close() diff --git a/regression-tests.dnsdist/test_Tags.py b/regression-tests.dnsdist/test_Tags.py index 9a50c2b299..246416f620 100644 --- a/regression-tests.dnsdist/test_Tags.py +++ b/regression-tests.dnsdist/test_Tags.py @@ -3,7 +3,7 @@ import dns import clientsubnetoption from dnsdisttests import DNSDistTest -class TestBasics(DNSDistTest): +class TestTags(DNSDistTest): _config_template = """ newServer{address="127.0.0.1:%s"}