* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/
+
+#include <thread>
+#include <netinet/tcp.h>
+#include <queue>
+
#include "dnsdist.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-proxy-protocol.hh"
#include "dnsdist-rings.hh"
+#include "dnsdist-tcp-downstream.hh"
+#include "dnsdist-tcp-upstream.hh"
#include "dnsdist-xpf.hh"
-
#include "dnsparser.hh"
-#include "ednsoptions.hh"
#include "dolog.hh"
-#include "lock.hh"
+#include "ednsoptions.hh"
#include "gettime.hh"
+#include "lock.hh"
+#include "sstuff.hh"
#include "tcpiohandler.hh"
+#include "tcpiohandler-mplexer.hh"
#include "threadname.hh"
-#include <thread>
-#include <atomic>
-#include <netinet/tcp.h>
-
-#include "sstuff.hh"
-
-using std::thread;
-using std::atomic;
/* TCP: the grand design.
We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
static std::mutex tcpClientsCountMutex;
static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
-static const size_t g_maxCachedConnectionsPerDownstream = 20;
+
uint64_t g_maxTCPQueuedConnections{1000};
size_t g_maxTCPQueriesPerConn{0};
size_t g_maxTCPConnectionDuration{0};
uint16_t g_downstreamTCPCleanupInterval{60};
bool g_useTCPSinglePipe{false};
-static std::unique_ptr<Socket> setupTCPDownstream(shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures)
+class DownstreamConnectionsManager
{
- std::unique_ptr<Socket> result;
+public:
- do {
- vinfolog("TCP connecting to downstream %s (%d)", ds->remote.toStringWithPort(), downstreamFailures);
- try {
- result = std::unique_ptr<Socket>(new Socket(ds->remote.sin4.sin_family, SOCK_STREAM, 0));
- if (!IsAnyAddress(ds->sourceAddr)) {
- SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
-#ifdef IP_BIND_ADDRESS_NO_PORT
- if (ds->ipBindAddrNoPort) {
- SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
- }
-#endif
-#ifdef SO_BINDTODEVICE
- if (!ds->sourceItfName.empty()) {
- int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, ds->sourceItfName.c_str(), ds->sourceItfName.length());
- if (res != 0) {
- vinfolog("Error setting up the interface on backend TCP socket '%s': %s", ds->getNameWithAddr(), stringerror());
- }
- }
-#endif
- result->bind(ds->sourceAddr, false);
- }
- result->setNonBlocking();
-#ifdef MSG_FASTOPEN
- if (!ds->tcpFastOpen) {
- SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
- }
-#else
- SConnectWithTimeout(result->getHandle(), ds->remote, /* no timeout, we will handle it ourselves */ 0);
-#endif /* MSG_FASTOPEN */
- return result;
- }
- catch(const std::runtime_error& e) {
- vinfolog("Connection to downstream server %s failed: %s", ds->getName(), e.what());
- downstreamFailures++;
- if (downstreamFailures > ds->retries) {
- throw;
+ static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::unique_ptr<FDMultiplexer>& mplexer, std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
+ {
+ std::unique_ptr<TCPConnectionToBackend> result;
+
+ const auto& it = t_downstreamConnections.find(ds->remote);
+ if (it != t_downstreamConnections.end()) {
+ auto& list = it->second;
+ if (!list.empty()) {
+ result = std::move(list.front());
+ list.pop_front();
+ result->setReused();
+ return result;
}
}
- } while(downstreamFailures <= ds->retries);
- return nullptr;
-}
-
-class TCPConnectionToBackend
-{
-public:
- TCPConnectionToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now): d_ds(ds), d_connectionStartTime(now), d_enableFastOpen(ds->tcpFastOpen)
- {
- d_socket = setupTCPDownstream(d_ds, downstreamFailures);
- ++d_ds->tcpCurrentConnections;
+ return make_unique<TCPConnectionToBackend>(ds, now);
}
- ~TCPConnectionToBackend()
+ static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
{
- if (d_ds && d_socket) {
- --d_ds->tcpCurrentConnections;
- struct timeval now;
- gettimeofday(&now, nullptr);
-
- auto diff = now - d_connectionStartTime;
- d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000);
+ if (conn == nullptr) {
+ return;
}
- }
- int getHandle() const
- {
- if (!d_socket) {
- throw std::runtime_error("Attempt to get the socket handle from a non-established TCP connection");
+ if (!conn->canBeReused()) {
+ conn.reset();
+ return;
}
- return d_socket->getHandle();
- }
-
- const ComboAddress& getRemote() const
- {
- return d_ds->remote;
- }
-
- bool isFresh() const
- {
- return d_fresh;
- }
-
- void incQueries()
- {
- ++d_queries;
- }
-
- void setReused()
- {
- d_fresh = false;
- }
-
- void disableFastOpen()
- {
- d_enableFastOpen = false;
- }
-
- bool isFastOpenEnabled()
- {
- return d_enableFastOpen;
- }
-
- bool canBeReused() const
- {
- /* we can't reuse a connection where a proxy protocol payload has been sent,
- since:
- - it cannot be reused for a different client
- - we might have different TLV values for each query
- */
- if (d_ds && d_ds->useProxyProtocol) {
- return false;
+ const auto& remote = conn->getRemote();
+ const auto& it = t_downstreamConnections.find(remote);
+ if (it != t_downstreamConnections.end()) {
+ auto& list = it->second;
+ if (list.size() >= s_maxCachedConnectionsPerDownstream) {
+ /* too many connections queued already */
+ conn.reset();
+ return;
+ }
+ list.push_back(std::move(conn));
+ }
+ else {
+ t_downstreamConnections[remote].push_back(std::move(conn));
}
- return true;
}
- bool matches(const std::shared_ptr<DownstreamState>& ds) const
+ static void cleanupClosedTCPConnections()
{
- if (!ds || !d_ds) {
- return false;
+ for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
+ for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
+ if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
+ ++connIt;
+ }
+ else {
+ connIt = dsIt->second.erase(connIt);
+ }
+ }
+
+ if (!dsIt->second.empty()) {
+ ++dsIt;
+ }
+ else {
+ dsIt = t_downstreamConnections.erase(dsIt);
+ }
}
- return ds == d_ds;
}
private:
- std::unique_ptr<Socket> d_socket{nullptr};
- std::shared_ptr<DownstreamState> d_ds{nullptr};
- struct timeval d_connectionStartTime;
- uint64_t d_queries{0};
- bool d_fresh{true};
- bool d_enableFastOpen{false};
+ static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
+ static const size_t s_maxCachedConnectionsPerDownstream;
};
-static thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> t_downstreamConnections;
+thread_local map<ComboAddress, std::deque<std::unique_ptr<TCPConnectionToBackend>>> DownstreamConnectionsManager::t_downstreamConnections;
+const size_t DownstreamConnectionsManager::s_maxCachedConnectionsPerDownstream{20};
-static std::unique_ptr<TCPConnectionToBackend> getConnectionToDownstream(std::shared_ptr<DownstreamState>& ds, uint16_t& downstreamFailures, const struct timeval& now)
+static void decrementTCPClientCount(const ComboAddress& client)
{
- std::unique_ptr<TCPConnectionToBackend> result;
-
- const auto& it = t_downstreamConnections.find(ds->remote);
- if (it != t_downstreamConnections.end()) {
- auto& list = it->second;
- if (!list.empty()) {
- result = std::move(list.front());
- list.pop_front();
- result->setReused();
- return result;
+ if (g_maxTCPConnectionsPerClient) {
+ std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
+ tcpClientsCount.at(client)--;
+ if (tcpClientsCount[client] == 0) {
+ tcpClientsCount.erase(client);
}
}
-
- return std::unique_ptr<TCPConnectionToBackend>(new TCPConnectionToBackend(ds, downstreamFailures, now));
}
-static void releaseDownstreamConnection(std::unique_ptr<TCPConnectionToBackend>&& conn)
+IncomingTCPConnectionState::~IncomingTCPConnectionState()
{
- if (conn == nullptr) {
- return;
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ decrementTCPClientCount(d_ci.remote);
+ // DEBUG: cerr<<"decremented"<<endl;
+
+ if (d_ci.cs != nullptr) {
+ struct timeval now;
+ gettimeofday(&now, nullptr);
+
+ auto diff = now - d_connectionStartTime;
+ // DEBUG: cerr<<"updating tcp metrics"<<endl;
+ d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
+ // DEBUG: cerr<<"updated tcp metrics"<<endl;
}
- if (!conn->canBeReused()) {
- conn.reset();
- return;
+#if 0
+ if (d_ds != nullptr) {
+
+ if (d_downstreamConnection) {
+ try {
+ if (d_lastIOState == IOState::NeedRead) {
+ // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
+ d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
+ }
+ else if (d_lastIOState == IOState::NeedWrite) {
+ // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
+ d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
+ }
+ }
+ catch(const FDMultiplexerException& e) {
+ vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+ }
+ catch(const std::runtime_error& e) {
+ /* might be thrown by getHandle() */
+ vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
+ }
+ }
}
+#endif
- const auto& remote = conn->getRemote();
- const auto& it = t_downstreamConnections.find(remote);
- if (it != t_downstreamConnections.end()) {
- auto& list = it->second;
- if (list.size() >= g_maxCachedConnectionsPerDownstream) {
- /* too many connections queued already */
- conn.reset();
- return;
+ // DEBUG: cerr<<"about to remove left over FDs"<<endl;
+ try {
+ if (d_lastIOState == IOState::NeedRead) {
+ // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover client read FD "<<d_ci.fd<<endl;
+ d_threadData.mplexer->removeReadFD(d_ci.fd);
+ }
+ else if (d_lastIOState == IOState::NeedWrite) {
+ // DEBUG: cerr<<__PRETTY_FUNCTION__<<": removing leftover client write FD "<<d_ci.fd<<endl;
+ d_threadData.mplexer->removeWriteFD(d_ci.fd);
}
- list.push_back(std::move(conn));
}
- else {
- t_downstreamConnections[remote].push_back(std::move(conn));
+ catch (const FDMultiplexerException& e) {
+ vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+ }
+ catch (...) {
+ vinfolog("Got an unknown exception when trying to remove a pending IO operation on an incoming TCP connection from %s", d_ci.remote.toStringWithPort());
}
+ // DEBUG: cerr<<"done"<<endl;
}
-struct ConnectionInfo
+std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const struct timeval& now)
{
- ConnectionInfo(ClientState* cs_): cs(cs_), fd(-1)
- {
- }
- ConnectionInfo(ConnectionInfo&& rhs): remote(rhs.remote), cs(rhs.cs), fd(rhs.fd)
- {
- rhs.cs = nullptr;
- rhs.fd = -1;
- }
+ std::shared_ptr<TCPConnectionToBackend> downstream{nullptr};
- ConnectionInfo(const ConnectionInfo& rhs) = delete;
- ConnectionInfo& operator=(const ConnectionInfo& rhs) = delete;
-
- ConnectionInfo& operator=(ConnectionInfo&& rhs)
- {
- remote = rhs.remote;
- cs = rhs.cs;
- rhs.cs = nullptr;
- fd = rhs.fd;
- rhs.fd = -1;
- return *this;
+ if (!ds->useProxyProtocol || !d_proxyProtocolPayloadHasTLV) {
+ downstream = getActiveDownstreamConnection(ds);
}
- ~ConnectionInfo()
- {
- if (fd != -1) {
- close(fd);
- fd = -1;
- }
- if (cs) {
- --cs->tcpCurrentConnections;
- }
+ if (!downstream) {
+ /* we don't have a connection to this backend active yet, let's ask one (it might not be a fresh one, though) */
+ downstream = DownstreamConnectionsManager::getConnectionToDownstream(d_threadData.mplexer, ds, now);
}
- ComboAddress remote;
- ClientState* cs{nullptr};
- int fd{-1};
-};
-
-void tcpClientThread(int pipefd);
-
-static void decrementTCPClientCount(const ComboAddress& client)
-{
- if (g_maxTCPConnectionsPerClient) {
- std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
- tcpClientsCount[client]--;
- if (tcpClientsCount[client] == 0) {
- tcpClientsCount.erase(client);
- }
- }
+ return downstream;
}
+static void tcpClientThread(int pipefd);
+
void TCPClientCollection::addTCPClientThread()
{
int pipefds[2] = { -1, -1};
}
try {
- thread t1(tcpClientThread, pipefds[0]);
+ std::thread t1(tcpClientThread, pipefds[0]);
t1.detach();
}
catch(const std::runtime_error& e) {
}
}
-static void cleanupClosedTCPConnections()
-{
- for(auto dsIt = t_downstreamConnections.begin(); dsIt != t_downstreamConnections.end(); ) {
- for (auto connIt = dsIt->second.begin(); connIt != dsIt->second.end(); ) {
- if (*connIt && isTCPSocketUsable((*connIt)->getHandle())) {
- ++connIt;
- }
- else {
- connIt = dsIt->second.erase(connIt);
- }
- }
-
- if (!dsIt->second.empty()) {
- ++dsIt;
- }
- else {
- dsIt = t_downstreamConnections.erase(dsIt);
- }
- }
-}
-
/* Tries to read exactly toRead bytes into the buffer, starting at position pos.
Updates pos everytime a successful read occurs,
throws an std::runtime_error in case of IO error,
would block.
*/
// XXX could probably be implemented as a TCPIOHandler
-static IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
+IOState tryRead(int fd, std::vector<uint8_t>& buffer, size_t& pos, size_t toRead)
{
if (buffer.size() < (pos + toRead)) {
throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead) + " bytes starting at " + std::to_string(pos));
std::unique_ptr<TCPClientCollection> g_tcpclientthreads;
-class TCPClientThreadData
+static IOState handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
{
-public:
- TCPClientThreadData(): localRespRulactions(g_resprulactions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
- {
- }
-
- LocalHolders holders;
- LocalStateHolder<vector<DNSDistResponseRuleAction> > localRespRulactions;
- std::unique_ptr<FDMultiplexer> mplexer{nullptr};
-};
-
-static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param);
-
-class IncomingTCPConnectionState
-{
-public:
- IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_responseBuffer(s_maxPacketCacheEntrySize), d_threadData(threadData), d_ci(std::move(ci)), d_handler(d_ci.fd, g_tcpRecvTimeout, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now)
- {
- d_ids.origDest.reset();
- d_ids.origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
- socklen_t socklen = d_ids.origDest.getSocklen();
- if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_ids.origDest), &socklen)) {
- d_ids.origDest = d_ci.cs->local;
- }
- }
-
- IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
- IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
-
- ~IncomingTCPConnectionState()
- {
- decrementTCPClientCount(d_ci.remote);
- if (d_ci.cs != nullptr) {
- struct timeval now;
- gettimeofday(&now, nullptr);
-
- auto diff = now - d_connectionStartTime;
- d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0);
- }
-
- if (d_ds != nullptr) {
- if (d_outstanding) {
- --d_ds->outstanding;
- d_outstanding = false;
- }
-
- if (d_downstreamConnection) {
- try {
- if (d_lastIOState == IOState::NeedRead) {
- cerr<<__func__<<": removing leftover backend read FD "<<d_downstreamConnection->getHandle()<<endl;
- d_threadData.mplexer->removeReadFD(d_downstreamConnection->getHandle());
- }
- else if (d_lastIOState == IOState::NeedWrite) {
- cerr<<__func__<<": removing leftover backend write FD "<<d_downstreamConnection->getHandle()<<endl;
- d_threadData.mplexer->removeWriteFD(d_downstreamConnection->getHandle());
- }
- }
- catch(const FDMultiplexerException& e) {
- vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
- }
- catch(const std::runtime_error& e) {
- /* might be thrown by getHandle() */
- vinfolog("Got an exception when trying to remove a pending IO operation on the socket to the %s backend: %s", d_ds->getName(), e.what());
- }
- }
- }
-
- try {
- if (d_lastIOState == IOState::NeedRead) {
- cerr<<__func__<<": removing leftover client read FD "<<d_ci.fd<<endl;
- d_threadData.mplexer->removeReadFD(d_ci.fd);
- }
- else if (d_lastIOState == IOState::NeedWrite) {
- cerr<<__func__<<": removing leftover client write FD "<<d_ci.fd<<endl;
- d_threadData.mplexer->removeWriteFD(d_ci.fd);
- }
- }
- catch(const FDMultiplexerException& e) {
- vinfolog("Got an exception when trying to remove a pending IO operation on an incoming TCP connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
- }
- }
-
- void resetForNewQuery()
- {
- d_buffer.resize(sizeof(uint16_t));
- d_currentPos = 0;
- d_querySize = 0;
- d_responseSize = 0;
- d_downstreamFailures = 0;
- d_state = State::readingQuerySize;
- d_lastIOState = IOState::Done;
- d_selfGeneratedResponse = false;
- }
-
- boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
- {
- if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
- return boost::none;
- }
-
- if (g_maxTCPConnectionDuration > 0) {
- auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
- if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
- return now;
- }
- auto remaining = g_maxTCPConnectionDuration - elapsed;
- if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
- now.tv_sec += remaining;
- return now;
- }
- }
-
- now.tv_sec += g_tcpRecvTimeout;
- return now;
- }
-
- boost::optional<struct timeval> getBackendReadTTD(const struct timeval& now) const
- {
- if (d_ds == nullptr) {
- throw std::runtime_error("getBackendReadTTD() without any backend selected");
- }
- if (d_ds->tcpRecvTimeout == 0) {
- return boost::none;
- }
-
- struct timeval res = now;
- res.tv_sec += d_ds->tcpRecvTimeout;
-
- return res;
- }
-
- boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
- {
- if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
- return boost::none;
+ if (!state->d_isXFR) {
+ const auto& currentResponse = state->d_currentResponse;
+ if (state->d_selfGeneratedResponse == false && currentResponse.d_ds) {
+ struct timespec answertime;
+ gettime(&answertime);
+ const auto& ids = currentResponse.d_idstate;
+ double udiff = ids.sentTime.udiff();
+ g_rings.insertResponse(answertime, state->d_ci.remote, ids.qname, ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, currentResponse.d_ds->remote);
+ vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", currentResponse.d_ds->remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff);
+ }
+
+ switch (currentResponse.d_cleartextDH.rcode) {
+ case RCode::NXDomain:
+ ++g_stats.frontendNXDomain;
+ break;
+ case RCode::ServFail:
+ ++g_stats.servfailResponses;
+ ++g_stats.frontendServFail;
+ break;
+ case RCode::NoError:
+ ++g_stats.frontendNoError;
+ break;
}
- struct timeval res = now;
-
- if (g_maxTCPConnectionDuration > 0) {
- auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
- if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
- return res;
- }
- auto remaining = g_maxTCPConnectionDuration - elapsed;
- if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
- res.tv_sec += remaining;
- return res;
- }
+ if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
+ return IOState::Done;
}
- res.tv_sec += g_tcpSendTimeout;
- return res;
- }
-
- boost::optional<struct timeval> getBackendWriteTTD(const struct timeval& now) const
- {
- if (d_ds == nullptr) {
- throw std::runtime_error("getBackendReadTTD() called without any backend selected");
- }
- if (d_ds->tcpSendTimeout == 0) {
- return boost::none;
+ if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+ return IOState::Done;
}
-
- struct timeval res = now;
- res.tv_sec += d_ds->tcpSendTimeout;
-
- return res;
}
- bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
- {
- if (maxConnectionDuration) {
- time_t curtime = now.tv_sec;
- unsigned int elapsed = 0;
- if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
- elapsed = curtime - d_connectionStartTime.tv_sec;
- }
- if (elapsed >= maxConnectionDuration) {
- return true;
- }
- d_remainingTime = maxConnectionDuration - elapsed;
+ if (state->d_queuedResponses.empty()) {
+ // DEBUG: cerr<<"no response remaining"<<endl;
+ if (state->d_isXFR) {
+ /* we should still be reading from the backend, and we don't want to read from the client */
+ state->d_state = IncomingTCPConnectionState::State::idle;
+ state->d_currentPos = 0;
+ // DEBUG: cerr<<"idling for XFR completion"<<endl;
+ return IOState::Done;
+ } else {
+ // DEBUG: cerr<<"reading new queries if any"<<endl;
+ state->resetForNewQuery();
+ return IOState::NeedRead;
}
-
- return false;
}
-
- void dump() const
- {
- static std::mutex s_mutex;
-
- struct timeval now;
- gettimeofday(&now, 0);
-
- {
- std::lock_guard<std::mutex> lock(s_mutex);
- fprintf(stderr, "State is %p\n", this);
- cerr << "Current state is " << static_cast<int>(d_state) << ", got "<<d_queriesCount<<" queries so far" << endl;
- cerr << "Current time is " << now.tv_sec << " - " << now.tv_usec << endl;
- cerr << "Connection started at " << d_connectionStartTime.tv_sec << " - " << d_connectionStartTime.tv_usec << endl;
- if (d_state > State::doingHandshake) {
- cerr << "Handshake done at " << d_handshakeDoneTime.tv_sec << " - " << d_handshakeDoneTime.tv_usec << endl;
- }
- if (d_state > State::readingQuerySize) {
- cerr << "Got first query size at " << d_firstQuerySizeReadTime.tv_sec << " - " << d_firstQuerySizeReadTime.tv_usec << endl;
- }
- if (d_state > State::readingQuerySize) {
- cerr << "Got query size at " << d_querySizeReadTime.tv_sec << " - " << d_querySizeReadTime.tv_usec << endl;
- }
- if (d_state > State::readingQuery) {
- cerr << "Got query at " << d_queryReadTime.tv_sec << " - " << d_queryReadTime.tv_usec << endl;
- }
- if (d_state > State::sendingQueryToBackend) {
- cerr << "Sent query at " << d_querySentTime.tv_sec << " - " << d_querySentTime.tv_usec << endl;
- }
- if (d_state > State::readingResponseFromBackend) {
- cerr << "Got response at " << d_responseReadTime.tv_sec << " - " << d_responseReadTime.tv_usec << endl;
- }
- }
+ else {
+ // DEBUG: cerr<<"queue size is "<<state->d_queuedResponses.size()<<endl;
+ TCPResponse resp = std::move(state->d_queuedResponses.front());
+ state->d_queuedResponses.pop_front();
+ state->d_state = IncomingTCPConnectionState::State::idle;
+ state->sendResponse(state, now, std::move(resp));
+ return IOState::NeedWrite;
}
+}
- enum class State { doingHandshake, readingQuerySize, readingQuery, sendingQueryToBackend, readingResponseSizeFromBackend, readingResponseFromBackend, sendingResponse };
-
- std::vector<uint8_t> d_buffer;
- std::vector<uint8_t> d_responseBuffer;
- TCPClientThreadData& d_threadData;
- IDState d_ids;
- ConnectionInfo d_ci;
- TCPIOHandler d_handler;
- std::unique_ptr<TCPConnectionToBackend> d_downstreamConnection{nullptr};
- std::shared_ptr<DownstreamState> d_ds{nullptr};
- dnsheader d_cleartextDH;
- struct timeval d_connectionStartTime;
- struct timeval d_handshakeDoneTime;
- struct timeval d_firstQuerySizeReadTime;
- struct timeval d_querySizeReadTime;
- struct timeval d_queryReadTime;
- struct timeval d_querySentTime;
- struct timeval d_responseReadTime;
- size_t d_currentPos{0};
- size_t d_queriesCount{0};
- unsigned int d_remainingTime{0};
- uint16_t d_querySize{0};
- uint16_t d_responseSize{0};
- uint16_t d_downstreamFailures{0};
- State d_state{State::doingHandshake};
- IOState d_lastIOState{IOState::Done};
- bool d_readingFirstQuery{true};
- bool d_outstanding{false};
- bool d_firstResponsePacket{true};
- bool d_isXFR{false};
- bool d_xfrStarted{false};
- bool d_selfGeneratedResponse{false};
- bool d_proxyProtocolPayloadAdded{false};
- bool d_proxyProtocolPayloadHasTLV{false};
-};
-
-static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
-static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd=boost::none);
-static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
-static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now);
-
-static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+void IncomingTCPConnectionState::resetForNewQuery()
{
- handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback);
+ d_buffer.resize(sizeof(uint16_t));
+ d_currentPos = 0;
+ d_querySize = 0;
+ //d_responseSize = 0;
+ d_downstreamFailures = 0;
+ d_state = State::readingQuerySize;
+ d_lastIOState = IOState::Done;
+ d_selfGeneratedResponse = false;
+}
- if (state->d_isXFR && state->d_downstreamConnection) {
- /* we need to resume reading from the backend! */
- state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
+/* this version is called when the buffer has been set and the rules have been processed */
+void IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
+{
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ // if we already reading a query (not the query size, mind you), or sending a response we need to either queue the response
+ // otherwise we can start sending it right away
+ if (state->d_state == IncomingTCPConnectionState::State::idle ||
+ state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+
+ state->d_state = IncomingTCPConnectionState::State::sendingResponse;
+
+ uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size());
+ const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) };
+ /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
+ that could occur if we had to deal with the size during the processing,
+ especially alignment issues */
+ response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2);
state->d_currentPos = 0;
- handleDownstreamIO(state, now);
- return;
- }
-
- if (state->d_selfGeneratedResponse == false && state->d_ds) {
- /* if we have no downstream server selected, this was a self-answered response
- but cache hits have a selected server as well, so be careful */
- struct timespec answertime;
- gettime(&answertime);
- double udiff = state->d_ids.sentTime.udiff();
- g_rings.insertResponse(answertime, state->d_ci.remote, state->d_ids.qname, state->d_ids.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(state->d_responseBuffer.size()), state->d_cleartextDH, state->d_ds->remote);
- vinfolog("Got answer from %s, relayed to %s (%s), took %f usec", state->d_ds->remote.toStringWithPort(), state->d_ids.origRemote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), udiff);
- }
+ state->d_currentResponse = std::move(response);
- switch (state->d_cleartextDH.rcode) {
- case RCode::NXDomain:
- ++g_stats.frontendNXDomain;
- break;
- case RCode::ServFail:
- ++g_stats.servfailResponses;
- ++g_stats.frontendServFail;
- break;
- case RCode::NoError:
- ++g_stats.frontendNoError;
- break;
- }
-
- if (g_maxTCPQueriesPerConn && state->d_queriesCount > g_maxTCPQueriesPerConn) {
- vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", state->d_ci.remote.toStringWithPort(), state->d_queriesCount, g_maxTCPQueriesPerConn);
- return;
+ //IncomingTCPConnectionState::handleIO(state, now);
+ state->d_ioState->update(IOState::NeedWrite, handleIOCallback, state, getClientWriteTTD(now));
+ // DEBUG: cerr<<"updated IO state"<<endl;
}
-
- if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
- vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
- return;
+ else {
+ // queue response
+ state->d_queuedResponses.push_back(std::move(response));
+ // DEBUG: cerr<<"queueing response because state is "<<(int)state->d_state<<", queue size is now "<<state->d_queuedResponses.size()<<endl;
}
-
- state->resetForNewQuery();
-
- handleIO(state, now);
-}
-
-static void sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
-{
- state->d_state = IncomingTCPConnectionState::State::sendingResponse;
- const uint8_t sizeBytes[] = { static_cast<uint8_t>(state->d_responseSize / 256), static_cast<uint8_t>(state->d_responseSize % 256) };
- /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
- that could occur if we had to deal with the size during the processing,
- especially alignment issues */
- state->d_responseBuffer.insert(state->d_responseBuffer.begin(), sizeBytes, sizeBytes + 2);
-
- state->d_currentPos = 0;
-
- handleIO(state, now);
}
-static void handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+/* this version is called from the backend code when a new response has been received */
+void IncomingTCPConnectionState::handleResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
{
- if (state->d_responseSize < sizeof(dnsheader) || !state->d_ds) {
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ if (response.d_buffer.size() < sizeof(dnsheader)) {
+ // DEBUG: cerr<<"too small"<<endl;
return;
}
- auto response = reinterpret_cast<char*>(&state->d_responseBuffer.at(0));
+ uint16_t responseSize = response.d_buffer.size();
+ response.d_buffer.resize(responseSize + static_cast<size_t>(512));
+ size_t responseCapacity = response.d_buffer.size();
+ auto responseAsCharArray = reinterpret_cast<char*>(&response.d_buffer.at(0));
+
+ auto& ids = response.d_idstate;
+ // DEBUG: cerr<<"IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
unsigned int consumed;
- if (state->d_firstResponsePacket && !responseContentMatches(response, state->d_responseSize, state->d_ids.qname, state->d_ids.qtype, state->d_ids.qclass, state->d_ds->remote, consumed)) {
+ // DEBUG: cerr<<"about to match response for "<<ids.qname<<endl;
+ if (!responseContentMatches(responseAsCharArray, responseSize, ids.qname, ids.qtype, ids.qclass, response.d_ds->remote, consumed)) {
+ // DEBUG: cerr<<"content does not match"<<endl;
return;
}
- state->d_firstResponsePacket = false;
-
- if (state->d_outstanding) {
- --state->d_ds->outstanding;
- state->d_outstanding = false;
- }
- auto dh = reinterpret_cast<struct dnsheader*>(response);
+ auto dh = reinterpret_cast<struct dnsheader*>(responseAsCharArray);
uint16_t addRoom = 0;
- DNSResponse dr = makeDNSResponseFromIDState(state->d_ids, dh, state->d_responseBuffer.size(), state->d_responseSize, true);
+ DNSResponse dr = makeDNSResponseFromIDState(ids, dh, responseCapacity, responseSize, true);
if (dr.dnsCryptQuery) {
addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
}
- memcpy(&state->d_cleartextDH, dr.dh, sizeof(state->d_cleartextDH));
+ memcpy(&response.d_cleartextDH, dr.dh, sizeof(response.d_cleartextDH));
std::vector<uint8_t> rewrittenResponse;
- size_t responseSize = state->d_responseBuffer.size();
- if (!processResponse(&response, &state->d_responseSize, &responseSize, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+ if (!processResponse(&responseAsCharArray, &responseSize, &responseCapacity, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) {
+ // DEBUG: cerr<<"process said to drop it"<<endl;
return;
}
if (!rewrittenResponse.empty()) {
/* responseSize has been updated as well but we don't really care since it will match
the capacity of rewrittenResponse anyway */
- state->d_responseBuffer = std::move(rewrittenResponse);
- state->d_responseSize = state->d_responseBuffer.size();
+ response.d_buffer = std::move(rewrittenResponse);
} else {
/* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */
- state->d_responseBuffer.resize(state->d_responseSize);
+ response.d_buffer.resize(responseSize);
}
if (state->d_isXFR && !state->d_xfrStarted) {
state->d_xfrStarted = true;
++g_stats.responses;
++state->d_ci.cs->responses;
- ++state->d_ds->responses;
+ ++response.d_ds->responses;
}
if (!state->d_isXFR) {
++g_stats.responses;
++state->d_ci.cs->responses;
- ++state->d_ds->responses;
+ ++response.d_ds->responses;
}
- sendResponse(state, now);
+ sendResponse(state, now, std::move(response));
}
-static void sendQueryToBackend(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
-{
- auto ds = state->d_ds;
- state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend;
- state->d_currentPos = 0;
- state->d_firstResponsePacket = true;
-
- if (state->d_xfrStarted) {
- /* sorry, but we are not going to resume a XFR if we have already sent some packets
- to the client */
- return;
- }
-
- if (!state->d_downstreamConnection) {
- if (state->d_downstreamFailures < state->d_ds->retries) {
- try {
- state->d_downstreamConnection = getConnectionToDownstream(ds, state->d_downstreamFailures, now);
- }
- catch (const std::runtime_error& e) {
- state->d_downstreamConnection.reset();
- }
- }
-
- if (!state->d_downstreamConnection) {
- ++ds->tcpGaveUp;
- ++state->d_ci.cs->tcpGaveUp;
- vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), state->d_downstreamFailures);
- return;
- }
-
- if (ds->useProxyProtocol && !state->d_proxyProtocolPayloadAdded) {
- /* we know there is no TLV values to add, otherwise we would not have tried
- to reuse the connection and d_proxyProtocolPayloadAdded would be true already */
- addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, std::vector<ProxyProtocolValue>());
- state->d_proxyProtocolPayloadAdded = true;
- }
- }
-
- vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", state->d_ids.qname.toLogString(), QType(state->d_ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
-
- handleDownstreamIO(state, now);
- return;
-}
-
-static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+static bool handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
{
if (state->d_querySize < sizeof(dnsheader)) {
++g_stats.nonCompliantQueries;
- return;
+ return true;
}
state->d_readingFirstQuery = false;
std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true);
if (dnsCryptResponse) {
- state->d_responseBuffer = std::move(*dnsCryptResponse);
- state->d_responseSize = state->d_responseBuffer.size();
- sendResponse(state, now);
- return;
+ //state->d_responseBuffer = std::move(*dnsCryptResponse);
+ //state->d_responseSize = state->d_responseBuffer.size();
+ TCPResponse response;
+ response.d_buffer = std::move(*dnsCryptResponse);
+ state->d_state = IncomingTCPConnectionState::State::idle;
+ state->sendResponse(state, now, std::move(response));
+ return false;
}
const auto& dh = reinterpret_cast<dnsheader*>(query);
if (!checkQueryHeaders(dh)) {
- return;
+ return true;
}
uint16_t qtype, qclass;
unsigned int consumed = 0;
DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_ids.origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
+ DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_origDest, &state->d_ci.remote, reinterpret_cast<dnsheader*>(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime);
dq.dnsCryptQuery = std::move(dnsCryptQuery);
dq.sni = state->d_handler.getServerNameIndication();
dq.skipCache = true;
}
- state->d_ds.reset();
- auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, state->d_ds);
+ std::shared_ptr<DownstreamState> ds;
+ auto result = processQuery(dq, *state->d_ci.cs, state->d_threadData.holders, ds);
if (result == ProcessQueryResult::Drop) {
- return;
+ return true;
}
if (result == ProcessQueryResult::SendAnswer) {
state->d_selfGeneratedResponse = true;
state->d_buffer.resize(dq.len);
- state->d_responseBuffer = std::move(state->d_buffer);
- state->d_responseSize = state->d_responseBuffer.size();
- sendResponse(state, now);
- return;
+ TCPResponse response;
+ response.d_buffer = std::move(state->d_buffer);
+ state->d_state = IncomingTCPConnectionState::State::idle;
+ state->sendResponse(state, now, std::move(response));
+ return false;
}
- if (result != ProcessQueryResult::PassToBackend || state->d_ds == nullptr) {
- return;
+ if (result != ProcessQueryResult::PassToBackend || ds == nullptr) {
+ return true;
+ }
+
+#warning move this, we should just never read another question again on this client connection
+ if (state->d_xfrStarted) {
+ /* sorry, but we are not going to resume a XFR if we have already sent some packets
+ to the client */
+ return true;
}
- setIDStateFromDNSQuestion(state->d_ids, dq, std::move(qname));
+ IDState ids;
+ // DEBUG: cerr<<"DQ has "<<(dq.qTag?" TAGS ": "NO TAGS")<<endl;
+ setIDStateFromDNSQuestion(ids, dq, std::move(qname));
+ // DEBUG: cerr<<"query IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
+ ids.origID = ntohs(dh->id);
const uint8_t sizeBytes[] = { static_cast<uint8_t>(dq.len / 256), static_cast<uint8_t>(dq.len % 256) };
/* prepend the size. Yes, this is not the most efficient way but it prevents mistakes
dq.size = state->d_buffer.size();
state->d_buffer.resize(dq.len);
- if (state->d_ds->useProxyProtocol) {
+ bool proxyProtocolPayloadAdded = false;
+ std::string proxyProtocolPayload;
+
+ if (ds->useProxyProtocol) {
/* if we ever sent a TLV over a connection, we can never go back */
if (!state->d_proxyProtocolPayloadHasTLV) {
state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty();
}
- if (state->d_downstreamConnection && !state->d_proxyProtocolPayloadHasTLV && state->d_downstreamConnection->matches(state->d_ds)) {
- /* we have an existing connection, on which we already sent a Proxy Protocol header with no values
- (in the previous query had TLV values we would have reset the connection afterwards),
- so let's reuse it as long as we still don't have any values */
- state->d_proxyProtocolPayloadAdded = false;
- }
- else {
- state->d_downstreamConnection.reset();
- addProxyProtocol(state->d_buffer, true, state->d_ci.remote, state->d_ids.origDest, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector<ProxyProtocolValue>());
- state->d_proxyProtocolPayloadAdded = true;
+ proxyProtocolPayload = getProxyProtocolPayload(dq);
+
+ if (state->d_proxyProtocolPayloadHasTLV) {
+ /* we will not be able to reuse an existing connection anyway so let's add the payload right now */
+ addProxyProtocol(state->d_buffer, proxyProtocolPayload);
+ proxyProtocolPayloadAdded = true;
}
}
- sendQueryToBackend(state, now);
-}
-
-static void handleNewIOState(std::shared_ptr<IncomingTCPConnectionState>& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional<struct timeval> ttd)
-{
- //cerr<<"in "<<__func__<<" for fd "<<fd<<", last state was "<<(int)state->d_lastIOState<<", new state is "<<(int)iostate<<endl;
+ auto downstreamConnection = state->getDownstreamConnection(ds, now);
+ downstreamConnection->assignToClientConnection(state, state->d_isXFR);
- if (state->d_lastIOState == IOState::NeedRead && iostate != IOState::NeedRead) {
- state->d_threadData.mplexer->removeReadFD(fd);
- //cerr<<__func__<<": remove read FD "<<fd<<endl;
- state->d_lastIOState = IOState::Done;
+ if (proxyProtocolPayloadAdded) {
+ downstreamConnection->setProxyProtocolPayloadAdded(true);
}
- else if (state->d_lastIOState == IOState::NeedWrite && iostate != IOState::NeedWrite) {
- state->d_threadData.mplexer->removeWriteFD(fd);
- //cerr<<__func__<<": remove write FD "<<fd<<endl;
- state->d_lastIOState = IOState::Done;
+ else {
+ downstreamConnection->setProxyProtocolPayload(std::move(proxyProtocolPayload));
}
- if (iostate == IOState::NeedRead) {
- if (state->d_lastIOState == IOState::NeedRead) {
- if (ttd) {
- /* let's update the TTD ! */
- state->d_threadData.mplexer->setReadTTD(fd, *ttd, /* we pass 0 here because we already have a TTD */0);
- }
- return;
- }
+ vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).getName(), state->d_ci.remote.toStringWithPort(), (state->d_ci.cs->tlsFrontend ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName());
- state->d_lastIOState = IOState::NeedRead;
- //cerr<<__func__<<": add read FD "<<fd<<endl;
- state->d_threadData.mplexer->addReadFD(fd, callback, state, ttd ? &*ttd : nullptr);
- }
- else if (iostate == IOState::NeedWrite) {
- if (state->d_lastIOState == IOState::NeedWrite) {
- return;
- }
+// DEBUG: cerr<<"about to be queued query IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
+ downstreamConnection->queueQuery(TCPQuery(std::move(state->d_buffer), std::move(ids)), downstreamConnection);
- state->d_lastIOState = IOState::NeedWrite;
- //cerr<<__func__<<": add write FD "<<fd<<endl;
- state->d_threadData.mplexer->addWriteFD(fd, callback, state, ttd ? &*ttd : nullptr);
- }
- else if (iostate == IOState::Done) {
- state->d_lastIOState = IOState::Done;
- }
+ //sendQueryToBackend(state, now);
+ // DEBUG: cerr<<"out of "<<__PRETTY_FUNCTION__<<endl;
+ return true;
}
-static void handleDownstreamIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
+void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
{
- if (state->d_downstreamConnection == nullptr) {
- throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
+ auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
+ if (fd != conn->d_ci.fd) {
+ throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_ci.fd));
}
- int fd = state->d_downstreamConnection->getHandle();
- IOState iostate = IOState::Done;
- bool connectionDied = false;
+ struct timeval now;
+ gettimeofday(&now, 0);
+ handleIO(conn, now);
+}
- try {
- if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
- int socketFlags = 0;
-#ifdef MSG_FASTOPEN
- if (state->d_downstreamConnection->isFastOpenEnabled()) {
- socketFlags |= MSG_FASTOPEN;
- }
-#endif /* MSG_FASTOPEN */
-
- size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&state->d_buffer.at(state->d_currentPos)), state->d_buffer.size() - state->d_currentPos, &state->d_ds->remote, &state->d_ds->sourceAddr, state->d_ds->sourceItf, socketFlags);
- if (sent == state->d_buffer.size()) {
- /* request sent ! */
- state->d_downstreamConnection->incQueries();
- state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend;
- state->d_currentPos = 0;
- state->d_querySentTime = now;
- iostate = IOState::NeedRead;
- if (!state->d_isXFR && !state->d_outstanding) {
- /* don't bother with the outstanding count for XFR queries */
- ++state->d_ds->outstanding;
- state->d_outstanding = true;
- }
- }
- else {
- state->d_currentPos += sent;
- iostate = IOState::NeedWrite;
- /* disable fast open on partial write */
- state->d_downstreamConnection->disableFastOpen();
- }
- }
+void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now)
+{
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read
+ // even though the underlying socket is not ready, so we need to actually ask for the data first
+ bool wouldBlock = false;
+ IOState iostate = IOState::Done;
+ do {
+ iostate = IOState::Done;
+ IOStateGuard ioGuard(state->d_ioState);
- if (state->d_state == IncomingTCPConnectionState::State::readingResponseSizeFromBackend) {
- // then we need to allocate a new buffer (new because we might need to re-send the query if the
- // backend dies on us
- // We also might need to read and send to the client more than one response in case of XFR (yeah!)
- // should very likely be a TCPIOHandler d_downstreamHandler
- iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, sizeof(uint16_t) - state->d_currentPos);
- if (iostate == IOState::Done) {
- state->d_state = IncomingTCPConnectionState::State::readingResponseFromBackend;
- state->d_responseSize = state->d_responseBuffer.at(0) * 256 + state->d_responseBuffer.at(1);
- state->d_responseBuffer.resize((state->d_ids.dnsCryptQuery && (UINT16_MAX - state->d_responseSize) > static_cast<uint16_t>(DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE)) ? state->d_responseSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE : state->d_responseSize);
- state->d_currentPos = 0;
- }
+ if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
+ // will be handled by the ioGuard
+ //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+ return;
}
- if (state->d_state == IncomingTCPConnectionState::State::readingResponseFromBackend) {
- iostate = tryRead(fd, state->d_responseBuffer, state->d_currentPos, state->d_responseSize - state->d_currentPos);
- if (iostate == IOState::Done) {
- handleNewIOState(state, IOState::Done, fd, handleDownstreamIOCallback);
-
- if (state->d_isXFR) {
- /* Don't reuse the TCP connection after an {A,I}XFR */
- /* but don't reset it either, we will need to read more messages */
- }
- else {
- /* if we did not send a Proxy Protocol header, let's pool the connection */
- if (state->d_ds && state->d_ds->useProxyProtocol == false) {
- releaseDownstreamConnection(std::move(state->d_downstreamConnection));
- }
- else {
- if (state->d_proxyProtocolPayloadHasTLV) {
- /* sent a Proxy Protocol header with TLV values, we can't reuse it */
- state->d_downstreamConnection.reset();
+ try {
+ if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
+ // DEBUG: cerr<<"doing handshake"<<endl;
+ iostate = state->d_handler.tryHandshake();
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<"handshake done"<<endl;
+ if (state->d_handler.isTLS()) {
+ if (!state->d_handler.hasTLSSessionBeenResumed()) {
+ ++state->d_ci.cs->tlsNewSessions;
}
else {
- /* if we did but there was no TLV values, let's try to reuse it but only
- for this incoming connection */
+ ++state->d_ci.cs->tlsResumptions;
+ }
+ if (state->d_handler.getResumedFromInactiveTicketKey()) {
+ ++state->d_ci.cs->tlsInactiveTicketKey;
+ }
+ if (state->d_handler.getUnknownTicketKey()) {
+ ++state->d_ci.cs->tlsUnknownTicketKey;
}
}
- }
- fd = -1;
- state->d_responseReadTime = now;
- try {
- handleResponse(state, now);
+ state->d_handshakeDoneTime = now;
+ state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
}
- catch (const std::exception& e) {
- vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", state->d_ds ? state->d_ds->getName() : "unknown", state->d_ci.remote.toStringWithPort(), e.what());
+ else {
+ wouldBlock = true;
}
- return;
}
- }
- if (state->d_state != IncomingTCPConnectionState::State::sendingQueryToBackend &&
- state->d_state != IncomingTCPConnectionState::State::readingResponseSizeFromBackend &&
- state->d_state != IncomingTCPConnectionState::State::readingResponseFromBackend) {
- vinfolog("Unexpected state %d in handleDownstreamIOCallback", static_cast<int>(state->d_state));
- }
- }
- catch(const std::exception& e) {
- /* most likely an EOF because the other end closed the connection,
- but it might also be a real IO error or something else.
- Let's just drop the connection
- */
- vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading from" : "writing to"), state->d_ci.remote.toStringWithPort(), e.what());
- if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) {
- ++state->d_ds->tcpDiedSendingQuery;
- }
- else {
- ++state->d_ds->tcpDiedReadingResponse;
- }
-
- /* don't increase this counter when reusing connections */
- if (state->d_downstreamConnection && state->d_downstreamConnection->isFresh()) {
- ++state->d_downstreamFailures;
- }
-
- if (state->d_outstanding) {
- state->d_outstanding = false;
+ if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
+ // DEBUG: cerr<<"reading query size"<<endl;
+ iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<"query size received"<<endl;
+ state->d_state = IncomingTCPConnectionState::State::readingQuery;
+ state->d_querySizeReadTime = now;
+ if (state->d_queriesCount == 0) {
+ state->d_firstQuerySizeReadTime = now;
+ }
+ state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
+ if (state->d_querySize < sizeof(dnsheader)) {
+ /* go away */
+ // will be handled by the guard
+ //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+ return;
+ }
- if (state->d_ds != nullptr) {
- --state->d_ds->outstanding;
+ /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
+ or to add ECS without allocating a new buffer */
+ state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
+ state->d_currentPos = 0;
+ }
+ else {
+ wouldBlock = true;
+ }
}
- }
- /* remove this FD from the IO multiplexer */
- iostate = IOState::Done;
- connectionDied = true;
- }
- if (iostate == IOState::Done) {
- handleNewIOState(state, iostate, fd, handleDownstreamIOCallback);
- }
- else {
- handleNewIOState(state, iostate, fd, handleDownstreamIOCallback, iostate == IOState::NeedRead ? state->getBackendReadTTD(now) : state->getBackendWriteTTD(now));
- }
-
- if (connectionDied) {
- state->d_downstreamConnection.reset();
- sendQueryToBackend(state, now);
- }
-}
-
-static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param)
-{
- auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
- if (state->d_downstreamConnection == nullptr) {
- throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!");
- }
- if (fd != state->d_downstreamConnection->getHandle()) {
- throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamConnection->getHandle()));
- }
-
- struct timeval now;
- gettimeofday(&now, 0);
- handleDownstreamIO(state, now);
-}
-
-static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, struct timeval& now)
-{
- int fd = state->d_ci.fd;
- IOState iostate = IOState::Done;
-
- if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
- vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort());
- handleNewIOState(state, IOState::Done, fd, handleIOCallback);
- return;
- }
-
- try {
- if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) {
- iostate = state->d_handler.tryHandshake();
- if (iostate == IOState::Done) {
- if (state->d_handler.isTLS()) {
- if (!state->d_handler.hasTLSSessionBeenResumed()) {
- ++state->d_ci.cs->tlsNewSessions;
+ if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+ // DEBUG: cerr<<"reading query"<<endl;
+ iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<"query received"<<endl;
+ //handleNewIOState(state, IOState::Done, fd, handleIOCallback);
+ if (handleQuery(state, now)) {
+ // DEBUG: cerr<<"handle query returned true"<<endl;
+ // if the query has been passed to a backend, or dropped, we can start
+ // reading again, or sending queued responses
+ if (state->d_queuedResponses.empty()) {
+ state->resetForNewQuery();
+ // DEBUG: cerr<<__LINE__<<endl;
+ iostate = IOState::NeedRead;
+ //state->d_ioState->update(IOState::NeedRead, handleIOCallback, state, state->getClientReadTTD(now));
+ // DEBUG: cerr<<__LINE__<<endl;
+ //ioGuard.release();
+ }
+ else {
+ TCPResponse resp = std::move(state->d_queuedResponses.front());
+ state->d_queuedResponses.pop_front();
+ ioGuard.release();
+ state->sendResponse(state, now, std::move(resp));
+ return;
+ }
}
else {
- ++state->d_ci.cs->tlsResumptions;
- }
- if (state->d_handler.getResumedFromInactiveTicketKey()) {
- ++state->d_ci.cs->tlsInactiveTicketKey;
- }
- if (state->d_handler.getUnknownTicketKey()) {
- ++state->d_ci.cs->tlsUnknownTicketKey;
+ /* otherwise the state should already be waiting for
+ the socket to be writable */
+ // DEBUG: cerr<<"should be waiting for writable socket"<<endl;
+ ioGuard.release();
+ return;
}
}
-
- state->d_handshakeDoneTime = now;
- state->d_state = IncomingTCPConnectionState::State::readingQuerySize;
+ else {
+ wouldBlock = true;
+ }
}
- }
- if (state->d_state == IncomingTCPConnectionState::State::readingQuerySize) {
- iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t));
- if (iostate == IOState::Done) {
- state->d_state = IncomingTCPConnectionState::State::readingQuery;
- state->d_querySizeReadTime = now;
- if (state->d_queriesCount == 0) {
- state->d_firstQuerySizeReadTime = now;
- }
- state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1);
- if (state->d_querySize < sizeof(dnsheader)) {
- /* go away */
- handleNewIOState(state, IOState::Done, fd, handleIOCallback);
- return;
+ if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+ // DEBUG: cerr<<"sending response"<<endl;
+ iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size());
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<"response sent"<<endl;
+ iostate = handleResponseSent(state, now);
+ } else {
+ wouldBlock = true;
+ // DEBUG: cerr<<"partial write"<<endl;
}
-
- /* allocate a bit more memory to be able to spoof the content, get an answer from the cache
- or to add ECS without allocating a new buffer */
- state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize));
- state->d_currentPos = 0;
+ // DEBUG: cerr<<__LINE__<<endl;
+ //state->d_ioState->update(IOState::NeedRead, handleIOCallback, state, state->getClientReadTTD(now));
+ //// DEBUG: cerr<<__LINE__<<endl;
}
- }
- if (state->d_state == IncomingTCPConnectionState::State::readingQuery) {
- iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize);
- if (iostate == IOState::Done) {
- handleNewIOState(state, IOState::Done, fd, handleIOCallback);
- handleQuery(state, now);
- return;
+ if (state->d_state != IncomingTCPConnectionState::State::idle &&
+ state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
+ state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
+ state->d_state != IncomingTCPConnectionState::State::readingQuery &&
+ state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
+ vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
}
}
-
- if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
- iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size());
- if (iostate == IOState::Done) {
- handleResponseSent(state, now);
- return;
+ catch(const std::exception& e) {
+ /* most likely an EOF because the other end closed the connection,
+ but it might also be a real IO error or something else.
+ Let's just drop the connection
+ */
+ if (state->d_state == IncomingTCPConnectionState::State::idle ||
+ state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
+ state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
+ state->d_state == IncomingTCPConnectionState::State::readingQuery) {
+ ++state->d_ci.cs->tcpDiedReadingQuery;
+ }
+ else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
+ ++state->d_ci.cs->tcpDiedSendingResponse;
}
- }
- if (state->d_state != IncomingTCPConnectionState::State::doingHandshake &&
- state->d_state != IncomingTCPConnectionState::State::readingQuerySize &&
- state->d_state != IncomingTCPConnectionState::State::readingQuery &&
- state->d_state != IncomingTCPConnectionState::State::sendingResponse) {
- vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state));
- }
- }
- catch(const std::exception& e) {
- /* most likely an EOF because the other end closed the connection,
- but it might also be a real IO error or something else.
- Let's just drop the connection
- */
- if (state->d_state == IncomingTCPConnectionState::State::doingHandshake ||
- state->d_state == IncomingTCPConnectionState::State::readingQuerySize ||
- state->d_state == IncomingTCPConnectionState::State::readingQuery) {
- ++state->d_ci.cs->tcpDiedReadingQuery;
- }
- else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) {
- ++state->d_ci.cs->tcpDiedSendingResponse;
+ if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
+ // DEBUG: cerr<<"Got an exception while handling TCP query: "<<e.what()<<endl;
+ vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+ }
+ else {
+ vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+ // DEBUG: cerr<<"Closing TCP client connection: "<<e.what()<<endl;
+ }
+ /* remove this FD from the IO multiplexer */
+ iostate = IOState::Done;
}
- if (state->d_lastIOState == IOState::NeedWrite || state->d_readingFirstQuery) {
- vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_lastIOState == IOState::NeedRead ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what());
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<__LINE__<<endl;
+ state->d_ioState->update(iostate, handleIOCallback, state);
+ // DEBUG: cerr<<__LINE__<<endl;
}
else {
- vinfolog("Closing TCP client connection with %s", state->d_ci.remote.toStringWithPort());
+ // DEBUG: cerr<<__LINE__<<endl;
+ state->d_ioState->update(iostate, handleIOCallback, state, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+ // DEBUG: cerr<<__LINE__<<endl;
}
- /* remove this FD from the IO multiplexer */
- iostate = IOState::Done;
+ ioGuard.release();
+ }
+ while (state->d_state == IncomingTCPConnectionState::State::readingQuerySize && iostate == IOState::NeedRead && !wouldBlock);
+}
+
+void IncomingTCPConnectionState::notifyIOError(std::shared_ptr<IncomingTCPConnectionState>& state, IDState&& query, const struct timeval& now)
+{
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ if (d_isXFR) {
+ d_xfrDone = true;
}
- if (iostate == IOState::Done) {
- handleNewIOState(state, iostate, fd, handleIOCallback);
+ if (d_state == State::sendingResponse) {
+ /* if we have responses to send, let's do that first */
+ }
+ else if (!d_queuedResponses.empty()) {
+ /* stop reading and send what we have */
+ TCPResponse resp = std::move(d_queuedResponses.front());
+ d_queuedResponses.pop_front();
+ sendResponse(state, now, std::move(resp));
}
else {
- handleNewIOState(state, iostate, fd, handleIOCallback, iostate == IOState::NeedRead ? state->getClientReadTTD(now) : state->getClientWriteTTD(now));
+ // the backend code already tried to reconnect if it was possible
+ d_lastIOState = IOState::Done;
+ d_ioState->reset();
}
+
+ // DEBUG: cerr<<"out "<<__PRETTY_FUNCTION__<<endl;
}
-static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+void IncomingTCPConnectionState::handleXFRResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response)
{
- auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param);
- if (fd != state->d_ci.fd) {
- throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_ci.fd));
- }
- struct timeval now;
- gettimeofday(&now, 0);
+ sendResponse(state, now, std::move(response));
+}
- handleIO(state, now);
+void IncomingTCPConnectionState::handleTimeout(bool write)
+{
+ // DEBUG: cerr<<"client timeout"<<endl;
+ ++d_ci.cs->tcpClientTimeouts;
+ d_lastIOState = IOState::Done;
+ d_ioState->reset();
}
static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param)
/* let's update the remaining time */
state->d_remainingTime = g_maxTCPConnectionDuration;
- handleIO(state, now);
+ IncomingTCPConnectionState::handleIO(state, now);
}
catch(...) {
delete citmp;
}
}
-void tcpClientThread(int pipefd)
+static void tcpClientThread(int pipefd)
{
/* we get launched with a pipe on which we receive file descriptors from clients that we own
from that point on */
data.mplexer->run(&now);
if (g_downstreamTCPCleanupInterval > 0 && (now.tv_sec > (lastTCPCleanup + g_downstreamTCPCleanupInterval))) {
- cleanupClosedTCPConnections();
+ DownstreamConnectionsManager::cleanupClosedTCPConnections();
lastTCPCleanup = now.tv_sec;
}
if (now.tv_sec > lastTimeoutScan) {
lastTimeoutScan = now.tv_sec;
auto expiredReadConns = data.mplexer->getTimeouts(now, false);
- for(const auto& conn : expiredReadConns) {
- auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
- if (conn.first == state->d_ci.fd) {
- vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
- ++state->d_ci.cs->tcpClientTimeouts;
+ for (const auto& cbData : expiredReadConns) {
+ if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
+ auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
+ if (cbData.first == state->d_ci.fd) {
+ vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+ state->handleTimeout(false);
+ }
}
- else if (state->d_ds) {
- vinfolog("Timeout (read) from remote backend %s", state->d_ds->getName());
- ++state->d_ci.cs->tcpDownstreamTimeouts;
- ++state->d_ds->tcpReadTimeouts;
+ else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
+ auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
+ vinfolog("Timeout (read) from remote backend %s", conn->getBackendName());
+ conn->handleTimeout(now, false);
}
- data.mplexer->removeReadFD(conn.first);
- state->d_lastIOState = IOState::Done;
}
auto expiredWriteConns = data.mplexer->getTimeouts(now, true);
- for(const auto& conn : expiredWriteConns) {
- auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(conn.second);
- if (conn.first == state->d_ci.fd) {
- vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
- ++state->d_ci.cs->tcpClientTimeouts;
+ for (const auto& cbData : expiredWriteConns) {
+ if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) {
+ auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second);
+ if (cbData.first == state->d_ci.fd) {
+ vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort());
+ state->handleTimeout(true);
+ }
+ }
+ else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) {
+ auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second);
+ vinfolog("Timeout (write) from remote backend %s", conn->getBackendName());
+ conn->handleTimeout(now, true);
}
- else if (state->d_ds) {
- vinfolog("Timeout (write) from remote backend %s", state->d_ds->getName());
- ++state->d_ci.cs->tcpDownstreamTimeouts;
- ++state->d_ds->tcpWriteTimeouts;
+#if 0
+ try {
+ data.mplexer->removeWriteFD(cbData.first);
+ }
+ catch (const FDMultiplexerException& fde) {
+ warnlog("Exception while removing a socket (%d) after a write timeout: %s", cbData.first, fde.what());
}
- data.mplexer->removeWriteFD(conn.first);
- state->d_lastIOState = IOState::Done;
+#endif
}
}
}
--- /dev/null
+
+#include "dnsdist-tcp-downstream.hh"
+#include "dnsdist-tcp-upstream.hh"
+
+const uint16_t TCPConnectionToBackend::s_xfrID = 0;
+
+void TCPConnectionToBackend::assignToClientConnection(std::shared_ptr<IncomingTCPConnectionState>& clientConn, bool isXFR)
+{
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ if (isXFR) {
+ d_usedForXFR = true;
+ }
+
+ d_clientConn = clientConn;
+ d_ioState = make_unique<IOStateHandler>(clientConn->getIOMPlexer(), d_socket->getHandle());
+}
+
+IOState TCPConnectionToBackend::sendNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
+{
+ conn->d_currentQuery = std::move(conn->d_pendingQueries.front());
+ conn->d_pendingQueries.pop_front();
+ conn->d_state = State::sendingQueryToBackend;
+ return IOState::NeedWrite;
+}
+
+void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now)
+{
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ if (conn->d_socket == nullptr) {
+ throw std::runtime_error("No downstream socket in " + std::string(__PRETTY_FUNCTION__) + "!");
+ }
+
+ bool connectionDied = false;
+ IOState iostate = IOState::Done;
+ IOStateGuard ioGuard(conn->d_ioState);
+ int fd = conn->d_socket->getHandle();
+
+ try {
+ if (conn->d_state == State::sendingQueryToBackend) {
+ // DEBUG: cerr<<"sending query to backend over FD "<<fd<<endl;
+ int socketFlags = 0;
+#ifdef MSG_FASTOPEN
+ if (conn->isFastOpenEnabled()) {
+ socketFlags |= MSG_FASTOPEN;
+ }
+#endif /* MSG_FASTOPEN */
+
+ size_t sent = sendMsgWithOptions(fd, reinterpret_cast<const char *>(&conn->d_currentQuery.d_buffer.at(conn->d_currentPos)), conn->d_currentQuery.d_buffer.size() - conn->d_currentPos, &conn->d_ds->remote, &conn->d_ds->sourceAddr, conn->d_ds->sourceItf, socketFlags);
+ if (sent == conn->d_currentQuery.d_buffer.size()) {
+ // DEBUG: cerr<<"query sent to backend"<<endl;
+ /* request sent ! */
+ conn->incQueries();
+ conn->d_currentPos = 0;
+ //conn->d_currentQuery.d_querySentTime = now;
+ // DEBUG: cerr<<"adding a pending response for ID "<<conn->d_currentQuery.d_idstate.origID<<" and QNAME "<<conn->d_currentQuery.d_idstate.qname<<endl;
+ // DEBUG: cerr<<"IDS has "<<(conn->d_currentQuery.d_idstate.qTag?"tags":"no tags")<<endl;
+ conn->d_pendingResponses[conn->d_currentQuery.d_idstate.origID] = std::move(conn->d_currentQuery);
+ conn->d_currentQuery.d_buffer.clear();
+#if 0
+ if (!conn->d_usedForXFR) {
+ /* don't bother with the outstanding count for XFR queries */
+ ++conn->d_ds->outstanding;
+ ++conn->d_outstanding;
+ }
+#endif
+
+ if (conn->d_pendingQueries.empty()) {
+ conn->d_state = State::readingResponseSizeFromBackend;
+ conn->d_currentPos = 0;
+ conn->d_responseBuffer.resize(sizeof(uint16_t));
+ iostate = IOState::NeedRead;
+ }
+ else {
+ iostate = sendNextQuery(conn);
+ }
+ }
+ else {
+ conn->d_currentPos += sent;
+ iostate = IOState::NeedWrite;
+ /* disable fast open on partial write */
+ conn->disableFastOpen();
+ }
+ }
+
+ if (conn->d_state == State::readingResponseSizeFromBackend) {
+ // DEBUG: cerr<<"reading response size from backend"<<endl;
+ // then we need to allocate a new buffer (new because we might need to re-send the query if the
+ // backend dies on us)
+ // We also might need to read and send to the client more than one response in case of XFR (yeah!)
+ // should very likely be a TCPIOHandler d_downstreamHandler
+ conn->d_responseBuffer.resize(sizeof(uint16_t));
+ iostate = tryRead(fd, conn->d_responseBuffer, conn->d_currentPos, sizeof(uint16_t) - conn->d_currentPos);
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<"got response size from backend"<<endl;
+ conn->d_state = State::readingResponseFromBackend;
+ conn->d_responseSize = conn->d_responseBuffer.at(0) * 256 + conn->d_responseBuffer.at(1);
+ conn->d_responseBuffer.reserve(conn->d_responseSize + /* we will need to prepend the size later */ 2);
+ conn->d_responseBuffer.resize(conn->d_responseSize);
+ conn->d_currentPos = 0;
+ }
+ }
+
+ if (conn->d_state == State::readingResponseFromBackend) {
+ // DEBUG: cerr<<"reading response from backend"<<endl;
+ iostate = tryRead(fd, conn->d_responseBuffer, conn->d_currentPos, conn->d_responseSize - conn->d_currentPos);
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<"got response from backend"<<endl;
+ //conn->d_responseReadTime = now;
+ try {
+ iostate = conn->handleResponse(now);
+ }
+ catch (const std::exception& e) {
+ vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what());
+ }
+ //return;
+ }
+ }
+
+ if (conn->d_state != State::idle &&
+ conn->d_state != State::sendingQueryToBackend &&
+ conn->d_state != State::readingResponseSizeFromBackend &&
+ conn->d_state != State::readingResponseFromBackend) {
+ vinfolog("Unexpected state %d in TCPConnectionToBackend::handleIO", static_cast<int>(conn->d_state));
+ }
+ }
+ catch (const std::exception& e) {
+ /* most likely an EOF because the other end closed the connection,
+ but it might also be a real IO error or something else.
+ Let's just drop the connection
+ */
+ vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_ioState->getState() == IOState::NeedRead ? "reading from" : "writing to"), conn->d_currentQuery.d_idstate.origRemote.toStringWithPort(), e.what());
+ if (conn->d_state == State::sendingQueryToBackend) {
+ ++conn->d_ds->tcpDiedSendingQuery;
+ }
+ else {
+ ++conn->d_ds->tcpDiedReadingResponse;
+ }
+
+ /* don't increase this counter when reusing connections */
+ if (conn->d_fresh) {
+ ++conn->d_downstreamFailures;
+ }
+
+#if 0
+ if (conn->d_outstanding) {
+ conn->d_outstanding = false;
+
+ if (conn->d_ds != nullptr) {
+ --conn->d_ds->outstanding;
+ }
+ }
+#endif
+ /* remove this FD from the IO multiplexer */
+ iostate = IOState::Done;
+ connectionDied = true;
+ }
+
+ if (connectionDied) {
+ bool reconnected = false;
+ // DEBUG: cerr<<"connection died, number of failures is "<<conn->d_downstreamFailures<<", retries is "<<conn->d_ds->retries<<endl;
+
+ if ((!conn->d_usedForXFR || conn->d_queries == 0) && conn->d_downstreamFailures < conn->d_ds->retries) {
+ // DEBUG: cerr<<"reconnecting"<<endl;
+ conn->d_ioState->reset();
+ ioGuard.release();
+
+ if (conn->reconnect()) {
+ // DEBUG: cerr<<"reconnected"<<endl;
+
+ conn->d_ioState = make_unique<IOStateHandler>(conn->d_clientConn->getIOMPlexer(), conn->d_socket->getHandle());
+ // DEBUG: cerr<<"new state"<<endl;
+
+ for (auto& pending : conn->d_pendingResponses) {
+ conn->d_pendingQueries.push_back(std::move(pending.second));
+ }
+ conn->d_pendingResponses.clear();
+ conn->d_currentPos = 0;
+
+ if (conn->d_state == State::doingHandshake ||
+ conn->d_state == State::sendingQueryToBackend) {
+ iostate = IOState::NeedWrite;
+ // resume sending query
+ }
+ else {
+ // DEBUG: cerr<<"sending next query"<<endl;
+ iostate = sendNextQuery(conn);
+ // DEBUG: cerr<<"after call to sendNextQuery"<<endl;
+ }
+
+ if (!conn->d_proxyProtocolPayloadAdded && !conn->d_proxyProtocolPayload.empty()) {
+ conn->d_currentQuery.d_buffer.insert(conn->d_currentQuery.d_buffer.begin(), conn->d_proxyProtocolPayload.begin(), conn->d_proxyProtocolPayload.end());
+ conn->d_proxyProtocolPayloadAdded = true;
+ }
+
+ reconnected = true;
+ }
+ }
+
+ if (!reconnected) {
+ /* reconnect failed, we give up */
+ conn->d_connectionDied = true;
+ conn->notifyAllQueriesFailed(now);
+ }
+ }
+
+ if (iostate == IOState::Done) {
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<", done"<<endl;
+ conn->d_ioState->update(iostate, handleIOCallback, conn);
+ }
+ else {
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<", updating to "<<(int)iostate<<endl;
+ conn->d_ioState->update(iostate, handleIOCallback, conn, iostate == IOState::NeedRead ? conn->getBackendReadTTD(now) : conn->getBackendWriteTTD(now));
+ }
+ ioGuard.release();
+
+}
+
+void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param)
+{
+ auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param);
+ if (fd != conn->getHandle()) {
+ throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle()));
+ }
+
+ struct timeval now;
+ gettimeofday(&now, 0);
+ handleIO(conn, now);
+}
+
+void TCPConnectionToBackend::queueQuery(TCPQuery&& query, std::shared_ptr<TCPConnectionToBackend>& sharedSelf)
+{
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ // DEBUG: cerr<<"IDS has "<<(query.d_idstate.qTag?"tags":"no tags")<<endl;
+ if (d_ioState == nullptr) {
+ throw std::runtime_error("Trying to queue a query to a TCP connection that has no incoming client connection assigned");
+ }
+
+ // if we are not already sending a query or in the middle of reading a response (so idle or doingHandshake),
+ // start sending the query
+ if (d_state == State::idle || d_state == State::waitingForResponseFromBackend) {
+ d_state = State::sendingQueryToBackend;
+ d_currentQuery = std::move(query);
+ // DEBUG: cerr<<"need write"<<endl;
+
+ struct timeval now;
+ gettimeofday(&now, 0);
+
+ d_ioState->update(IOState::NeedWrite, handleIOCallback, sharedSelf, getBackendWriteTTD(now));
+ }
+ else {
+ // store query in the list of queries to send
+ d_pendingQueries.push_back(std::move(query));
+ }
+ // DEBUG: cerr<<"out of "<<__PRETTY_FUNCTION__<<endl;
+}
+
+bool TCPConnectionToBackend::reconnect()
+{
+ std::unique_ptr<Socket> result;
+
+ if (d_socket) {
+ // DEBUG: cerr<<"closing socket "<<d_socket->getHandle()<<endl;
+ shutdown(d_socket->getHandle(), SHUT_RDWR);
+ d_socket.reset();
+ d_ioState.reset();
+ --d_ds->tcpCurrentConnections;
+ }
+
+ do {
+ vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures);
+ try {
+ result = std::unique_ptr<Socket>(new Socket(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0));
+ // DEBUG: cerr<<"result of connect is "<<result->getHandle()<<endl;
+ if (!IsAnyAddress(d_ds->sourceAddr)) {
+ SSetsockopt(result->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1);
+#ifdef IP_BIND_ADDRESS_NO_PORT
+ if (d_ds->ipBindAddrNoPort) {
+ SSetsockopt(result->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1);
+ }
+#endif
+#ifdef SO_BINDTODEVICE
+ if (!d_ds->sourceItfName.empty()) {
+ int res = setsockopt(result->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length());
+ if (res != 0) {
+ vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror());
+ }
+ }
+#endif
+ result->bind(d_ds->sourceAddr, false);
+ }
+ result->setNonBlocking();
+#ifdef MSG_FASTOPEN
+ if (!d_ds->tcpFastOpen || !isFastOpenEnabled()) {
+ SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0);
+ }
+#else
+ SConnectWithTimeout(result->getHandle(), d_ds->remote, /* no timeout, we will handle it ourselves */ 0);
+#endif /* MSG_FASTOPEN */
+
+ d_socket = std::move(result);
+ // DEBUG: cerr<<"connected new socket "<<d_socket->getHandle()<<endl;
+ ++d_ds->tcpCurrentConnections;
+ return true;
+ }
+ catch(const std::runtime_error& e) {
+ vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what());
+ d_downstreamFailures++;
+ if (d_downstreamFailures > d_ds->retries) {
+ throw;
+ }
+ }
+ }
+ while (d_downstreamFailures <= d_ds->retries);
+
+ return false;
+}
+
+void TCPConnectionToBackend::handleTimeout(const struct timeval& now, bool write)
+{
+ if (write) {
+ ++d_ds->tcpWriteTimeouts;
+ }
+ else {
+ ++d_ds->tcpReadTimeouts;
+ }
+
+ if (d_ioState) {
+ d_ioState->reset();
+ }
+
+ notifyAllQueriesFailed(now, true);
+}
+
+void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, bool timeout)
+{
+ d_connectionDied = true;
+ //auto clientConn = d_clientConn.lock();
+ //if (!clientConn) {
+ // d_clientConn.reset();
+ // return;
+ //}
+ auto& clientConn = d_clientConn;
+ if (!clientConn->active()) {
+ // a client timeout occured, or something like that */
+ d_connectionDied = true;
+ d_clientConn.reset();
+ return;
+ }
+
+ if (timeout) {
+ ++clientConn->d_ci.cs->tcpDownstreamTimeouts;
+ }
+
+ if (d_state == State::doingHandshake || d_state == State::sendingQueryToBackend) {
+ clientConn->notifyIOError(clientConn, std::move(d_currentQuery.d_idstate), now);
+ }
+
+ for (auto& query : d_pendingQueries) {
+ clientConn->notifyIOError(clientConn, std::move(query.d_idstate), now);
+ }
+
+ for (auto& response : d_pendingResponses) {
+ clientConn->notifyIOError(clientConn, std::move(response.second.d_idstate), now);
+ }
+
+ d_pendingQueries.clear();
+ d_pendingResponses.clear();
+
+ d_clientConn.reset();
+}
+
+IOState TCPConnectionToBackend::handleResponse(const struct timeval& now)
+{
+ // DEBUG: cerr<<"in "<<__PRETTY_FUNCTION__<<endl;
+ //auto clientConn = d_clientConn.lock();
+ //if (!clientConn) {
+ //d_clientConn.reset();
+ // d_connectionDied = true;
+ // // DEBUG: cerr<<"connection to client died, bye bye"<<endl;
+ // return IOState::Done;
+ //}
+
+ auto& clientConn = d_clientConn;
+ if (!clientConn->active()) {
+ // DEBUG: cerr<<"client is not active"<<endl;
+ // a client timeout occured, or something like that */
+ d_connectionDied = true;
+ d_clientConn.reset();
+ return IOState::Done;
+ }
+
+ if (d_usedForXFR) {
+ // DEBUG: cerr<<"XFR!"<<endl;
+ TCPResponse response;
+ response.d_buffer = std::move(d_responseBuffer);
+ response.d_ds = d_ds;
+ clientConn->handleXFRResponse(clientConn, now, std::move(response));
+ d_state = State::readingResponseSizeFromBackend;
+ d_currentPos = 0;
+ d_responseBuffer.resize(sizeof(uint16_t));
+ return IOState::NeedRead;
+ // get ready to read the next packet, if any
+ }
+ else {
+ // DEBUG: cerr<<"not XFR, phew"<<endl;
+ uint16_t queryId = 0;
+ try {
+ queryId = getQueryIdFromResponse();
+ }
+ catch (const std::exception& e) {
+ notifyAllQueriesFailed(now);
+ throw;
+ }
+
+ auto it = d_pendingResponses.find(queryId);
+ if (it == d_pendingResponses.end()) {
+ // DEBUG: cerr<<"could not found any corresponding query for ID "<<queryId<<endl;
+ notifyAllQueriesFailed(now);
+ return IOState::Done;
+ }
+ auto ids = std::move(it->second.d_idstate);
+ // DEBUG: cerr<<"IDS has "<<(ids.qTag?" TAGS ": "NO TAGS")<<endl;
+ // DEBUG: cerr<<"passing response to client connection for "<<ids.qname<<endl;
+ clientConn->handleResponse(clientConn, now, TCPResponse(std::move(d_responseBuffer), std::move(ids), d_ds));
+ d_pendingResponses.erase(it);
+
+ if (!d_pendingQueries.empty()) {
+ // DEBUG: cerr<<"still have some queries to send"<<endl;
+ d_state = State::sendingQueryToBackend;
+ d_currentQuery = std::move(d_pendingQueries.front());
+ d_pendingQueries.pop_front();
+ return IOState::NeedWrite;
+ }
+ else if (!d_pendingResponses.empty()) {
+ // DEBUG: cerr<<"still have some responses to read"<<endl;
+ d_state = State::readingResponseSizeFromBackend;
+ d_currentPos = 0;
+ d_responseBuffer.resize(sizeof(uint16_t));
+ return IOState::NeedRead;
+ }
+ else {
+ // DEBUG: cerr<<"nothing to do, phewwwww"<<endl;
+ d_state = State::idle;
+ d_clientConn.reset();
+ return IOState::Done;
+ }
+ }
+}
+
+uint16_t TCPConnectionToBackend::getQueryIdFromResponse()
+{
+ if (d_responseBuffer.size() < sizeof(dnsheader)) {
+ throw std::runtime_error("Unable to get query ID in a too small (" + std::to_string(d_responseBuffer.size()) + ") response from " + d_ds->getNameWithAddr());
+ }
+
+ dnsheader dh;
+ memcpy(&dh, &d_responseBuffer.at(0), sizeof(dh));
+ return ntohs(dh.id);
+}
+
+void TCPConnectionToBackend::setProxyProtocolPayload(std::string&& payload)
+{
+ d_proxyProtocolPayload = std::move(payload);
+}
+
+void TCPConnectionToBackend::setProxyProtocolPayloadAdded(bool added)
+{
+ d_proxyProtocolPayloadAdded = added;
+}