static int pickBackendSocketForSending(DownstreamState* state)
{
- return state->fds[state->fdOffset++ % state->fds.size()];
+ return state->sockets[state->socketsOffset++ % state->sockets.size()];
}
-static int pickBackendSocketForReceiving(const std::shared_ptr<DownstreamState>& state)
+static void pickBackendSocketsReadyForReceiving(const std::shared_ptr<DownstreamState>& state, std::vector<int>& ready)
{
- if (state->fds.size() == 1) {
- return state->fds[0];
- }
+ ready.clear();
- std::set<int> fds;
- for (auto fd : state->fds) {
- if (fd >= 0) {
- fds.insert(fd);
- }
+ if (state->sockets.size() == 1) {
+ ready.push_back(state->sockets[0]);
+ return ;
}
- int selected = -1;
- int res = waitForMultiData(fds, -1, -1, &selected);
- if (res != 1) {
- throw std::runtime_error("Error selecting a socket for a backend " + state->remote.toStringWithPort() + ": " + strerror(errno));
+ {
+ std::lock_guard<std::mutex> lock(state->socketsLock);
+ state->mplexer->getAvailableFDs(ready, -1);
}
-
- return selected;
}
// listens on a dedicated socket, lobs answers from downstream servers to original requestors
vector<uint8_t> rewrittenResponse;
uint16_t queryId = 0;
+ std::vector<int> sockets;
+ sockets.reserve(dss->sockets.size());
+
for(;;) {
dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
bool outstandingDecreased = false;
try {
- int fd = pickBackendSocketForReceiving(dss);
- ssize_t got = recv(fd, packet, sizeof(packet), 0);
- char * response = packet;
- size_t responseSize = sizeof(packet);
+ pickBackendSocketsReadyForReceiving(dss, sockets);
+ for (const auto& fd : sockets) {
+ ssize_t got = recv(fd, packet, sizeof(packet), 0);
+ char * response = packet;
+ size_t responseSize = sizeof(packet);
- if (got < (ssize_t) sizeof(dnsheader))
- continue;
+ if (got < (ssize_t) sizeof(dnsheader))
+ continue;
- uint16_t responseLen = (uint16_t) got;
- queryId = dh->id;
+ uint16_t responseLen = (uint16_t) got;
+ queryId = dh->id;
- if(queryId >= dss->idStates.size())
- continue;
+ if(queryId >= dss->idStates.size())
+ continue;
- IDState* ids = &dss->idStates[queryId];
- int origFD = ids->origFD;
+ IDState* ids = &dss->idStates[queryId];
+ int origFD = ids->origFD;
- if(origFD < 0) // duplicate
- continue;
+ if(origFD < 0) // duplicate
+ continue;
- /* setting age to 0 to prevent the maintainer thread from
- cleaning this IDS while we process the response.
- We have already a copy of the origFD, so it would
- mostly mess up the outstanding counter.
- */
- ids->age = 0;
+ /* setting age to 0 to prevent the maintainer thread from
+ cleaning this IDS while we process the response.
+ We have already a copy of the origFD, so it would
+ mostly mess up the outstanding counter.
+ */
+ ids->age = 0;
- if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote)) {
- continue;
- }
+ if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote)) {
+ continue;
+ }
- --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket
- outstandingDecreased = true;
+ --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket
+ outstandingDecreased = true;
- if(dh->tc && g_truncateTC) {
- truncateTC(response, &responseLen);
- }
+ if(dh->tc && g_truncateTC) {
+ truncateTC(response, &responseLen);
+ }
- dh->id = ids->origID;
+ dh->id = ids->origID;
- uint16_t addRoom = 0;
- DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
+ uint16_t addRoom = 0;
+ DNSResponse dr(&ids->qname, ids->qtype, ids->qclass, &ids->origDest, &ids->origRemote, dh, sizeof(packet), responseLen, false, &ids->sentTime.d_start);
#ifdef HAVE_PROTOBUF
- dr.uniqueId = ids->uniqueId;
+ dr.uniqueId = ids->uniqueId;
#endif
- dr.qTag = ids->qTag;
+ dr.qTag = ids->qTag;
- if (!processResponse(localRespRulactions, dr, &ids->delayMsec)) {
- continue;
- }
+ if (!processResponse(localRespRulactions, dr, &ids->delayMsec)) {
+ continue;
+ }
#ifdef HAVE_DNSCRYPT
- if (ids->dnsCryptQuery) {
- addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
- }
+ if (ids->dnsCryptQuery) {
+ addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
+ }
#endif
- if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded, ids->ecsAdded, rewrittenResponse, addRoom)) {
- continue;
- }
+ if (!fixUpResponse(&response, &responseLen, &responseSize, ids->qname, ids->origFlags, ids->ednsAdded, ids->ecsAdded, rewrittenResponse, addRoom)) {
+ continue;
+ }
- if (ids->packetCache && !ids->skipCache) {
- ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL);
- }
+ if (ids->packetCache && !ids->skipCache) {
+ ids->packetCache->insert(ids->cacheKey, ids->qname, ids->qtype, ids->qclass, response, responseLen, false, dh->rcode, ids->tempFailureTTL);
+ }
- if (ids->cs && !ids->cs->muted) {
+ if (ids->cs && !ids->cs->muted) {
#ifdef HAVE_DNSCRYPT
- if (!encryptResponse(response, &responseLen, responseSize, false, ids->dnsCryptQuery, &dh, &dhCopy)) {
- continue;
- }
+ if (!encryptResponse(response, &responseLen, responseSize, false, ids->dnsCryptQuery, &dh, &dhCopy)) {
+ continue;
+ }
#endif
- ComboAddress empty;
- empty.sin4.sin_family = 0;
- /* if ids->destHarvested is false, origDest holds the listening address.
- We don't want to use that as a source since it could be 0.0.0.0 for example. */
- sendUDPResponse(origFD, response, responseLen, ids->delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
- }
+ ComboAddress empty;
+ empty.sin4.sin_family = 0;
+ /* if ids->destHarvested is false, origDest holds the listening address.
+ We don't want to use that as a source since it could be 0.0.0.0 for example. */
+ sendUDPResponse(origFD, response, responseLen, ids->delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
+ }
- g_stats.responses++;
+ g_stats.responses++;
- double udiff = ids->sentTime.udiff();
- vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
+ double udiff = ids->sentTime.udiff();
+ vinfolog("Got answer from %s, relayed to %s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(), udiff);
- {
- struct timespec ts;
- gettime(&ts);
- std::lock_guard<std::mutex> lock(g_rings.respMutex);
- g_rings.respRing.push_back({ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, dss->remote});
- }
+ {
+ struct timespec ts;
+ gettime(&ts);
+ std::lock_guard<std::mutex> lock(g_rings.respMutex);
+ g_rings.respRing.push_back({ts, ids->origRemote, ids->qname, ids->qtype, (unsigned int)udiff, (unsigned int)got, *dh, dss->remote});
+ }
- if(dh->rcode == RCode::ServFail)
- g_stats.servfailResponses++;
- dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff/128.0;
+ if(dh->rcode == RCode::ServFail)
+ g_stats.servfailResponses++;
+ dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff/128.0;
- doLatencyStats(udiff);
+ doLatencyStats(udiff);
- if (ids->origFD == origFD) {
+ if (ids->origFD == origFD) {
#ifdef HAVE_DNSCRYPT
- ids->dnsCryptQuery = nullptr;
+ ids->dnsCryptQuery = nullptr;
#endif
- ids->origFD = -1;
- outstandingDecreased = false;
- }
+ ids->origFD = -1;
+ outstandingDecreased = false;
+ }
- rewrittenResponse.clear();
+ rewrittenResponse.clear();
+ }
}
catch(const std::exception& e){
vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->remote.toStringWithPort(), queryId, e.what());
void DownstreamState::reconnect()
{
connected = false;
- for (auto& fd : fds) {
+ for (auto& fd : sockets) {
if (fd != -1) {
+ {
+ std::lock_guard<std::mutex> lock(socketsLock);
+ mplexer->removeReadFD(fd);
+ }
/* shutdown() is needed to wake up recv() in the responderThread */
shutdown(fd, SHUT_RDWR);
close(fd);
}
try {
SConnect(fd, remote);
+ {
+ std::lock_guard<std::mutex> lock(socketsLock);
+ mplexer->addReadFD(fd, [](int, boost::any) {});
+ }
connected = true;
}
catch(const std::runtime_error& error) {
/* if at least one (re-)connection failed, close all sockets */
if (!connected) {
- for (auto& fd : fds) {
+ for (auto& fd : sockets) {
if (fd != -1) {
/* shutdown() is needed to wake up recv() in the responderThread */
shutdown(fd, SHUT_RDWR);
DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, size_t numberOfSockets): remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_)
{
- fds.resize(numberOfSockets);
- for (auto& fd : fds) {
+ mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
+
+ sockets.resize(numberOfSockets);
+ for (auto& fd : sockets) {
fd = -1;
}
warnlog("Marking downstream %s as '%s'", dss->getNameWithAddr(), newState ? "up" : "down");
if (newState && !dss->connected) {
- for (auto& fd : dss->fds) {
+ for (auto& fd : dss->sockets) {
try {
SConnect(fd, dss->remote);
+ {
+ std::lock_guard<std::mutex> lock(dss->socketsLock);
+ dss->mplexer->addReadFD(fd, [](int, boost::any) {});
+ }
dss->connected = true;
}
catch(const std::runtime_error& error) {
/* UDP sockets to backends */
size_t backendUDPSocketsCount = 0;
for (const auto& backend : backends) {
- backendUDPSocketsCount += backend->fds.size();
+ backendUDPSocketsCount += backend->sockets.size();
}
requiredFDsCount += backendUDPSocketsCount;
/* TCP sockets to backends */
{
}
- virtual int run(struct timeval* tv, int timeout=500);
+ virtual int run(struct timeval* tv, int timeout=500) override;
+ virtual void getAvailableFDs(std::vector<int>& fds, int timeout) override;
- virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter);
- virtual void removeFD(callbackmap_t& cbmap, int fd);
- string getName()
+ virtual void addFD(callbackmap_t& cbmap, int fd, callbackfunc_t toDo, const funcparam_t& parameter) override;
+ virtual void removeFD(callbackmap_t& cbmap, int fd) override;
+
+ string getName() const override
{
return "poll";
}
private:
+ vector<struct pollfd> preparePollFD() const;
};
static FDMultiplexer* make()
throw FDMultiplexerException("Tried to remove unlisted fd "+std::to_string(fd)+ " from multiplexer");
}
-bool pollfdcomp(const struct pollfd& a, const struct pollfd& b)
+vector<struct pollfd> PollFDMultiplexer::preparePollFD() const
{
- return a.fd < b.fd;
-}
-
-int PollFDMultiplexer::run(struct timeval* now, int timeout)
-{
- if(d_inrun) {
- throw FDMultiplexerException("FDMultiplexer::run() is not reentrant!\n");
- }
-
vector<struct pollfd> pollfds;
-
+ pollfds.reserve(d_readCallbacks.size() + d_writeCallbacks.size());
+
struct pollfd pollfd;
- for(callbackmap_t::const_iterator i=d_readCallbacks.begin(); i != d_readCallbacks.end(); ++i) {
- pollfd.fd = i->first;
+ for(const auto& cb : d_readCallbacks) {
+ pollfd.fd = cb.first;
pollfd.events = POLLIN;
pollfds.push_back(pollfd);
}
- for(callbackmap_t::const_iterator i=d_writeCallbacks.begin(); i != d_writeCallbacks.end(); ++i) {
- pollfd.fd = i->first;
+ for(const auto& cb : d_writeCallbacks) {
+ pollfd.fd = cb.first;
pollfd.events = POLLOUT;
pollfds.push_back(pollfd);
}
+ return pollfds;
+}
+
+void PollFDMultiplexer::getAvailableFDs(std::vector<int>& fds, int timeout)
+{
+ auto pollfds = preparePollFD();
+ int ret = poll(&pollfds[0], pollfds.size(), timeout);
+
+ if (ret < 0 && errno != EINTR)
+ throw FDMultiplexerException("poll returned error: " + stringerror());
+
+ for(const auto& pollfd : pollfds) {
+ if (pollfd.revents == POLLIN || pollfd.revents == POLLOUT) {
+ fds.push_back(pollfd.fd);
+ }
+ }
+}
+
+int PollFDMultiplexer::run(struct timeval* now, int timeout)
+{
+ if(d_inrun) {
+ throw FDMultiplexerException("FDMultiplexer::run() is not reentrant!\n");
+ }
+
+ auto pollfds = preparePollFD();
+
int ret=poll(&pollfds[0], pollfds.size(), timeout);
gettimeofday(now, 0); // MANDATORY!
d_iter=d_readCallbacks.end();
d_inrun=true;
- for(unsigned int n = 0; n < pollfds.size(); ++n) {
- if(pollfds[n].revents == POLLIN) {
- d_iter=d_readCallbacks.find(pollfds[n].fd);
+ for(const auto& pollfd : pollfds) {
+ if(pollfd.revents == POLLIN) {
+ d_iter=d_readCallbacks.find(pollfd.fd);
if(d_iter != d_readCallbacks.end()) {
d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);
continue; // so we don't refind ourselves as writable!
}
}
- else if(pollfds[n].revents == POLLOUT) {
- d_iter=d_writeCallbacks.find(pollfds[n].fd);
+ else if(pollfd.revents == POLLOUT) {
+ d_iter=d_writeCallbacks.find(pollfd.fd);
if(d_iter != d_writeCallbacks.end()) {
d_iter->second.d_callback(d_iter->first, d_iter->second.d_parameter);