From: Remi Gacogne Date: Thu, 28 Feb 2019 14:39:40 +0000 (+0100) Subject: dnsdist: TCP refactoring using an event-based logic X-Git-Tag: dnsdist-1.4.0-alpha1~25^2~24 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=d0ae6360966e65c0711d4a98b1b45498989856cb;p=thirdparty%2Fpdns.git dnsdist: TCP refactoring using an event-based logic --- diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index 9390a55df8..5e8974d698 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -257,10 +257,10 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload dh.d_class = htons(udpPayloadSize); static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)"); memcpy(&dh.d_ttl, &edns0, sizeof edns0); - dh.d_clen = htons((uint16_t) optRData.length()); + dh.d_clen = htons(static_cast(optRData.length())); res.reserve(sizeof(name) + sizeof(dh) + optRData.length()); - res.assign((const char *) &name, sizeof name); - res.append((const char *) &dh, sizeof dh); + res.assign(reinterpret_cast(&name), sizeof name); + res.append(reinterpret_cast(&dh), sizeof(dh)); res.append(optRData.c_str(), optRData.length()); } diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index bf66072fff..b8e8d591d6 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -35,6 +35,8 @@ #include #include +#include "sstuff.hh" + using std::thread; using std::atomic; @@ -53,42 +55,89 @@ using std::atomic; Let's start naively. */ -static int setupTCPDownstream(shared_ptr ds, uint16_t& downstreamFailures) +static thread_local map>> t_downstreamSockets; +static std::mutex tcpClientsCountMutex; +static std::map tcpClientsCount; +uint64_t g_maxTCPQueuedConnections{1000}; +size_t g_maxTCPQueriesPerConn{0}; +size_t g_maxTCPConnectionDuration{0}; +size_t g_maxTCPConnectionsPerClient{0}; +bool g_useTCPSinglePipe{false}; +std::atomic g_downstreamTCPCleanupInterval{60}; + +static std::unique_ptr setupTCPDownstream(shared_ptr ds, uint16_t& downstreamFailures, int timeout) { + std::unique_ptr result; + do { vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures); - int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0); + result = std::unique_ptr(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0)); try { if (!IsAnyAddress(ds->sourceAddr)) { - SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1); + SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); #ifdef IP_BIND_ADDRESS_NO_PORT if (ds->ipBindAddrNoPort) { - SSetsockopt(sock, SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); + SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); } #endif - SBind(sock, ds->sourceAddr); + result->bind(ds->sourceAddr, false); } - setNonBlocking(sock); + result->setNonBlocking(); #ifdef MSG_FASTOPEN if (!ds->tcpFastOpen) { - SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout); + SConnectWithTimeout(result->getHandle(), ds->remote, timeout); } #else - SConnectWithTimeout(sock, ds->remote, ds->tcpConnectTimeout); + SConnectWithTimeout(result->getHandle(), ds->remote, timeout); #endif /* MSG_FASTOPEN */ - return sock; + return result; } catch(const std::runtime_error& e) { - /* don't leak our file descriptor if SConnect() (for example) throws */ + vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what()); downstreamFailures++; - close(sock); if (downstreamFailures > ds->retries) { throw; } } } while(downstreamFailures <= ds->retries); - return -1; + return nullptr; +} + +static std::unique_ptr getConnectionToDownstream(std::shared_ptr& ds, uint16_t& downstreamFailures, bool& isFresh) +{ + std::unique_ptr result; + + const auto& it = t_downstreamSockets.find(ds->remote); + if (it != t_downstreamSockets.end()) { + auto& list = it->second; + if (!list.empty()) { + result = std::move(list.front()); + list.pop_front(); + isFresh = false; + return result; + } + } + + isFresh = true; + return setupTCPDownstream(ds, downstreamFailures, 0); +} + +static void releaseDownstreamConnection(std::shared_ptr& ds, std::unique_ptr&& socket) +{ + const auto& it = t_downstreamSockets.find(ds->remote); + if (it != t_downstreamSockets.end()) { + auto& list = it->second; + if (list.size() >= 20) { + /* too many connections queued already */ + socket.reset(); + return; + } + list.push_back(std::move(socket)); + } + else { + t_downstreamSockets[ds->remote].push_back(std::move(socket)); + } } struct ConnectionInfo @@ -96,6 +145,14 @@ struct ConnectionInfo ConnectionInfo(): cs(nullptr), fd(-1) { } + ConnectionInfo(ConnectionInfo&& rhs) + { + remote = rhs.remote; + cs = rhs.cs; + rhs.cs = nullptr; + fd = rhs.fd; + rhs.fd = -1; + } ConnectionInfo(const ConnectionInfo& rhs) = delete; ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete; @@ -123,15 +180,6 @@ struct ConnectionInfo int fd{-1}; }; -uint64_t g_maxTCPQueuedConnections{1000}; -size_t g_maxTCPQueriesPerConn{0}; -size_t g_maxTCPConnectionDuration{0}; -size_t g_maxTCPConnectionsPerClient{0}; -static std::mutex tcpClientsCountMutex; -static std::map tcpClientsCount; -bool g_useTCPSinglePipe{false}; -std::atomic g_downstreamTCPCleanupInterval{60}; - void tcpClientThread(int pipefd); static void decrementTCPClientCount(const ComboAddress& client) @@ -201,392 +249,814 @@ void TCPClientCollection::addTCPClientThread() ++d_numthreads; } -static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout) -try +static void cleanupClosedTCPConnections() { - uint16_t raw; - size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout); - if(ret != sizeof raw) - return false; - *len = ntohs(raw); - return true; -} -catch(...) { - return false; -} + for(auto dsIt = t_downstreamSockets.begin(); dsIt != t_downstreamSockets.end(); ) { + for (auto socketIt = dsIt->second.begin(); socketIt != dsIt->second.end(); ) { + if (*socketIt && isTCPSocketUsable((*socketIt)->getHandle())) { + ++socketIt; + } + else { + socketIt = dsIt->second.erase(socketIt); + } + } -static bool getNonBlockingMsgLenFromClient(TCPIOHandler& handler, uint16_t* len) -try -{ - uint16_t raw; - size_t ret = handler.read(&raw, sizeof raw, g_tcpRecvTimeout); - if(ret != sizeof raw) - return false; - *len = ntohs(raw); - return true; -} -catch(...) { - return false; + if (!dsIt->second.empty()) { + ++dsIt; + } + else { + dsIt = t_downstreamSockets.erase(dsIt); + } + } } -static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime) +/* Tries to read exactly toRead bytes into the buffer, starting at position pos. + Updates pos everytime a successful read occurs, + throws an std::runtime_error in case of IO error, + return Done when toRead bytes have been read, needRead or needWrite if the IO operation + would block. +*/ +// XXX could probably be implemented as a TCPIOHandler +IOState tryRead(int fd, std::vector& buffer, size_t& pos, size_t toRead) { - if (maxConnectionDuration) { - time_t curtime = time(nullptr); - unsigned int elapsed = 0; - if (curtime > start) { // To prevent issues when time goes backward - elapsed = curtime - start; + size_t got = 0; + do { + ssize_t res = ::read(fd, reinterpret_cast(&buffer.at(pos)), toRead - got); + if (res == 0) { + throw runtime_error("EOF while reading message"); } - if (elapsed >= maxConnectionDuration) { - return true; + if (res < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return IOState::NeedRead; + } + else { + throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno)); + } } - remainingTime = maxConnectionDuration - elapsed; + + pos += static_cast(res); + got += static_cast(res); } - return false; + while (got < toRead); + + return IOState::Done; } -static void cleanupClosedTCPConnections(std::map& sockets) +std::shared_ptr g_tcpclientthreads; + +class TCPClientThreadData +{ +public: + TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr(FDMultiplexer::getMultiplexerSilent())) + { + } + + LocalHolders holders; + LocalStateHolder > localRespRulactions; + std::unique_ptr mplexer{nullptr}; +}; + +static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param); + +class IncomingTCPConnectionState { - for(auto it = sockets.begin(); it != sockets.end(); ) { - if (isTCPSocketUsable(it->second)) { - ++it; +public: + IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, time_t now): d_buffer(4096), d_responseBuffer(4096), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now), d_connectionStartTime(now) + { + d_ids.origDest.reset(); + d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family; + socklen_t socklen = d_ids.origDest.getSocklen(); + if (getsockname(d_ci.fd, reinterpret_cast(&d_ids.origDest), &socklen)) { + d_ids.origDest = d_ci.cs->local; + } + } + + IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete; + IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete; + + ~IncomingTCPConnectionState() + { + decrementTCPClientCount(d_ci.remote); + + if (d_ds != nullptr) { + if (d_outstanding) { + --d_ds->outstanding; + } + + if (d_downstreamSocket) { + try { + if (d_lastIOState == IOState::NeedRead) { + cerr<<__func__<<": removing leftover backend read FD "<getHandle()<removeReadFD(d_downstreamSocket->getHandle()); + } + else if (d_lastIOState == IOState::NeedWrite) { + cerr<<__func__<<": removing leftover backend write FD "<getHandle()<removeWriteFD(d_downstreamSocket->getHandle()); + } + } + catch(const FDMultiplexerException& e) { + vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what()); + } + } + } + + try { + if (d_lastIOState == IOState::NeedRead) { + cerr<<__func__<<": removing leftover client read FD "<removeReadFD(d_ci.fd); + } + else if (d_lastIOState == IOState::NeedWrite) { + cerr<<__func__<<": removing leftover client write FD "<removeWriteFD(d_ci.fd); + } + } + catch(const FDMultiplexerException& e) { + vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what()); + } + } + + void resetForNewQuery() + { + d_buffer.resize(sizeof(uint16_t)); + d_currentPos = 0; + d_querySize = 0; + d_responseSize = 0; + d_downstreamFailures = 0; + d_state = State::readingQuerySize; + d_lastIOState = IOState::Done; + } + + boost::optional getClientReadTTD(struct timeval now) const + { + if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) { + return boost::none; + } + + if (g_maxTCPConnectionDuration > 0) { + auto elapsed = now.tv_sec - d_connectionStartTime; + if (elapsed < 0 || (static_cast(elapsed) >= g_maxTCPConnectionDuration)) { + return now; + } + auto remaining = g_maxTCPConnectionDuration - elapsed; + if (g_tcpRecvTimeout == 0 || remaining <= static_cast(g_tcpRecvTimeout)) { + now.tv_sec += remaining; + return now; + } + } + + now.tv_sec += g_tcpRecvTimeout; + return now; + } + + boost::optional getBackendReadTTD() const + { + if (d_ds == nullptr) { + throw std::runtime_error("getBackendReadTTD() without any backend selected"); + } + if (d_ds->tcpRecvTimeout == 0) { + return boost::none; + } + + struct timeval res; + gettimeofday(&res, 0); + + res.tv_sec += d_ds->tcpRecvTimeout; + + return res; + } + + boost::optional getClientWriteTTD(boost::optional now=boost::none) const + { + if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) { + return boost::none; + } + + struct timeval res; + if (now) { + res = *now; } else { - close(it->second); - it = sockets.erase(it); + gettimeofday(&res, 0); + } + + if (g_maxTCPConnectionDuration > 0) { + auto elapsed = res.tv_sec - d_connectionStartTime; + if (elapsed < 0 || static_cast(elapsed) >= g_maxTCPConnectionDuration) { + return res; + } + auto remaining = g_maxTCPConnectionDuration - elapsed; + if (g_tcpSendTimeout == 0 || remaining <= static_cast(g_tcpSendTimeout)) { + res.tv_sec += remaining; + return res; + } + } + + res.tv_sec += g_tcpSendTimeout; + return res; + } + + boost::optional getBackendWriteTTD() const + { + if (d_ds == nullptr) { + throw std::runtime_error("getBackendReadTTD() called without any backend selected"); + } + if (d_ds->tcpSendTimeout == 0) { + return boost::none; } + + struct timeval res; + gettimeofday(&res, 0); + + res.tv_sec += d_ds->tcpSendTimeout; + + return res; } + + bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval now) + { + if (maxConnectionDuration) { + time_t curtime = now.tv_sec; + unsigned int elapsed = 0; + if (curtime > d_connectionStartTime) { // To prevent issues when time goes backward + elapsed = curtime - d_connectionStartTime; + } + if (elapsed >= maxConnectionDuration) { + return true; + } + d_remainingTime = maxConnectionDuration - elapsed; + } + + return false; + } + + enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse }; + + std::vector d_buffer; + std::vector d_responseBuffer; + TCPClientThreadData& d_threadData; + IDState d_ids; + ConnectionInfo d_ci; + TCPIOHandler d_handler; + std::unique_ptr d_downstreamSocket{nullptr}; + std::shared_ptr d_ds{nullptr}; + size_t d_currentPos{0}; + size_t d_queriesCount{0}; + time_t d_connectionStartTime; + unsigned int d_remainingTime{0}; + uint16_t d_querySize{0}; + uint16_t d_responseSize{0}; + uint16_t d_downstreamFailures{0}; + State d_state{State::doingHandshake}; + IOState d_lastIOState{IOState::Done}; + bool d_freshDownstreamConnection{false}; + bool d_readingFirstQuery{true}; + bool d_outstanding{false}; + bool d_firstResponsePacket{true}; + bool d_isXFR{false}; + bool d_xfrStarted{false}; +}; + +static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); +static void handleNewIOState(std::shared_ptr& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional ttd=boost::none); + +static void handleResponseSent(std::shared_ptr& state) +{ + handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback); + + if (state->d_isXFR && state->d_downstreamSocket) { + /* we need to resume reading from the backend! */ + state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend; + state->d_currentPos = 0; + //cerr<<__func__<<": add read client FD "<d_ci.fd<d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendReadTTD()); + return; + } + + if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) { + vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn); + return; + } + + struct timeval now; + gettimeofday(&now, 0); + if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); + return; + } + + state->resetForNewQuery(); + //cerr<<__func__<<": add read client FD "<d_ci.fd<d_ci.fd, handleIOCallback, state->getClientReadTTD(now)); } -std::shared_ptr g_tcpclientthreads; +static void sendResponse(std::shared_ptr& state) +{ + state->d_state = IncomingTCPConnectionState::State::sendingResponse; + const uint8_t sizeBytes[] = { static_cast(state->d_responseSize / 256), static_cast(state->d_responseSize % 256) }; + /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes + that could occur if we had to deal with the size during the processing, + especially alignment issues */ + state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2); -void tcpClientThread(int pipefd) + state->d_currentPos = 0; + + auto iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size()); + if (iostate == IOState::Done) { + + handleResponseSent(state); + return; + } + else { + //cerr<<__func__<<": adding client write FD "<d_ci.fd<d_ci.fd, handleIOCallback, state->getClientWriteTTD()); + } +} + +static void handleResponse(std::shared_ptr& state) { - /* we get launched with a pipe on which we receive file descriptors from clients that we own - from that point on */ + if (state->d_responseSize < sizeof(dnsheader)) { + return; + } - setThreadName("dnsdist/tcpClie"); + auto response = reinterpret_cast(&state->d_responseBuffer.at(0)); + unsigned int consumed; + if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) { + return; + } + state->d_firstResponsePacket = false; + + if (state->d_outstanding) { + --state->d_ds->outstanding; + state->d_outstanding = false; + } + + auto dh = reinterpret_cast(response); + uint16_t addRoom = 0; + DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true); + if (dr.dnsCryptQuery) { + addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; + } - bool outstanding = false; - time_t lastTCPCleanup = time(nullptr); - - LocalHolders holders; - auto localRespRulactions = g_resprulactions.getLocal(); - /* when the answer is encrypted in place, we need to get a copy - of the original header before encryption to fill the ring buffer */ dnsheader cleartextDH; + memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH)); - map sockets; - for(;;) { - ConnectionInfo* citmp, ci; + std::vector rewrittenResponse; + size_t responseSize = state->d_responseBuffer.size(); + if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) { + return; + } - try { - readn2(pipefd, &citmp, sizeof(citmp)); - } - catch(const std::runtime_error& e) { - throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what()); - } + if (!rewrittenResponse.empty()) { + /* responseSize has been updated as well but we don't really care since it will match + the capacity of rewrittenResponse anyway */ + state->d_responseBuffer = std::move(rewrittenResponse); + state->d_responseSize = state->d_responseBuffer.size(); + } else { + /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */ + state->d_responseBuffer.resize(state->d_responseSize); + } + + if (state->d_isXFR && !state->d_xfrStarted) { + /* don't bother parsing the content of the response for now */ + state->d_xfrStarted = true; + } - g_tcpclientthreads->decrementQueuedCount(); - ci=std::move(*citmp); - delete citmp; + sendResponse(state); + + ++g_stats.responses; + struct timespec answertime; + gettime(&answertime); + double udiff = state->d_ids.sentTime.udiff(); + g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast(udiff), static_cast(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote); +} + +static void sendQueryToBackend(std::shared_ptr& state) +{ + auto ds = state->d_ds; + state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend; + state->d_currentPos = 0; + state->d_firstResponsePacket = true; + state->d_downstreamSocket.reset(); + + if (state->d_xfrStarted) { + /* sorry, but we are not going to resume a XFR if we have already sent some packets + to the client */ + return; + } - uint16_t qlen, rlen; - vector rewrittenResponse; - shared_ptr ds; - size_t queriesCount = 0; - time_t connectionStartTime = time(nullptr); - std::vector queryBuffer; - std::vector answerBuffer; + while (state->d_downstreamFailures < state->d_ds->retries) + { + state->d_downstreamSocket = getConnectionToDownstream(ds, state->d_downstreamFailures, state->d_freshDownstreamConnection); - ComboAddress dest; - dest.reset(); - dest.sin4.sin_family = ci.remote.sin4.sin_family; - socklen_t socklen = dest.getSocklen(); - if (getsockname(ci.fd, (sockaddr*)&dest, &socklen)) { - dest = ci.cs->local; + if (!state->d_downstreamSocket) { + vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures); + return; } - try { - TCPIOHandler handler(ci.fd, g_tcpRecvTimeout, ci.cs->tlsFrontend ? ci.cs->tlsFrontend->getContext() : nullptr, connectionStartTime); + //cerr<<__func__<<": add write backend FD "<d_downstreamSocket->getHandle()<d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendWriteTTD()); + return; + } - for(;;) { - unsigned int remainingTime = 0; - ds = nullptr; - outstanding = false; + vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures); +} - if(!getNonBlockingMsgLenFromClient(handler, &qlen)) { - break; - } +static void handleQuery(std::shared_ptr& state) +{ + if (state->d_querySize < sizeof(dnsheader)) { + ++g_stats.nonCompliantQueries; + return; + } - queriesCount++; + state->d_readingFirstQuery = false; + ++state->d_queriesCount; + ++state->d_ci.cs->queries; + ++g_stats.queries; + + /* we need an accurate ("real") value for the response and + to store into the IDS, but not for insertion into the + rings for example */ + struct timespec now; + struct timespec queryRealTime; + gettime(&now); + gettime(&queryRealTime, true); + + auto query = reinterpret_cast(&state->d_buffer.at(0)); + std::shared_ptr dnsCryptQuery{nullptr}; + auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true); + if (dnsCryptResponse) { + state->d_responseBuffer = std::move(*dnsCryptResponse); + state->d_responseSize = state->d_responseBuffer.size(); + sendResponse(state); + return; + } - if (qlen < sizeof(dnsheader)) { - ++g_stats.nonCompliantQueries; - break; - } + const auto& dh = reinterpret_cast(query); + if (!checkQueryHeaders(dh)) { + return; + } - ci.cs->queries++; - ++g_stats.queries; + uint16_t qtype, qclass; + unsigned int consumed = 0; + DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime); + dq.dnsCryptQuery = std::move(dnsCryptQuery); - if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) { - vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn); - break; - } + state->d_isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR); + if (state->d_isXFR) { + dq.skipCache = true; + } - if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) { - vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort()); - break; - } + state->d_ds.reset(); + auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds); - /* allocate a bit more memory to be able to spoof the content, - or to add ECS without allocating a new buffer */ - queryBuffer.resize((static_cast(qlen) + 512) < 4096 ? (static_cast(qlen) + 512) : 4096); - - char* query = &queryBuffer[0]; - handler.read(query, qlen, g_tcpRecvTimeout, remainingTime); - - /* we need an accurate ("real") value for the response and - to store into the IDS, but not for insertion into the - rings for example */ - struct timespec now; - struct timespec queryRealTime; - gettime(&now); - gettime(&queryRealTime, true); - - std::shared_ptr dnsCryptQuery = nullptr; - auto dnsCryptResponse = checkDNSCryptQuery(*ci.cs, query, qlen, dnsCryptQuery, queryRealTime.tv_sec, true); - if (dnsCryptResponse) { - handler.writeSizeAndMsg(reinterpret_cast(dnsCryptResponse->data()), static_cast(dnsCryptResponse->size()), g_tcpSendTimeout); - continue; - } + if (result == ProcessQueryResult::Drop) { + return; + } - struct dnsheader* dh = reinterpret_cast(query); - if (!checkQueryHeaders(dh)) { - break; - } + if (result == ProcessQueryResult::SendAnswer) { + state->d_buffer.resize(dq.len); + state->d_responseBuffer = std::move(state->d_buffer); + state->d_responseSize = state->d_responseBuffer.size(); + sendResponse(state); + return; + } - uint16_t qtype, qclass; - unsigned int consumed = 0; - DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, queryBuffer.size(), qlen, true, &queryRealTime); - dq.dnsCryptQuery = std::move(dnsCryptQuery); + if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) { + return; + } - std::shared_ptr ds{nullptr}; - auto result = processQuery(dq, *ci.cs, holders, ds); + state->d_buffer.resize(dq.len); + setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname)); - if (result == ProcessQueryResult::Drop) { - break; - } + const uint8_t sizeBytes[] = { static_cast(dq.len / 256), static_cast(dq.len % 256) }; + /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes + that could occur if we had to deal with the size during the processing, + especially alignment issues */ + state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2); + sendQueryToBackend(state); +} - if (result == ProcessQueryResult::SendAnswer) { - handler.writeSizeAndMsg(reinterpret_cast(dq.dh), dq.len, g_tcpSendTimeout); - continue; - } +static void handleNewIOState(std::shared_ptr& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional ttd) +{ + //cerr<<"in "<<__func__<<" for fd "<d_lastIOState<<", new state is "<<(int)iostate<d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) { + state->d_threadData.mplexer->removeReadFD(fd); + //cerr<<__func__<<": remove read FD "<d_lastIOState = IOState::Done; + } + else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) { + state->d_threadData.mplexer->removeWriteFD(fd); + //cerr<<__func__<<": remove write FD "<d_lastIOState = IOState::Done; + } - int dsock = -1; - uint16_t downstreamFailures=0; -#ifdef MSG_FASTOPEN - bool freshConn = true; -#endif /* MSG_FASTOPEN */ - if(sockets.count(ds->remote) == 0) { - dsock=setupTCPDownstream(ds, downstreamFailures); - sockets[ds->remote]=dsock; - } - else { - dsock=sockets[ds->remote]; -#ifdef MSG_FASTOPEN - freshConn = false; -#endif /* MSG_FASTOPEN */ - } + if (iostate == IOState::NeedRead) { + if (state->d_lastIOState == IOState::NeedRead) { + if (ttd) { + /* let's update the TTD ! */ + state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0); + } + return; + } - ds->outstanding++; - outstanding = true; + state->d_lastIOState = IOState::NeedRead; + //cerr<<__func__<<": add read FD "<d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr); + } + else if (iostate == IOState::NeedWrite) { + if (state->d_lastIOState == IOState::NeedWrite) { + return; + } - retry:; - if (dsock < 0) { - sockets.erase(ds->remote); - break; - } + state->d_lastIOState = IOState::NeedWrite; + //cerr<<__func__<<": add write FD "<d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr); + } + else if (iostate == IOState::Done) { + state->d_lastIOState = IOState::Done; + } +} - if (ds->retries > 0 && downstreamFailures > ds->retries) { - vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstreamFailures); - close(dsock); - dsock=-1; - sockets.erase(ds->remote); - break; - } +static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto state = boost::any_cast>(param); + if (state->d_downstreamSocket == nullptr) { + throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!"); + } + if (fd != state->d_downstreamSocket->getHandle()) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamSocket->getHandle())); + } - try { - int socketFlags = 0; -#ifdef MSG_FASTOPEN - if (ds->tcpFastOpen && freshConn) { - socketFlags |= MSG_FASTOPEN; - } -#endif /* MSG_FASTOPEN */ - sendSizeAndMsgWithTimeout(dsock, dq.len, query, ds->tcpSendTimeout, &ds->remote, &ds->sourceAddr, ds->sourceItf, 0, socketFlags); - } - catch(const runtime_error& e) { - vinfolog("Downstream connection to %s died on us (%s), getting a new one!", ds->getName(), e.what()); - close(dsock); - dsock=-1; - sockets.erase(ds->remote); - downstreamFailures++; - dsock=setupTCPDownstream(ds, downstreamFailures); - sockets[ds->remote]=dsock; -#ifdef MSG_FASTOPEN - freshConn=true; -#endif /* MSG_FASTOPEN */ - goto retry; - } + IOState iostate = IOState::Done; + bool connectionDied = false; - bool xfrStarted = false; - bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR); - if (isXFR) { - dq.skipCache = true; - } - bool firstPacket=true; - getpacket:; - - if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) { - vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName()); - close(dsock); - dsock=-1; - sockets.erase(ds->remote); - downstreamFailures++; - dsock=setupTCPDownstream(ds, downstreamFailures); - sockets[ds->remote]=dsock; + try { + if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) { + int socketFlags = 0; #ifdef MSG_FASTOPEN - freshConn=true; + if (state->d_ds->tcpFastOpen && state->d_freshDownstreamConnection) { + socketFlags |= MSG_FASTOPEN; + } #endif /* MSG_FASTOPEN */ - if(xfrStarted) { - break; - } - goto retry; - } - size_t responseSize = rlen; - uint16_t addRoom = 0; - if (dq.dnsCryptQuery && (UINT16_MAX - rlen) > static_cast(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) { - addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; + size_t sent = sendMsgWithTimeout(fd, reinterpret_cast(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, 0, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, 0, socketFlags); + if (sent == state->d_buffer.size()) { + /* request sent ! */ + state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend; + state->d_currentPos = 0; + iostate = IOState::NeedRead; + if (!state->d_isXFR) { + /* don't bother with the outstanding count for XFR queries */ + ++state->d_ds->outstanding; + state->d_outstanding = true; } + } + else { + state->d_currentPos += sent; + iostate = IOState::NeedWrite; + /* disable fast open on partial write */ + state->d_freshDownstreamConnection = false; + } + } - responseSize += addRoom; - answerBuffer.resize(responseSize); - char* response = answerBuffer.data(); - readn2WithTimeout(dsock, response, rlen, ds->tcpRecvTimeout); - uint16_t responseLen = rlen; - if (outstanding) { - /* might be false for {A,I}XFR */ - --ds->outstanding; - outstanding = false; - } + if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) { + // then we need to allocate a new buffer (new because we might need to re-send the query if the + // backend dies on us + // We also might need to read and send to the client more than one response in case of XFR (yeah!) + // should very likely be a TCPIOHandler d_downstreamHandler + iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos); + if (iostate == IOState::Done) { + state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend; + state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1); + state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize); + state->d_currentPos = 0; + } + } - if (rlen < sizeof(dnsheader)) { - break; - } + if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) { + iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos); + if (iostate == IOState::Done) { + handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback); - consumed = 0; - if (firstPacket && !responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote, consumed)) { - break; + if (state->d_isXFR) { + /* Don't reuse the TCP connection after an {A,I}XFR */ + /* but don't reset it either, we will need to read more messages */ } - firstPacket=false; - - dh = reinterpret_cast(response); - DNSResponse dr(&qname, qtype, qclass, consumed, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime); - dr.origFlags = dq.origFlags; - dr.ecsAdded = dq.ecsAdded; - dr.ednsAdded = dq.ednsAdded; - dr.useZeroScope = dq.useZeroScope; - dr.packetCache = std::move(dq.packetCache); - dr.delayMsec = dq.delayMsec; - dr.skipCache = dq.skipCache; - dr.cacheKey = dq.cacheKey; - dr.cacheKeyNoECS = dq.cacheKeyNoECS; - dr.dnssecOK = dq.dnssecOK; - dr.tempFailureTTL = dq.tempFailureTTL; - dr.qTag = std::move(dq.qTag); - dr.subnet = std::move(dq.subnet); -#ifdef HAVE_PROTOBUF - dr.uniqueId = std::move(dq.uniqueId); -#endif - if (dq.dnsCryptQuery) { - dr.dnsCryptQuery = std::move(dq.dnsCryptQuery); + else { + releaseDownstreamConnection(state->d_ds, std::move(state->d_downstreamSocket)); } + fd = -1; - memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH)); - if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, false)) { - break; - } + handleResponse(state); + return; + } + } - if (!handler.writeSizeAndMsg(response, responseLen, g_tcpSendTimeout)) { - break; - } + if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend && + state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend && + state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) { + vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast(state->d_state)); + } + } + catch(const std::exception& e) { + /* most likely an EOF because the other end closed the connection, + but it might also be a real IO error or something else. + Let's just drop the connection + */ + vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what()); + /* remove this FD from the IO multiplexer */ + ++state->d_downstreamFailures; + if (state->d_outstanding && state->d_ds != nullptr) { + --state->d_ds->outstanding; + } + iostate = IOState::Done; + connectionDied = true; + } - if (isXFR) { - if (dh->rcode == 0 && dh->ancount != 0) { - if (xfrStarted == false) { - xfrStarted = true; - if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) { - goto getpacket; - } - } - else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) { - goto getpacket; - } - } - /* Don't reuse the TCP connection after an {A,I}XFR */ - close(dsock); - dsock=-1; - sockets.erase(ds->remote); - } + if (iostate == IOState::Done) { + handleNewIOState(state, iostate, fd, handleDownstreamIOCallback); + } + else { + handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD() : state->getBackendWriteTTD()); + } + + if (connectionDied) { + sendQueryToBackend(state); + } +} + +static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto state = boost::any_cast>(param); + if (fd != state->d_ci.fd) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd)); + } + + IOState iostate = IOState::Done; + + struct timeval now; + gettimeofday(&now, 0); + if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); + handleNewIOState(state, IOState::Done, fd, handleIOCallback); + return; + } - ++g_stats.responses; - switch (dr.dh->rcode) { - case RCode::NXDomain: - ++g_stats.frontendNXDomain; - break; - case RCode::ServFail: - ++g_stats.frontendServFail; - break; - case RCode::NoError: - ++g_stats.frontendNoError; - break; + try { + if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) { + iostate = state->d_handler.tryHandshake(); + if (iostate == IOState::Done) { + state->d_state = IncomingTCPConnectionState::State::readingQuerySize; + } + } + + if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) { + iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos); + if (iostate == IOState::Done) { + state->d_state = IncomingTCPConnectionState::State::readingQuery; + state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1); + if (state->d_querySize < sizeof(dnsheader)) { + /* go away */ + handleNewIOState(state, IOState::Done, fd, handleIOCallback); + return; } - struct timespec answertime; - gettime(&answertime); - unsigned int udiff = 1000000.0*DiffTime(now,answertime); - g_rings.insertResponse(answertime, ci.remote, qname, dq.qtype, static_cast(udiff), static_cast(responseLen), cleartextDH, ds->remote); - rewrittenResponse.clear(); + /* allocate a bit more memory to be able to spoof the content, + or to add ECS without allocating a new buffer */ + state->d_buffer.resize(state->d_querySize + 512); + state->d_currentPos = 0; + } + } + + if (state->d_state == IncomingTCPConnectionState::State::readingQuery) { + iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); + if (iostate == IOState::Done) { + handleNewIOState(state, IOState::Done, fd, handleIOCallback); + handleQuery(state); + return; } } - catch(const std::exception& e) { - vinfolog("Got exception while handling TCP query: %s", e.what()); + + if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + iostate = state->d_handler.tryWrite(state->d_buffer, state->d_currentPos, state->d_buffer.size()); + if (iostate == IOState::Done) { + handleResponseSent(state); + return; + } + } + + if (state->d_state != IncomingTCPConnectionState::State::doingHandshake && + state->d_state != IncomingTCPConnectionState::State::readingQuerySize && + state->d_state != IncomingTCPConnectionState::State::readingQuery && + state->d_state != IncomingTCPConnectionState::State::sendingResponse) { + vinfolog("Unexpected state %d in handleIOCallback", static_cast(state->d_state)); } - catch(...) { + } + catch(const std::exception& e) { + /* most likely an EOF because the other end closed the connection, + but it might also be a real IO error or something else. + Let's just drop the connection + */ + if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) { + vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); } + else { + vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort()); + } + /* remove this FD from the IO multiplexer */ + iostate = IOState::Done; + } + + if (iostate == IOState::Done) { + handleNewIOState(state, iostate, fd, handleIOCallback); + } + else { + handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now)); + } +} + +static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param) +{ + auto threadData = boost::any_cast(param); + + ConnectionInfo* citmp{nullptr}; + + try { + readn2(pipefd, &citmp, sizeof(citmp)); + } + catch(const std::runtime_error& e) { + throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what()); + } + + g_tcpclientthreads->decrementQueuedCount(); + auto ci = std::move(*citmp); + delete citmp; + citmp = nullptr; + + struct timeval now; + gettimeofday(&now, 0); + auto state = std::make_shared(std::move(ci), *threadData, now.tv_sec); - vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort()); + /* let's update the remaining time */ + state->d_remainingTime = g_maxTCPConnectionDuration; - if (ds && outstanding) { - outstanding = false; - --ds->outstanding; + /* we could try reading right away, but let's not for now */ + handleNewIOState(state, IOState::NeedRead, state->d_ci.fd, handleIOCallback, state->getClientReadTTD(now)); +} + +void tcpClientThread(int pipefd) +{ + /* we get launched with a pipe on which we receive file descriptors from clients that we own + from that point on */ + + setThreadName("dnsdist/tcpClie"); + + TCPClientThreadData data; + + data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data); + time_t lastTCPCleanup = time(nullptr); + struct timeval now; + gettimeofday(&now, 0); + + for (;;) { + data.mplexer->run(&now); + + if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) { + cleanupClosedTCPConnections(); + lastTCPCleanup = now.tv_sec; + } + + auto expiredReadConns = data.mplexer->getTimeouts(now, false); + for(const auto& conn : expiredReadConns) { + auto state = boost::any_cast>(conn.second); + if (conn.first == state->d_ci.fd) { + vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + } + else if (state->d_ds) { + vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName()); + } + data.mplexer->removeReadFD(conn.first); + state->d_lastIOState = IOState::Done; } - decrementTCPClientCount(ci.remote); - if (g_downstreamTCPCleanupInterval > 0 && (connectionStartTime > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) { - cleanupClosedTCPConnections(sockets); - lastTCPCleanup = time(nullptr); + auto expiredWriteConns = data.mplexer->getTimeouts(now, true); + for(const auto& conn : expiredWriteConns) { + auto state = boost::any_cast>(conn.second); + if (conn.first == state->d_ci.fd) { + vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + } + else if (state->d_ds) { + vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName()); + } + data.mplexer->removeWriteFD(conn.first); + state->d_lastIOState = IOState::Done; } } } -/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and +/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and they will hand off to worker threads & spawn more of them if required */ void tcpAcceptorThread(void* p) @@ -596,7 +1066,7 @@ void tcpAcceptorThread(void* p) bool tcpClientCountIncremented = false; ComboAddress remote; remote.sin4.sin_family = cs->local.sin4.sin_family; - + g_tcpclientthreads->addTCPClientThread(); auto acl = g_ACL.getLocal(); diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 0252b95deb..acb4647786 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -574,26 +574,9 @@ try { dh->id = ids->origID; uint16_t addRoom = 0; - DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, consumed, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start); - dr.origFlags = ids->origFlags; - dr.ecsAdded = ids->ecsAdded; - dr.ednsAdded = ids->ednsAdded; - dr.useZeroScope = ids->useZeroScope; - dr.packetCache = std::move(ids->packetCache); - dr.delayMsec = ids->delayMsec; - dr.skipCache = ids->skipCache; - dr.cacheKey = ids->cacheKey; - dr.cacheKeyNoECS = ids->cacheKeyNoECS; - dr.dnssecOK = ids->dnssecOK; - dr.tempFailureTTL = ids->tempFailureTTL; - dr.qTag = std::move(ids->qTag); - dr.subnet = std::move(ids->subnet); -#ifdef HAVE_PROTOBUF - dr.uniqueId = std::move(ids->uniqueId); -#endif - if (ids->dnsCryptQuery) { + DNSResponse dr = makeDNSResponseFromIDState(*ids, dh, sizeof(packet), responseLen, false); + if (dr.dnsCryptQuery) { addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; - dr.dnsCryptQuery = std::move(ids->dnsCryptQuery); } memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH)); @@ -1577,24 +1560,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids->cs = &cs; ids->origID = dh->id; - ids->origRemote = remote; - ids->sentTime.set(queryRealTime); - ids->qname = std::move(qname); - ids->qtype = dq.qtype; - ids->qclass = dq.qclass; - ids->delayMsec = dq.delayMsec; - ids->tempFailureTTL = dq.tempFailureTTL; - ids->origFlags = dq.origFlags; - ids->cacheKey = dq.cacheKey; - ids->cacheKeyNoECS = dq.cacheKeyNoECS; - ids->subnet = dq.subnet; - ids->skipCache = dq.skipCache; - ids->packetCache = dq.packetCache; - ids->ednsAdded = dq.ednsAdded; - ids->ecsAdded = dq.ecsAdded; - ids->useZeroScope = dq.useZeroScope; - ids->qTag = dq.qTag; - ids->dnssecOK = dq.dnssecOK; + setIDStateFromDNSQuestion(*ids, dq, std::move(qname)); /* If we couldn't harvest the real dest addr, still write down the listening addr since it will be useful @@ -1611,12 +1577,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids->destHarvested = false; } - ids->dnsCryptQuery = std::move(dq.dnsCryptQuery); - -#ifdef HAVE_PROTOBUF - ids->uniqueId = std::move(dq.uniqueId); -#endif - dh->id = idOffset; int fd = pickBackendSocketForSending(ss); diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 6ed7faeab6..7757e69288 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -1066,3 +1066,5 @@ static const size_t s_udpIncomingBufferSize{1500}; enum class ProcessQueryResult { Drop, SendAnswer, PassToBackend }; ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend); +DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP); +void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index b0d56c303c..8c9e90b5cd 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -99,6 +99,7 @@ dnsdist_SOURCES = \ dnsdist-dnscrypt.cc \ dnsdist-dynblocks.hh \ dnsdist-ecs.cc dnsdist-ecs.hh \ + dnsdist-idstate.cc \ dnsdist-lua.hh dnsdist-lua.cc \ dnsdist-lua-actions.cc \ dnsdist-lua-bindings.cc \ diff --git a/pdns/dnsdistdist/dnsdist-idstate.cc b/pdns/dnsdistdist/dnsdist-idstate.cc new file mode 100644 index 0000000000..169ba64f3a --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-idstate.cc @@ -0,0 +1,58 @@ + +#include "dnsdist.hh" + +DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP) +{ + + DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, ids.qname.wirelength(), &ids.origDest, &ids.origRemote, dh, bufferSize, responseLen, isTCP, &ids.sentTime.d_start); + dr.origFlags = ids.origFlags; + dr.ecsAdded = ids.ecsAdded; + dr.ednsAdded = ids.ednsAdded; + dr.useZeroScope = ids.useZeroScope; + dr.packetCache = std::move(ids.packetCache); + dr.delayMsec = ids.delayMsec; + dr.skipCache = ids.skipCache; + dr.cacheKey = ids.cacheKey; + dr.cacheKeyNoECS = ids.cacheKeyNoECS; + dr.dnssecOK = ids.dnssecOK; + dr.tempFailureTTL = ids.tempFailureTTL; + dr.qTag = std::move(ids.qTag); + dr.subnet = std::move(ids.subnet); +#ifdef HAVE_PROTOBUF + dr.uniqueId = std::move(ids.uniqueId); +#endif + if (ids.dnsCryptQuery) { + dr.dnsCryptQuery = std::move(ids.dnsCryptQuery); + } + + return dr; +} + +void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname) +{ + ids.origRemote = *dq.remote; + ids.origDest = *dq.local; + ids.sentTime.set(*dq.queryTime); + ids.qname = std::move(qname); + ids.qtype = dq.qtype; + ids.qclass = dq.qclass; + ids.delayMsec = dq.delayMsec; + ids.tempFailureTTL = dq.tempFailureTTL; + ids.origFlags = dq.origFlags; + ids.cacheKey = dq.cacheKey; + ids.cacheKeyNoECS = dq.cacheKeyNoECS; + ids.subnet = dq.subnet; + ids.skipCache = dq.skipCache; + ids.packetCache = dq.packetCache; + ids.ednsAdded = dq.ednsAdded; + ids.ecsAdded = dq.ecsAdded; + ids.useZeroScope = dq.useZeroScope; + ids.qTag = dq.qTag; + ids.dnssecOK = dq.dnssecOK; + + ids.dnsCryptQuery = std::move(dq.dnsCryptQuery); + +#ifdef HAVE_PROTOBUF + ids.uniqueId = std::move(dq.uniqueId); +#endif +} diff --git a/pdns/dnsdistdist/tcpiohandler.cc b/pdns/dnsdistdist/tcpiohandler.cc index 2be4a4c62f..9d44f0dba7 100644 --- a/pdns/dnsdistdist/tcpiohandler.cc +++ b/pdns/dnsdistdist/tcpiohandler.cc @@ -232,7 +232,7 @@ private: class OpenSSLTLSConnection: public TLSConnection { public: - OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr(SSL_new(tlsCtx), SSL_free)) + OpenSSLTLSConnection(int socket, unsigned int timeout, SSL_CTX* tlsCtx): d_conn(std::unique_ptr(SSL_new(tlsCtx), SSL_free)), d_timeout(timeout) { d_socket = socket; @@ -247,12 +247,59 @@ public: if (!SSL_set_fd(d_conn.get(), d_socket)) { throw std::runtime_error("Error assigning socket"); } + } + + IOState convertIORequestToIOState(int res) const + { + int error = SSL_get_error(d_conn.get(), res); + if (error == SSL_ERROR_WANT_READ) { + return IOState::NeedRead; + } + else if (error == SSL_ERROR_WANT_WRITE) { + return IOState::NeedWrite; + } + else { + throw std::runtime_error("Error while processing TLS connection:" + std::to_string(error)); + } + } + + void handleIORequest(int res, unsigned int timeout) + { + auto state = convertIORequestToIOState(res); + if (state == IOState::NeedRead) { + res = waitForData(d_socket, timeout); + if (res <= 0) { + throw std::runtime_error("Error reading from TLS connection"); + } + } + else if (state == IOState::NeedWrite) { + res = waitForRWData(d_socket, false, timeout, 0); + if (res <= 0) { + throw std::runtime_error("Error waiting to write to TLS connection"); + } + } + } + + IOState tryHandshake() + { + int res = SSL_accept(d_conn.get()); + if (res == 1) { + return IOState::Done; + } + else if (res < 0) { + return convertIORequestToIOState(res); + } + + throw std::runtime_error("Error accepting TLS connection"); + } + void doHandshake() + { int res = 0; do { res = SSL_accept(d_conn.get()); if (res < 0) { - handleIORequest(res, timeout); + handleIORequest(res, d_timeout); } } while (res < 0); @@ -262,24 +309,40 @@ public: } } - void handleIORequest(int res, unsigned int timeout) + IOState tryWrite(std::vector& buffer, size_t& pos, size_t toWrite) override { - int error = SSL_get_error(d_conn.get(), res); - if (error == SSL_ERROR_WANT_READ) { - res = waitForData(d_socket, timeout); - if (res <= 0) { - throw std::runtime_error("Error reading from TLS connection"); + do { + int res = SSL_write(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toWrite - pos)); + if (res == 0) { + throw std::runtime_error("Error writing to TLS connection"); } - } - else if (error == SSL_ERROR_WANT_WRITE) { - res = waitForRWData(d_socket, false, timeout, 0); - if (res <= 0) { - throw std::runtime_error("Error waiting to write to TLS connection"); + else if (res < 0) { + return convertIORequestToIOState(res); + } + else { + pos += static_cast(res); } } - else { - throw std::runtime_error("Error writing to TLS connection"); + while (pos < toWrite); + return IOState::Done; + } + + IOState tryRead(std::vector& buffer, size_t& pos, size_t toRead) override + { + do { + int res = SSL_read(d_conn.get(), reinterpret_cast(&buffer.at(pos)), static_cast(toRead - pos)); + if (res == 0) { + throw std::runtime_error("Error reading from TLS connection"); + } + else if (res < 0) { + return convertIORequestToIOState(res); + } + else { + pos += static_cast(res); + } } + while (pos < toRead); + return IOState::Done; } size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override @@ -300,7 +363,7 @@ public: handleIORequest(res, readTimeout); } else { - got += (size_t) res; + got += static_cast(res); } if (totalTimeout) { @@ -330,7 +393,7 @@ public: handleIORequest(res, writeTimeout); } else { - got += (size_t) res; + got += static_cast(res); } } while (got < bufferSize); @@ -346,6 +409,7 @@ public: private: std::unique_ptr d_conn; + unsigned int d_timeout; }; class OpenSSLTLSIOCtx: public TLSCtx @@ -650,7 +714,7 @@ public: GnuTLSConnection(int socket, unsigned int timeout, const gnutls_certificate_credentials_t creds, const gnutls_priority_t priorityCache, std::shared_ptr& ticketsKey, bool enableTickets): d_conn(std::unique_ptr(nullptr, gnutls_deinit)), d_ticketsKey(ticketsKey) { - unsigned int sslOptions = GNUTLS_SERVER; + unsigned int sslOptions = GNUTLS_SERVER | GNUTLS_NONBLOCK; #ifdef GNUTLS_NO_SIGNAL sslOptions |= GNUTLS_NO_SIGNAL; #endif @@ -685,12 +749,86 @@ public: /* timeouts are in milliseconds */ gnutls_handshake_set_timeout(d_conn.get(), timeout * 1000); gnutls_record_set_timeout(d_conn.get(), timeout * 1000); + } + void doHandshake() + { int ret = 0; do { ret = gnutls_handshake(d_conn.get()); + if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { + throw std::runtime_error("Error accepting a new connection"); + } + } + while (ret < 0 && ret == GNUTLS_E_INTERRUPTED); + } + + IOState tryHandshake() + { + int ret = 0; + + do { + ret = gnutls_handshake(d_conn.get()); + if (ret == GNUTLS_E_SUCCESS) { + return IOState::Done; + } + else if (ret == GNUTLS_E_AGAIN) { + return IOState::NeedRead; + } + else if (gnutls_error_is_fatal(ret) || ret == GNUTLS_E_WARNING_ALERT_RECEIVED) { + throw std::runtime_error("Error accepting a new connection"); + } + } while (ret == GNUTLS_E_INTERRUPTED); + + throw std::runtime_error("Error accepting a new connection"); + } + + IOState tryWrite(std::vector& buffer, size_t& pos, size_t toWrite) override + { + do { + ssize_t res = gnutls_record_send(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toWrite - pos); + if (res == 0) { + throw std::runtime_error("Error writing to TLS connection"); + } + else if (res > 0) { + pos += static_cast(res); + } + else if (res < 0) { + if (gnutls_error_is_fatal(res)) { + throw std::runtime_error("Error writing to TLS connection"); + } + else if (res == GNUTLS_E_AGAIN) { + return IOState::NeedWrite; + } + warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); + } + } + while (pos < toWrite); + return IOState::Done; + } + + IOState tryRead(std::vector& buffer, size_t& pos, size_t toRead) override + { + do { + ssize_t res = gnutls_record_recv(d_conn.get(), reinterpret_cast(&buffer.at(pos)), toRead - pos); + if (res == 0) { + throw std::runtime_error("Error reading from TLS connection"); + } + else if (res > 0) { + pos += static_cast(res); + } + else if (res < 0) { + if (gnutls_error_is_fatal(res)) { + throw std::runtime_error("Error reading from TLS connection"); + } + else if (res == GNUTLS_E_AGAIN) { + return IOState::NeedRead; + } + warnlog("Warning, non-fatal error while writing to TLS connection: %s", gnutls_strerror(res)); + } } - while (ret < 0 && gnutls_error_is_fatal(ret) == 0); + while (pos < toRead); + return IOState::Done; } size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout) override @@ -708,7 +846,7 @@ public: throw std::runtime_error("Error reading from TLS connection"); } else if (res > 0) { - got += (size_t) res; + got += static_cast(res); } else if (res < 0) { if (gnutls_error_is_fatal(res)) { @@ -750,7 +888,7 @@ public: throw std::runtime_error("Error writing to TLS connection"); } else if (res > 0) { - got += (size_t) res; + got += static_cast(res); } else if (res < 0) { if (gnutls_error_is_fatal(res)) { diff --git a/pdns/iputils.cc b/pdns/iputils.cc index 0d7c342b00..88fd698131 100644 --- a/pdns/iputils.cc +++ b/pdns/iputils.cc @@ -269,40 +269,112 @@ void ComboAddress::truncate(unsigned int bits) noexcept *place &= (~((1<(const_cast(dest)); + msgh.msg_namelen = dest->getSocklen(); + } + else { + msgh.msg_name = nullptr; + msgh.msg_namelen = 0; + } + + msgh.msg_flags = 0; + + if (localItf != 0 && local) { + addCMsgSrcAddr(&msgh, cbuf, local, localItf); + } + + if (localItf != 0 && local) { + addCMsgSrcAddr(&msgh, cbuf, local, localItf); + } + + iov.iov_base = reinterpret_cast(const_cast(buffer)); + iov.iov_len = len; + msgh.msg_iov = &iov; + msgh.msg_iovlen = 1; + msgh.msg_flags = 0; + + size_t sent = 0; bool firstTry = true; - fillMSGHdr(&msgh, &iov, cbuf, sizeof(cbuf), const_cast(buffer), len, &dest); - addCMsgSrcAddr(&msgh, cbuf, &local, localItf); do { - ssize_t written = sendmsg(fd, &msgh, 0); - if (written > 0) - return written; +#ifdef MSG_FASTOPEN + if (flags & MSG_FASTOPEN && firstTry == false) { + flags &= ~MSG_FASTOPEN; + } +#endif /* MSG_FASTOPEN */ - if (errno == EAGAIN) { - if (firstTry) { - int res = waitForRWData(fd, false, timeout, 0); - if (res > 0) { - /* there is room available */ - firstTry = false; + ssize_t res = sendmsg(fd, &msgh, flags); + + if (res > 0) { + size_t written = static_cast(res); + sent += written; + + if (sent == len) { + return sent; + } + + /* partial write */ + iov.iov_len -= written; + iov.iov_base = reinterpret_cast(reinterpret_cast(iov.iov_base) + written); + written = 0; + } + else if (res == -1) { + if (errno == EINTR) { + continue; + } + else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS) { + /* EINPROGRESS might happen with non blocking socket, + especially with TCP Fast Open */ + if (totalTimeout <= 0 && idleTimeout <= 0) { + return sent; + } + + if (firstTry) { + int res = waitForRWData(fd, false, (totalTimeout == 0 || idleTimeout <= remainingTime) ? idleTimeout : remainingTime, 0); + if (res > 0) { + /* there is room available */ + firstTry = false; + } + else if (res == 0) { + throw runtime_error("Timeout while waiting to write data"); + } else { + throw runtime_error("Error while waiting for room to write data"); + } } - else if (res == 0) { + else { throw runtime_error("Timeout while waiting to write data"); - } else { - throw runtime_error("Error while waiting for room to write data"); } } else { - throw runtime_error("Timeout while waiting to write data"); + unixDie("failed in sendMsgWithTimeout"); } } - else { - unixDie("failed in write2WithTimeout"); + if (totalTimeout) { + time_t now = time(nullptr); + int elapsed = now - start; + if (elapsed >= remainingTime) { + throw runtime_error("Timeout while sending data"); + } + start = now; + remainingTime -= elapsed; } } while (firstTry); diff --git a/pdns/iputils.hh b/pdns/iputils.hh index 490e45ac43..498e03c924 100644 --- a/pdns/iputils.hh +++ b/pdns/iputils.hh @@ -1062,7 +1062,7 @@ bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destinat bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv); void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, char* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr); ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to); -ssize_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int timeout, ComboAddress& dest, const ComboAddress& local, unsigned int localItf); +size_t sendMsgWithTimeout(int fd, const char* buffer, size_t len, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags); bool sendSizeAndMsgWithTimeout(int sock, uint16_t bufferLen, const char* buffer, int idleTimeout, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int totalTimeout, int flags); /* requires a non-blocking, connected TCP socket */ bool isTCPSocketUsable(int sock); diff --git a/pdns/mplexer.hh b/pdns/mplexer.hh index 7de3a8a459..b42e900928 100644 --- a/pdns/mplexer.hh +++ b/pdns/mplexer.hh @@ -51,9 +51,9 @@ class FDMultiplexer { public: typedef boost::any funcparam_t; + typedef boost::function< void(int, funcparam_t&) > callbackfunc_t; protected: - typedef boost::function< void(int, funcparam_t&) > callbackfunc_t; struct Callback { callbackfunc_t d_callback; diff --git a/pdns/sstuff.hh b/pdns/sstuff.hh index b8066b8782..922315a5e2 100644 --- a/pdns/sstuff.hh +++ b/pdns/sstuff.hh @@ -60,10 +60,17 @@ public: setCloseOnExec(d_socket); } + Socket(Socket&& rhs): d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket) + { + rhs.d_socket = -1; + } + ~Socket() { try { - closesocket(d_socket); + if (d_socket != -1) { + closesocket(d_socket); + } } catch(const PDNSException& e) { } @@ -124,10 +131,10 @@ public: } //! Bind the socket to a specified endpoint - void bind(const ComboAddress &local) + void bind(const ComboAddress &local, bool reuseaddr=true) { int tmp=1; - if(setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&tmp), sizeof tmp)<0) + if(reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&tmp), sizeof tmp)<0) throw NetworkError(string("Setsockopt failed: ")+strerror(errno)); if(::bind(d_socket, reinterpret_cast(&local), local.getSocklen())<0) diff --git a/pdns/tcpiohandler.hh b/pdns/tcpiohandler.hh index 0d5bfa514e..30e19537f8 100644 --- a/pdns/tcpiohandler.hh +++ b/pdns/tcpiohandler.hh @@ -4,12 +4,18 @@ #include "misc.hh" +enum class IOState { Done, NeedRead, NeedWrite }; + class TLSConnection { public: virtual ~TLSConnection() { } + virtual void doHandshake() = 0; + virtual IOState tryHandshake() = 0; virtual size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) = 0; virtual size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) = 0; + virtual IOState tryWrite(std::vector& buffer, size_t& pos, size_t toWrite) = 0; + virtual IOState tryRead(std::vector& buffer, size_t& pos, size_t toRead) = 0; virtual void close() = 0; protected: @@ -153,12 +159,14 @@ private: class TCPIOHandler { public: + TCPIOHandler(int socket, unsigned int timeout, std::shared_ptr ctx, time_t now): d_socket(socket) { if (ctx) { d_conn = ctx->getConnection(d_socket, timeout, now); } } + ~TCPIOHandler() { if (d_conn) { @@ -168,6 +176,15 @@ public: shutdown(d_socket, SHUT_RDWR); } } + + IOState tryHandshake() + { + if (d_conn) { + return d_conn->tryHandshake(); + } + return IOState::Done; + } + size_t read(void* buffer, size_t bufferSize, unsigned int readTimeout, unsigned int totalTimeout=0) { if (d_conn) { @@ -176,6 +193,77 @@ public: return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout); } } + + /* Tries to read exactly toRead bytes into the buffer, starting at position pos. + Updates pos everytime a successful read occurs, + throws an std::runtime_error in case of IO error, + return Done when toRead bytes have been read, needRead or needWrite if the IO operation + would block. + */ + IOState tryRead(std::vector& buffer, size_t& pos, size_t toRead) + { + if (d_conn) { + return d_conn->tryRead(buffer, pos, toRead); + } + + size_t got = 0; + do { + ssize_t res = ::read(d_socket, reinterpret_cast(&buffer.at(pos)), toRead - got); + if (res == 0) { + throw runtime_error("EOF while reading message"); + } + if (res < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return IOState::NeedRead; + } + else { + throw std::runtime_error(std::string("Error while reading message: ") + strerror(errno)); + } + } + + pos += static_cast(res); + got += static_cast(res); + } + while (got < toRead); + + return IOState::Done; + } + + /* Tries to write exactly toWrite bytes from the buffer, starting at position pos. + Updates pos everytime a successful write occurs, + throws an std::runtime_error in case of IO error, + return Done when toWrite bytes have been written, needRead or needWrite if the IO operation + would block. + */ + IOState tryWrite(std::vector& buffer, size_t& pos, size_t toWrite) + { + if (d_conn) { + return d_conn->tryWrite(buffer, pos, toWrite); + } + + size_t sent = 0; + do { + ssize_t res = ::write(d_socket, reinterpret_cast(&buffer.at(pos)), toWrite - sent); + if (res == 0) { + throw runtime_error("EOF while sending message"); + } + if (res < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return IOState::NeedWrite; + } + else { + throw std::runtime_error(std::string("Error while writing message: ") + strerror(errno)); + } + } + + pos += static_cast(res); + sent += static_cast(res); + } + while (sent < toWrite); + + return IOState::Done; + } + size_t write(const void* buffer, size_t bufferSize, unsigned int writeTimeout) { if (d_conn) {