#include <atomic>
#include <netinet/tcp.h>
+#include "sstuff.hh"
+
using std::thread;
using std::atomic;
Let's start naively.
*/
-static int setupTCPDownstream(shared_ptr<DownstreamState> ds, uint16_t& downstreamFailures)
+static thread_local map<ComboAddress, std::deque<std::unique_ptr<Socket>>> t_downstreamSockets;
+static std::mutex tcpClientsCountMutex;
+static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
+uint64_t g_maxTCPQueuedConnections{1000};
+size_t g_maxTCPQueriesPerConn{0};
+size_t g_maxTCPConnectionDuration{0};
+size_t g_maxTCPConnectionsPerClient{0};
+bool g_useTCPSinglePipe{false};
+std::atomic<uint16_t> g_downstreamTCPCleanupInterval{60};
+
+static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState> ds, uint16_t& downstreamFailures, int timeout)
{
+ std::unique_ptr<Socket> result;
+
do {
vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
- int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
+ result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
try {
if (!IsAnyAddress(ds->sourceAddr)) {
- SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
+ SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
#ifdef IP_BIND_ADDRESS_NO_PORT
if (ds->ipBindAddrNoPort) {
- SSetsockopt(sock, SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+ SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
}
#endif
- SBind(sock, ds->sourceAddr);
+ result->bind(ds->sourceAddr, false);
}
- setNonBlocking(sock);
+ result->setNonBlocking();
#ifdef MSG_FASTOPEN
if (!ds->tcpFastOpen) {
- SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+ SConnectWithTimeout(result->getHandle(), ds->remote, timeout);
}
#else
- SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout);
+ SConnectWithTimeout(result->getHandle(), ds->remote, timeout);
#endif /* MSG_FASTOPEN */
- return sock;
+ return result;
}
catch(const std::runtime_error& e) {
- /* don't leak our file descriptor if SConnect() (for example) throws */
+ vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
downstreamFailures++;
- close(sock);
if (downstreamFailures > ds->retries) {
throw;
}
}
} while(downstreamFailures <= ds->retries);
- return -1;
+ return nullptr;
+}
+
+static std::unique_ptr<Socket> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, bool& isFresh)
+{
+ std::unique_ptr<Socket> result;
+
+ const auto& it = t_downstreamSockets.find(ds->remote);
+ if (it != t_downstreamSockets.end()) {
+ auto& list = it->second;
+ if (!list.empty()) {
+ result = std::move(list.front());
+ list.pop_front();
+ isFresh = false;
+ return result;
+ }
+ }
+
+ isFresh = true;
+ return setupTCPDownstream(ds, downstreamFailures, 0);
+}
+
+static void releaseDownstreamConnection(std::shared_ptr<DownstreamState>& ds, std::unique_ptr<Socket>&& socket)
+{
+ const auto& it = t_downstreamSockets.find(ds->remote);
+ if (it != t_downstreamSockets.end()) {
+ auto& list = it->second;
+ if (list.size() >= 20) {
+ /* too many connections queued already */
+ socket.reset();
+ return;
+ }
+ list.push_back(std::move(socket));
+ }
+ else {
+ t_downstreamSockets[ds->remote].push_back(std::move(socket));
+ }
}
struct ConnectionInfo
ConnectionInfo(): cs(nullptr), fd(-1)
{
}
+ ConnectionInfo(ConnectionInfo&& rhs)
+ {
+ remote = rhs.remote;
+ cs = rhs.cs;
+ rhs.cs = nullptr;
+ fd = rhs.fd;
+ rhs.fd = -1;
+ }
ConnectionInfo(const ConnectionInfo& rhs) = delete;
ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
int fd{-1};
};
-uint64_t g_maxTCPQueuedConnections{1000};
-size_t g_maxTCPQueriesPerConn{0};
-size_t g_maxTCPConnectionDuration{0};
-size_t g_maxTCPConnectionsPerClient{0};
-static std::mutex tcpClientsCountMutex;
-static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
-bool g_useTCPSinglePipe{false};
-std::atomic<uint16_t> g_downstreamTCPCleanupInterval{60};
-
void tcpClientThread(int pipefd);
static void decrementTCPClientCount(const ComboAddress& client)
++d_numthreads;
}
-static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
-try
+static void cleanupClosedTCPConnections()
{
- uint16_t raw;
- size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
- if(ret != sizeof raw)
- return false;
- *len = ntohs(raw);
- return true;
-}
-catch(...) {
- return false;
-}
+ for(auto dsIt = t_downstreamSockets.begin(); dsIt != t_downstreamSockets.end(); ) {
+ for (auto socketIt = dsIt->second.begin(); socketIt != dsIt->second.end(); ) {
+ if (*socketIt && isTCPSocketUsable((*socketIt)->getHandle())) {
+ ++socketIt;
+ }
+ else {
+ socketIt = dsIt->second.erase(socketIt);
+ }
+ }
-static bool getNonBlockingMsgLenFromClient(TCPIOHandler& handler, uint16_t* len)
-try
-{
- uint16_t raw;
- size_t ret = handler.read(&raw, sizeof raw, g_tcpRecvTimeout);
- if(ret != sizeof raw)
- return false;
- *len = ntohs(raw);
- return true;
-}
-catch(...) {
- return false;
+ if (!dsIt->second.empty()) {
+ ++dsIt;
+ }
+ else {
+ dsIt = t_downstreamSockets.erase(dsIt);
+ }
+ }
}
-static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
+/* Tries to read exactly toRead bytes into the buffer, starting at position pos.
+ Updates pos everytime a successful read occurs,
+ throws an std::runtime_error in case of IO error,
+ return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+ would block.
+*/
+// XXX could probably be implemented as a TCPIOHandler
+IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
{
- if (maxConnectionDuration) {
- time_t curtime = time(nullptr);
- unsigned int elapsed = 0;
- if (curtime > start) { // To prevent issues when time goes backward
- elapsed = curtime - start;
+ size_t got = 0;
+ do {
+ ssize_t res = ::read(fd, reinterpret_cast<char*>(&buffer.at(pos)), toRead - got);
+ if (res == 0) {
+ throw runtime_error("EOF while reading message");
}
- if (elapsed >= maxConnectionDuration) {
- return true;
+ if (res < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ return IOState::NeedRead;
+ }
+ else {
+ throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno));
+ }
}
- remainingTime = maxConnectionDuration - elapsed;
+
+ pos += static_cast<size_t>(res);
+ got += static_cast<size_t>(res);
}
- return false;
+ while (got < toRead);
+
+ return IOState::Done;
}
-static void cleanupClosedTCPConnections(std::map<ComboAddress,int>& sockets)
+std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+
+class TCPClientThreadData
+{
+public:
+ TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+ {
+ }
+
+ LocalHolders holders;
+ LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
+ std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+};
+
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+
+class IncomingTCPConnectionState
{
- for(auto it = sockets.begin(); it != sockets.end(); ) {
- if (isTCPSocketUsable(it->second)) {
- ++it;
+public:
+ IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, time_t now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now), d_connectionStartTime(now)
+ {
+ d_ids.origDest.reset();
+ d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
+ socklen_t socklen = d_ids.origDest.getSocklen();
+ if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
+ d_ids.origDest = d_ci.cs->local;
+ }
+ }
+
+ IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
+ IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
+
+ ~IncomingTCPConnectionState()
+ {
+ decrementTCPClientCount(d_ci.remote);
+
+ if (d_ds != nullptr) {
+ if (d_outstanding) {
+ --d_ds->outstanding;
+ }
+
+ if (d_downstreamSocket) {
+ try {
+ if (d_lastIOState == IOState::NeedRead) {
+ cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamSocket->getHandle()<<endl;
+ d_threadData.mplexer->removeReadFD(d_downstreamSocket->getHandle());
+ }
+ else if (d_lastIOState == IOState::NeedWrite) {
+ cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamSocket->getHandle()<<endl;
+ d_threadData.mplexer->removeWriteFD(d_downstreamSocket->getHandle());
+ }
+ }
+ catch(const FDMultiplexerException& e) {
+ vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+ }
+ }
+ }
+
+ try {
+ if (d_lastIOState == IOState::NeedRead) {
+ cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
+ d_threadData.mplexer->removeReadFD(d_ci.fd);
+ }
+ else if (d_lastIOState == IOState::NeedWrite) {
+ cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
+ d_threadData.mplexer->removeWriteFD(d_ci.fd);
+ }
+ }
+ catch(const FDMultiplexerException& e) {
+ vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+ }
+ }
+
+ void resetForNewQuery()
+ {
+ d_buffer.resize(sizeof(uint16_t));
+ d_currentPos = 0;
+ d_querySize = 0;
+ d_responseSize = 0;
+ d_downstreamFailures = 0;
+ d_state = State::readingQuerySize;
+ d_lastIOState = IOState::Done;
+ }
+
+ boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
+ {
+ if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
+ return boost::none;
+ }
+
+ if (g_maxTCPConnectionDuration > 0) {
+ auto elapsed = now.tv_sec - d_connectionStartTime;
+ if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+ return now;
+ }
+ auto remaining = g_maxTCPConnectionDuration - elapsed;
+ if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
+ now.tv_sec += remaining;
+ return now;
+ }
+ }
+
+ now.tv_sec += g_tcpRecvTimeout;
+ return now;
+ }
+
+ boost::optional<struct timeval> getBackendReadTTD() const
+ {
+ if (d_ds == nullptr) {
+ throw std::runtime_error("getBackendReadTTD() without any backend selected");
+ }
+ if (d_ds->tcpRecvTimeout == 0) {
+ return boost::none;
+ }
+
+ struct timeval res;
+ gettimeofday(&res, 0);
+
+ res.tv_sec += d_ds->tcpRecvTimeout;
+
+ return res;
+ }
+
+ boost::optional<struct timeval> getClientWriteTTD(boost::optional<struct timeval> now=boost::none) const
+ {
+ if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
+ return boost::none;
+ }
+
+ struct timeval res;
+ if (now) {
+ res = *now;
}
else {
- close(it->second);
- it = sockets.erase(it);
+ gettimeofday(&res, 0);
+ }
+
+ if (g_maxTCPConnectionDuration > 0) {
+ auto elapsed = res.tv_sec - d_connectionStartTime;
+ if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
+ return res;
+ }
+ auto remaining = g_maxTCPConnectionDuration - elapsed;
+ if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
+ res.tv_sec += remaining;
+ return res;
+ }
+ }
+
+ res.tv_sec += g_tcpSendTimeout;
+ return res;
+ }
+
+ boost::optional<struct timeval> getBackendWriteTTD() const
+ {
+ if (d_ds == nullptr) {
+ throw std::runtime_error("getBackendReadTTD() called without any backend selected");
+ }
+ if (d_ds->tcpSendTimeout == 0) {
+ return boost::none;
}
+
+ struct timeval res;
+ gettimeofday(&res, 0);
+
+ res.tv_sec += d_ds->tcpSendTimeout;
+
+ return res;
}
+
+ bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval now)
+ {
+ if (maxConnectionDuration) {
+ time_t curtime = now.tv_sec;
+ unsigned int elapsed = 0;
+ if (curtime > d_connectionStartTime) { // To prevent issues when time goes backward
+ elapsed = curtime - d_connectionStartTime;
+ }
+ if (elapsed >= maxConnectionDuration) {
+ return true;
+ }
+ d_remainingTime = maxConnectionDuration - elapsed;
+ }
+
+ return false;
+ }
+
+ enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
+
+ std::vector<uint8_t> d_buffer;
+ std::vector<uint8_t> d_responseBuffer;
+ TCPClientThreadData& d_threadData;
+ IDState d_ids;
+ ConnectionInfo d_ci;
+ TCPIOHandler d_handler;
+ std::unique_ptr<Socket> d_downstreamSocket{nullptr};
+ std::shared_ptr<DownstreamState> d_ds{nullptr};
+ size_t d_currentPos{0};
+ size_t d_queriesCount{0};
+ time_t d_connectionStartTime;
+ unsigned int d_remainingTime{0};
+ uint16_t d_querySize{0};
+ uint16_t d_responseSize{0};
+ uint16_t d_downstreamFailures{0};
+ State d_state{State::doingHandshake};
+ IOState d_lastIOState{IOState::Done};
+ bool d_freshDownstreamConnection{false};
+ bool d_readingFirstQuery{true};
+ bool d_outstanding{false};
+ bool d_firstResponsePacket{true};
+ bool d_isXFR{false};
+ bool d_xfrStarted{false};
+};
+
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
+
+static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+ handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
+
+ if (state->d_isXFR && state->d_downstreamSocket) {
+ /* we need to resume reading from the backend! */
+ state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+ state->d_currentPos = 0;
+ //cerr<<__func__<<": add read client FD "<<state->d_ci.fd<<endl;
+ handleNewIOState(state, IOState::NeedRead, state->d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendReadTTD());
+ return;
+ }
+
+ if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
+ return;
+ }
+
+ struct timeval now;
+ gettimeofday(&now, 0);
+ if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+ return;
+ }
+
+ state->resetForNewQuery();
+ //cerr<<__func__<<": add read client FD "<<state->d_ci.fd<<endl;
+ handleNewIOState(state, IOState::NeedRead, state->d_ci.fd, handleIOCallback, state->getClientReadTTD(now));
}
-std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
+static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+ state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+ const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
+ /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+ that could occur if we had to deal with the size during the processing,
+ especially alignment issues */
+ state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
-void tcpClientThread(int pipefd)
+ state->d_currentPos = 0;
+
+ auto iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
+ if (iostate == IOState::Done) {
+
+ handleResponseSent(state);
+ return;
+ }
+ else {
+ //cerr<<__func__<<": adding client write FD "<<state->d_ci.fd<<endl;
+ handleNewIOState(state, IOState::NeedWrite, state->d_ci.fd, handleIOCallback, state->getClientWriteTTD());
+ }
+}
+
+static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state)
{
- /* we get launched with a pipe on which we receive file descriptors from clients that we own
- from that point on */
+ if (state->d_responseSize < sizeof(dnsheader)) {
+ return;
+ }
- setThreadName("dnsdist/tcpClie");
+ auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
+ unsigned int consumed;
+ if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
+ return;
+ }
+ state->d_firstResponsePacket = false;
+
+ if (state->d_outstanding) {
+ --state->d_ds->outstanding;
+ state->d_outstanding = false;
+ }
+
+ auto dh = reinterpret_cast<struct dnsheader*>(response);
+ uint16_t addRoom = 0;
+ DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
+ if (dr.dnsCryptQuery) {
+ addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+ }
- bool outstanding = false;
- time_t lastTCPCleanup = time(nullptr);
-
- LocalHolders holders;
- auto localRespRulactions = g_resprulactions.getLocal();
- /* when the answer is encrypted in place, we need to get a copy
- of the original header before encryption to fill the ring buffer */
dnsheader cleartextDH;
+ memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
- map<ComboAddress,int> sockets;
- for(;;) {
- ConnectionInfo* citmp, ci;
+ std::vector<uint8_t> rewrittenResponse;
+ size_t responseSize = state->d_responseBuffer.size();
+ if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+ return;
+ }
- try {
- readn2(pipefd, &citmp, sizeof(citmp));
- }
- catch(const std::runtime_error& e) {
- throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
- }
+ if (!rewrittenResponse.empty()) {
+ /* responseSize has been updated as well but we don't really care since it will match
+ the capacity of rewrittenResponse anyway */
+ state->d_responseBuffer = std::move(rewrittenResponse);
+ state->d_responseSize = state->d_responseBuffer.size();
+ } else {
+ /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
+ state->d_responseBuffer.resize(state->d_responseSize);
+ }
+
+ if (state->d_isXFR && !state->d_xfrStarted) {
+ /* don't bother parsing the content of the response for now */
+ state->d_xfrStarted = true;
+ }
- g_tcpclientthreads->decrementQueuedCount();
- ci=std::move(*citmp);
- delete citmp;
+ sendResponse(state);
+
+ ++g_stats.responses;
+ struct timespec answertime;
+ gettime(&answertime);
+ double udiff = state->d_ids.sentTime.udiff();
+ g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote);
+}
+
+static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+ auto ds = state->d_ds;
+ state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
+ state->d_currentPos = 0;
+ state->d_firstResponsePacket = true;
+ state->d_downstreamSocket.reset();
+
+ if (state->d_xfrStarted) {
+ /* sorry, but we are not going to resume a XFR if we have already sent some packets
+ to the client */
+ return;
+ }
- uint16_t qlen, rlen;
- vector<uint8_t> rewrittenResponse;
- shared_ptr<DownstreamState> ds;
- size_t queriesCount = 0;
- time_t connectionStartTime = time(nullptr);
- std::vector<char> queryBuffer;
- std::vector<char> answerBuffer;
+ while (state->d_downstreamFailures < state->d_ds->retries)
+ {
+ state->d_downstreamSocket = getConnectionToDownstream(ds, state->d_downstreamFailures, state->d_freshDownstreamConnection);
- ComboAddress dest;
- dest.reset();
- dest.sin4.sin_family = ci.remote.sin4.sin_family;
- socklen_t socklen = dest.getSocklen();
- if (getsockname(ci.fd, (sockaddr*)&dest, &socklen)) {
- dest = ci.cs->local;
+ if (!state->d_downstreamSocket) {
+ vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+ return;
}
- try {
- TCPIOHandler handler(ci.fd, g_tcpRecvTimeout, ci.cs->tlsFrontend ? ci.cs->tlsFrontend->getContext() : nullptr, connectionStartTime);
+ //cerr<<__func__<<": add write backend FD "<<state->d_downstreamSocket->getHandle()<<endl;
+ handleNewIOState(state, IOState::NeedWrite, state->d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendWriteTTD());
+ return;
+ }
- for(;;) {
- unsigned int remainingTime = 0;
- ds = nullptr;
- outstanding = false;
+ vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
+}
- if(!getNonBlockingMsgLenFromClient(handler, &qlen)) {
- break;
- }
+static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state)
+{
+ if (state->d_querySize < sizeof(dnsheader)) {
+ ++g_stats.nonCompliantQueries;
+ return;
+ }
- queriesCount++;
+ state->d_readingFirstQuery = false;
+ ++state->d_queriesCount;
+ ++state->d_ci.cs->queries;
+ ++g_stats.queries;
+
+ /* we need an accurate ("real") value for the response and
+ to store into the IDS, but not for insertion into the
+ rings for example */
+ struct timespec now;
+ struct timespec queryRealTime;
+ gettime(&now);
+ gettime(&queryRealTime, true);
+
+ auto query = reinterpret_cast<char*>(&state->d_buffer.at(0));
+ std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
+ auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
+ if (dnsCryptResponse) {
+ state->d_responseBuffer = std::move(*dnsCryptResponse);
+ state->d_responseSize = state->d_responseBuffer.size();
+ sendResponse(state);
+ return;
+ }
- if (qlen < sizeof(dnsheader)) {
- ++g_stats.nonCompliantQueries;
- break;
- }
+ const auto& dh = reinterpret_cast<dnsheader*>(query);
+ if (!checkQueryHeaders(dh)) {
+ return;
+ }
- ci.cs->queries++;
- ++g_stats.queries;
+ uint16_t qtype, qclass;
+ unsigned int consumed = 0;
+ DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
+ DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
+ dq.dnsCryptQuery = std::move(dnsCryptQuery);
- if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
- vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
- break;
- }
+ state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
+ if (state->d_isXFR) {
+ dq.skipCache = true;
+ }
- if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) {
- vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort());
- break;
- }
+ state->d_ds.reset();
+ auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
- /* allocate a bit more memory to be able to spoof the content,
- or to add ECS without allocating a new buffer */
- queryBuffer.resize((static_cast<size_t>(qlen) + 512) < 4096 ? (static_cast<size_t>(qlen) + 512) : 4096);
-
- char* query = &queryBuffer[0];
- handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
-
- /* we need an accurate ("real") value for the response and
- to store into the IDS, but not for insertion into the
- rings for example */
- struct timespec now;
- struct timespec queryRealTime;
- gettime(&now);
- gettime(&queryRealTime, true);
-
- std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
- auto dnsCryptResponse = checkDNSCryptQuery(*ci.cs, query, qlen, dnsCryptQuery, queryRealTime.tv_sec, true);
- if (dnsCryptResponse) {
- handler.writeSizeAndMsg(reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), g_tcpSendTimeout);
- continue;
- }
+ if (result == ProcessQueryResult::Drop) {
+ return;
+ }
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
- if (!checkQueryHeaders(dh)) {
- break;
- }
+ if (result == ProcessQueryResult::SendAnswer) {
+ state->d_buffer.resize(dq.len);
+ state->d_responseBuffer = std::move(state->d_buffer);
+ state->d_responseSize = state->d_responseBuffer.size();
+ sendResponse(state);
+ return;
+ }
- uint16_t qtype, qclass;
- unsigned int consumed = 0;
- DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime);
- dq.dnsCryptQuery = std::move(dnsCryptQuery);
+ if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
+ return;
+ }
- std::shared_ptr<DownstreamState> ds{nullptr};
- auto result = processQuery(dq, *ci.cs, holders, ds);
+ state->d_buffer.resize(dq.len);
+ setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
- if (result == ProcessQueryResult::Drop) {
- break;
- }
+ const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
+ /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+ that could occur if we had to deal with the size during the processing,
+ especially alignment issues */
+ state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2);
+ sendQueryToBackend(state);
+}
- if (result == ProcessQueryResult::SendAnswer) {
- handler.writeSizeAndMsg(reinterpret_cast<char*>(dq.dh), dq.len, g_tcpSendTimeout);
- continue;
- }
+static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
+{
+ //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
- if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
- break;
- }
+ if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
+ state->d_threadData.mplexer->removeReadFD(fd);
+ //cerr<<__func__<<": remove read FD "<<fd<<endl;
+ state->d_lastIOState = IOState::Done;
+ }
+ else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
+ state->d_threadData.mplexer->removeWriteFD(fd);
+ //cerr<<__func__<<": remove write FD "<<fd<<endl;
+ state->d_lastIOState = IOState::Done;
+ }
- int dsock = -1;
- uint16_t downstreamFailures=0;
-#ifdef MSG_FASTOPEN
- bool freshConn = true;
-#endif /* MSG_FASTOPEN */
- if(sockets.count(ds->remote) == 0) {
- dsock=setupTCPDownstream(ds, downstreamFailures);
- sockets[ds->remote]=dsock;
- }
- else {
- dsock=sockets[ds->remote];
-#ifdef MSG_FASTOPEN
- freshConn = false;
-#endif /* MSG_FASTOPEN */
- }
+ if (iostate == IOState::NeedRead) {
+ if (state->d_lastIOState == IOState::NeedRead) {
+ if (ttd) {
+ /* let's update the TTD ! */
+ state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
+ }
+ return;
+ }
- ds->outstanding++;
- outstanding = true;
+ state->d_lastIOState = IOState::NeedRead;
+ //cerr<<__func__<<": add read FD "<<fd<<endl;
+ state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
+ }
+ else if (iostate == IOState::NeedWrite) {
+ if (state->d_lastIOState == IOState::NeedWrite) {
+ return;
+ }
- retry:;
- if (dsock < 0) {
- sockets.erase(ds->remote);
- break;
- }
+ state->d_lastIOState = IOState::NeedWrite;
+ //cerr<<__func__<<": add write FD "<<fd<<endl;
+ state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
+ }
+ else if (iostate == IOState::Done) {
+ state->d_lastIOState = IOState::Done;
+ }
+}
- if (ds->retries > 0 && downstreamFailures > ds->retries) {
- vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstreamFailures);
- close(dsock);
- dsock=-1;
- sockets.erase(ds->remote);
- break;
- }
+static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+ auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+ if (state->d_downstreamSocket == nullptr) {
+ throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+ }
+ if (fd != state->d_downstreamSocket->getHandle()) {
+ throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamSocket->getHandle()));
+ }
- try {
- int socketFlags = 0;
-#ifdef MSG_FASTOPEN
- if (ds->tcpFastOpen && freshConn) {
- socketFlags |= MSG_FASTOPEN;
- }
-#endif /* MSG_FASTOPEN */
- sendSizeAndMsgWithTimeout(dsock, dq.len, query, ds->tcpSendTimeout, &ds->remote, &ds->sourceAddr, ds->sourceItf, 0, socketFlags);
- }
- catch(const runtime_error& e) {
- vinfolog("Downstream connection to %s died on us (%s), getting a new one!", ds->getName(), e.what());
- close(dsock);
- dsock=-1;
- sockets.erase(ds->remote);
- downstreamFailures++;
- dsock=setupTCPDownstream(ds, downstreamFailures);
- sockets[ds->remote]=dsock;
-#ifdef MSG_FASTOPEN
- freshConn=true;
-#endif /* MSG_FASTOPEN */
- goto retry;
- }
+ IOState iostate = IOState::Done;
+ bool connectionDied = false;
- bool xfrStarted = false;
- bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
- if (isXFR) {
- dq.skipCache = true;
- }
- bool firstPacket=true;
- getpacket:;
-
- if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
- vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
- close(dsock);
- dsock=-1;
- sockets.erase(ds->remote);
- downstreamFailures++;
- dsock=setupTCPDownstream(ds, downstreamFailures);
- sockets[ds->remote]=dsock;
+ try {
+ if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
+ int socketFlags = 0;
#ifdef MSG_FASTOPEN
- freshConn=true;
+ if (state->d_ds->tcpFastOpen && state->d_freshDownstreamConnection) {
+ socketFlags |= MSG_FASTOPEN;
+ }
#endif /* MSG_FASTOPEN */
- if(xfrStarted) {
- break;
- }
- goto retry;
- }
- size_t responseSize = rlen;
- uint16_t addRoom = 0;
- if (dq.dnsCryptQuery && (UINT16_MAX - rlen) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) {
- addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+ size_t sent = sendMsgWithTimeout(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, 0, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, 0, socketFlags);
+ if (sent == state->d_buffer.size()) {
+ /* request sent ! */
+ state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+ state->d_currentPos = 0;
+ iostate = IOState::NeedRead;
+ if (!state->d_isXFR) {
+ /* don't bother with the outstanding count for XFR queries */
+ ++state->d_ds->outstanding;
+ state->d_outstanding = true;
}
+ }
+ else {
+ state->d_currentPos += sent;
+ iostate = IOState::NeedWrite;
+ /* disable fast open on partial write */
+ state->d_freshDownstreamConnection = false;
+ }
+ }
- responseSize += addRoom;
- answerBuffer.resize(responseSize);
- char* response = answerBuffer.data();
- readn2WithTimeout(dsock, response, rlen, ds->tcpRecvTimeout);
- uint16_t responseLen = rlen;
- if (outstanding) {
- /* might be false for {A,I}XFR */
- --ds->outstanding;
- outstanding = false;
- }
+ if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
+ // then we need to allocate a new buffer (new because we might need to re-send the query if the
+ // backend dies on us
+ // We also might need to read and send to the client more than one response in case of XFR (yeah!)
+ // should very likely be a TCPIOHandler d_downstreamHandler
+ iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+ if (iostate == IOState::Done) {
+ state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
+ state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
+ state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
+ state->d_currentPos = 0;
+ }
+ }
- if (rlen < sizeof(dnsheader)) {
- break;
- }
+ if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
+ iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
+ if (iostate == IOState::Done) {
+ handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
- consumed = 0;
- if (firstPacket && !responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote, consumed)) {
- break;
+ if (state->d_isXFR) {
+ /* Don't reuse the TCP connection after an {A,I}XFR */
+ /* but don't reset it either, we will need to read more messages */
}
- firstPacket=false;
-
- dh = reinterpret_cast<struct dnsheader*>(response);
- DNSResponse dr(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
- dr.origFlags = dq.origFlags;
- dr.ecsAdded = dq.ecsAdded;
- dr.ednsAdded = dq.ednsAdded;
- dr.useZeroScope = dq.useZeroScope;
- dr.packetCache = std::move(dq.packetCache);
- dr.delayMsec = dq.delayMsec;
- dr.skipCache = dq.skipCache;
- dr.cacheKey = dq.cacheKey;
- dr.cacheKeyNoECS = dq.cacheKeyNoECS;
- dr.dnssecOK = dq.dnssecOK;
- dr.tempFailureTTL = dq.tempFailureTTL;
- dr.qTag = std::move(dq.qTag);
- dr.subnet = std::move(dq.subnet);
-#ifdef HAVE_PROTOBUF
- dr.uniqueId = std::move(dq.uniqueId);
-#endif
- if (dq.dnsCryptQuery) {
- dr.dnsCryptQuery = std::move(dq.dnsCryptQuery);
+ else {
+ releaseDownstreamConnection(state->d_ds, std::move(state->d_downstreamSocket));
}
+ fd = -1;
- memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
- if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
- break;
- }
+ handleResponse(state);
+ return;
+ }
+ }
- if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) {
- break;
- }
+ if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
+ state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
+ state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
+ vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
+ }
+ }
+ catch(const std::exception& e) {
+ /* most likely an EOF because the other end closed the connection,
+ but it might also be a real IO error or something else.
+ Let's just drop the connection
+ */
+ vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
+ /* remove this FD from the IO multiplexer */
+ ++state->d_downstreamFailures;
+ if (state->d_outstanding && state->d_ds != nullptr) {
+ --state->d_ds->outstanding;
+ }
+ iostate = IOState::Done;
+ connectionDied = true;
+ }
- if (isXFR) {
- if (dh->rcode == 0 && dh->ancount != 0) {
- if (xfrStarted == false) {
- xfrStarted = true;
- if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
- goto getpacket;
- }
- }
- else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
- goto getpacket;
- }
- }
- /* Don't reuse the TCP connection after an {A,I}XFR */
- close(dsock);
- dsock=-1;
- sockets.erase(ds->remote);
- }
+ if (iostate == IOState::Done) {
+ handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
+ }
+ else {
+ handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD() : state->getBackendWriteTTD());
+ }
+
+ if (connectionDied) {
+ sendQueryToBackend(state);
+ }
+}
+
+static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+ auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+ if (fd != state->d_ci.fd) {
+ throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
+ }
+
+ IOState iostate = IOState::Done;
+
+ struct timeval now;
+ gettimeofday(&now, 0);
+ if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+ handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+ return;
+ }
- ++g_stats.responses;
- switch (dr.dh->rcode) {
- case RCode::NXDomain:
- ++g_stats.frontendNXDomain;
- break;
- case RCode::ServFail:
- ++g_stats.frontendServFail;
- break;
- case RCode::NoError:
- ++g_stats.frontendNoError;
- break;
+ try {
+ if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+ iostate = state->d_handler.tryHandshake();
+ if (iostate == IOState::Done) {
+ state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+ }
+ }
+
+ if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+ iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
+ if (iostate == IOState::Done) {
+ state->d_state = IncomingTCPConnectionState::State::readingQuery;
+ state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
+ if (state->d_querySize < sizeof(dnsheader)) {
+ /* go away */
+ handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+ return;
}
- struct timespec answertime;
- gettime(&answertime);
- unsigned int udiff = 1000000.0*DiffTime(now,answertime);
- g_rings.insertResponse(answertime, ci.remote, qname, dq.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(responseLen), cleartextDH, ds->remote);
- rewrittenResponse.clear();
+ /* allocate a bit more memory to be able to spoof the content,
+ or to add ECS without allocating a new buffer */
+ state->d_buffer.resize(state->d_querySize + 512);
+ state->d_currentPos = 0;
+ }
+ }
+
+ if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+ iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+ if (iostate == IOState::Done) {
+ handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+ handleQuery(state);
+ return;
}
}
- catch(const std::exception& e) {
- vinfolog("Got exception while handling TCP query: %s", e.what());
+
+ if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+ iostate = state->d_handler.tryWrite(state->d_buffer, state->d_currentPos, state->d_buffer.size());
+ if (iostate == IOState::Done) {
+ handleResponseSent(state);
+ return;
+ }
+ }
+
+ if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
+ state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+ state->d_state != IncomingTCPConnectionState::State::readingQuery &&
+ state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
+ vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
}
- catch(...) {
+ }
+ catch(const std::exception& e) {
+ /* most likely an EOF because the other end closed the connection,
+ but it might also be a real IO error or something else.
+ Let's just drop the connection
+ */
+ if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
+ vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
}
+ else {
+ vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+ }
+ /* remove this FD from the IO multiplexer */
+ iostate = IOState::Done;
+ }
+
+ if (iostate == IOState::Done) {
+ handleNewIOState(state, iostate, fd, handleIOCallback);
+ }
+ else {
+ handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+ }
+}
+
+static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
+{
+ auto threadData = boost::any_cast<TCPClientThreadData*>(param);
+
+ ConnectionInfo* citmp{nullptr};
+
+ try {
+ readn2(pipefd, &citmp, sizeof(citmp));
+ }
+ catch(const std::runtime_error& e) {
+ throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
+ }
+
+ g_tcpclientthreads->decrementQueuedCount();
+ auto ci = std::move(*citmp);
+ delete citmp;
+ citmp = nullptr;
+
+ struct timeval now;
+ gettimeofday(&now, 0);
+ auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now.tv_sec);
- vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
+ /* let's update the remaining time */
+ state->d_remainingTime = g_maxTCPConnectionDuration;
- if (ds && outstanding) {
- outstanding = false;
- --ds->outstanding;
+ /* we could try reading right away, but let's not for now */
+ handleNewIOState(state, IOState::NeedRead, state->d_ci.fd, handleIOCallback, state->getClientReadTTD(now));
+}
+
+void tcpClientThread(int pipefd)
+{
+ /* we get launched with a pipe on which we receive file descriptors from clients that we own
+ from that point on */
+
+ setThreadName("dnsdist/tcpClie");
+
+ TCPClientThreadData data;
+
+ data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data);
+ time_t lastTCPCleanup = time(nullptr);
+ struct timeval now;
+ gettimeofday(&now, 0);
+
+ for (;;) {
+ data.mplexer->run(&now);
+
+ if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
+ cleanupClosedTCPConnections();
+ lastTCPCleanup = now.tv_sec;
+ }
+
+ auto expiredReadConns = data.mplexer->getTimeouts(now, false);
+ for(const auto& conn : expiredReadConns) {
+ auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+ if (conn.first == state->d_ci.fd) {
+ vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+ }
+ else if (state->d_ds) {
+ vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
+ }
+ data.mplexer->removeReadFD(conn.first);
+ state->d_lastIOState = IOState::Done;
}
- decrementTCPClientCount(ci.remote);
- if (g_downstreamTCPCleanupInterval > 0 && (connectionStartTime > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
- cleanupClosedTCPConnections(sockets);
- lastTCPCleanup = time(nullptr);
+ auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
+ for(const auto& conn : expiredWriteConns) {
+ auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
+ if (conn.first == state->d_ci.fd) {
+ vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+ }
+ else if (state->d_ds) {
+ vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
+ }
+ data.mplexer->removeWriteFD(conn.first);
+ state->d_lastIOState = IOState::Done;
}
}
}
-/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
+/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
they will hand off to worker threads & spawn more of them if required
*/
void tcpAcceptorThread(void* p)
bool tcpClientCountIncremented = false;
ComboAddress remote;
remote.sin4.sin_family = cs->local.sin4.sin_family;
-
+
g_tcpclientthreads->addTCPClientThread();
auto acl = g_ACL.getLocal();