/*
- PowerDNS Versatile Database Driven Nameserver
- Copyright (C) 2013 - 2015 PowerDNS.COM BV
-
- This program is free software; you can redistribute it and/or modify
- it under the terms of the GNU General Public License version 2
- as published by the Free Software Foundation
-
- Additionally, the license of this program contains a special
- exception which allows to distribute the program in binary form when
- it is linked against OpenSSL.
-
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
-
- You should have received a copy of the GNU General Public License
- along with this program; if not, write to the Free Software
- Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
-*/
-
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
#include "dnsdist.hh"
#include "dnsdist-ecs.hh"
#include "dnsparser.hh"
{
vinfolog("TCP connecting to downstream %s", ds->remote.toStringWithPort());
int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
- if (!IsAnyAddress(ds->sourceAddr)) {
- SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
- SBind(sock, ds->sourceAddr);
+ try {
+ if (!IsAnyAddress(ds->sourceAddr)) {
+ SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
+ SBind(sock, ds->sourceAddr);
+ }
+ SConnect(sock, ds->remote);
+ setNonBlocking(sock);
+ }
+ catch(const std::runtime_error& e) {
+ /* don't leak our file descriptor if SConnect() (for example) throws */
+ close(sock);
+ throw;
}
- SConnect(sock, ds->remote);
- setNonBlocking(sock);
return sock;
}
ClientState* cs;
};
-uint64_t g_maxTCPQueuedConnections{0};
+uint64_t g_maxTCPQueuedConnections{1000};
+size_t g_maxTCPQueriesPerConn{0};
+size_t g_maxTCPConnectionDuration{0};
+size_t g_maxTCPConnectionsPerClient{0};
+static std::mutex tcpClientsCountMutex;
+static std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> tcpClientsCount;
+
void* tcpClientThread(int pipefd);
-// Should not be called simultaneously!
-void TCPClientCollection::addTCPClientThread()
+static void decrementTCPClientCount(const ComboAddress& client)
{
- if (d_numthreads >= d_tcpclientthreads.capacity()) {
- warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
- return;
+ if (g_maxTCPConnectionsPerClient) {
+ std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
+ tcpClientsCount[client]--;
+ if (tcpClientsCount[client] == 0) {
+ tcpClientsCount.erase(client);
+ }
}
+}
+void TCPClientCollection::addTCPClientThread()
+{
vinfolog("Adding TCP Client thread");
int pipefds[2] = { -1, -1};
- if(pipe(pipefds) < 0)
- unixDie("Creating pipe");
+ if (pipe(pipefds) < 0) {
+ errlog("Error creating the TCP thread communication pipe: %s", strerror(errno));
+ return;
+ }
if (!setNonBlocking(pipefds[1])) {
close(pipefds[0]);
close(pipefds[1]);
- unixDie("Setting pipe non-blocking");
+ errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
+ return;
+ }
+
+ {
+ std::lock_guard<std::mutex> lock(d_mutex);
+
+ if (d_numthreads >= d_tcpclientthreads.capacity()) {
+ warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
+ close(pipefds[0]);
+ close(pipefds[1]);
+ return;
+ }
+
+ try {
+ thread t1(tcpClientThread, pipefds[0]);
+ t1.detach();
+ }
+ catch(const std::runtime_error& e) {
+ /* the thread creation failed, don't leak */
+ errlog("Error creating a TCP thread: %s", e.what());
+ close(pipefds[0]);
+ close(pipefds[1]);
+ return;
+ }
+
+ d_tcpclientthreads.push_back(pipefds[1]);
}
- d_tcpclientthreads.push_back(pipefds[1]);
++d_numthreads;
- thread t1(tcpClientThread, pipefds[0]);
- t1.detach();
}
static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
return true;
}
+static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
+{
+ if (maxConnectionDuration) {
+ time_t elapsed = time(NULL) - start;
+ if (elapsed >= maxConnectionDuration) {
+ return true;
+ }
+ remainingTime = maxConnectionDuration - elapsed;
+ }
+ return false;
+}
+
std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
void* tcpClientThread(int pipefd)
throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
}
- --g_tcpclientthreads->d_queued;
+ g_tcpclientthreads->decrementQueuedCount();
ci=*citmp;
delete citmp;
string largerQuery;
vector<uint8_t> rewrittenResponse;
shared_ptr<DownstreamState> ds;
+ ComboAddress dest;
+ memset(&dest, 0, sizeof(dest));
+ dest.sin4.sin_family = ci.remote.sin4.sin_family;
+ socklen_t len = dest.getSocklen();
+ size_t queriesCount = 0;
+ time_t connectionStartTime = time(NULL);
+
if (!setNonBlocking(ci.fd))
goto drop;
+ if (getsockname(ci.fd, (sockaddr*)&dest, &len)) {
+ dest = ci.cs->local;
+ }
+
try {
for(;;) {
+ unsigned int remainingTime = 0;
ds = nullptr;
outstanding = false;
ci.cs->queries++;
g_stats.queries++;
+ queriesCount++;
+
+ if (g_maxTCPQueriesPerConn && queriesCount > g_maxTCPQueriesPerConn) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci.remote.toStringWithPort(), queriesCount, g_maxTCPQueriesPerConn);
+ break;
+ }
+
+ if (maxConnectionDurationReached(g_maxTCPConnectionDuration, connectionStartTime, remainingTime)) {
+ vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci.remote.toStringWithPort());
+ break;
+ }
+
if (qlen < sizeof(dnsheader)) {
g_stats.nonCompliantQueries++;
break;
size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
char queryBuffer[querySize];
const char* query = queryBuffer;
- readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
+ readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout, remainingTime);
#ifdef HAVE_DNSCRYPT
std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
uint16_t qtype, qclass;
unsigned int consumed = 0;
DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
- DNSQuestion dq(&qname, qtype, qclass, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true);
+ DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, (dnsheader*)query, querySize, qlen, true);
#ifdef HAVE_PROTOBUF
dq.uniqueId = uuidGenerator();
#endif
string poolname;
int delayMsec=0;
+ /* we need this one to be accurate ("real") for the protobuf message */
+ struct timespec queryRealTime;
struct timespec now;
- gettime(&now, true);
+ gettime(&now);
+ gettime(&queryRealTime, true);
if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, blockFilter, dq, poolname, &delayMsec, now)) {
goto drop;
packetCache = serverPool->packetCache;
}
- if (ds && ds->useECS) {
+ if (dq.useECS && ds && ds->useECS) {
uint16_t newLen = dq.len;
- handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote);
+ handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength);
if (largerQuery.empty() == false) {
query = largerQuery.c_str();
dq.len = (uint16_t) largerQuery.size();
g_stats.cacheMisses++;
}
- if(!ds) {
- g_stats.noPolicy++;
- break;
- }
+ if(!ds) {
+ g_stats.noPolicy++;
+
+ if (g_servFailOnNoPolicy) {
+ restoreFlags(dh, origFlags);
+ dq.dh->rcode = RCode::ServFail;
+ dq.dh->qr = true;
+
+#ifdef HAVE_DNSCRYPT
+ if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
+ goto drop;
+ }
+#endif
+ sendResponseToClient(ci.fd, query, dq.len);
+ }
+
+ break;
+ }
int dsock = -1;
if(sockets.count(ds->remote) == 0) {
- dsock=sockets[ds->remote]=setupTCPDownstream(ds);
+ dsock=setupTCPDownstream(ds);
+ sockets[ds->remote]=dsock;
}
else
dsock=sockets[ds->remote];
close(dsock);
dsock=-1;
sockets.erase(ds->remote);
- sockets[ds->remote]=dsock=setupTCPDownstream(ds);
+ dsock=setupTCPDownstream(ds);
+ sockets[ds->remote]=dsock;
downstream_failures++;
goto retry;
}
close(dsock);
dsock=-1;
sockets.erase(ds->remote);
- sockets[ds->remote]=dsock=setupTCPDownstream(ds);
+ dsock=setupTCPDownstream(ds);
+ sockets[ds->remote]=dsock;
downstream_failures++;
goto retry;
}
close(dsock);
dsock=-1;
sockets.erase(ds->remote);
- sockets[ds->remote]=dsock=setupTCPDownstream(ds);
+ dsock=setupTCPDownstream(ds);
+ sockets[ds->remote]=dsock;
downstream_failures++;
if(xfrStarted) {
goto drop;
readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
char* response = answerbuffer;
uint16_t responseLen = rlen;
- --ds->outstanding;
- outstanding = false;
+ if (outstanding) {
+ /* might be false for {A,I}XFR */
+ --ds->outstanding;
+ outstanding = false;
+ }
if (rlen < sizeof(dnsheader)) {
break;
}
dh = (struct dnsheader*) response;
- DNSResponse dr(&qname, qtype, qclass, &ci.cs->local, &ci.remote, dh, responseSize, responseLen, true, &now);
+ DNSResponse dr(&qname, qtype, qclass, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
#ifdef HAVE_PROTOBUF
dr.uniqueId = dq.uniqueId;
#endif
unsigned int udiff = 1000000.0*DiffTime(now,answertime);
{
std::lock_guard<std::mutex> lock(g_rings.respMutex);
- g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dq.dh, ds->remote});
+ g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote});
}
largerQuery.clear();
outstanding = false;
--ds->outstanding;
}
+ decrementTCPClientCount(ci.remote);
}
return 0;
}
-
/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
they will hand off to worker threads & spawn more of them if required
*/
void* tcpAcceptorThread(void* p)
{
ClientState* cs = (ClientState*) p;
-
+ bool tcpClientCountIncremented = false;
ComboAddress remote;
remote.sin4.sin_family = cs->local.sin4.sin_family;
auto acl = g_ACL.getLocal();
for(;;) {
- ConnectionInfo* ci;
+ bool queuedCounterIncremented = false;
+ ConnectionInfo* ci = nullptr;
+ tcpClientCountIncremented = false;
try {
- ci=0;
ci = new ConnectionInfo;
ci->cs = cs;
ci->fd = -1;
g_stats.aclDrops++;
close(ci->fd);
delete ci;
- ci=0;
+ ci=nullptr;
vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
continue;
}
- if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) {
+ if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
close(ci->fd);
delete ci;
ci=nullptr;
continue;
}
+ if (g_maxTCPConnectionsPerClient) {
+ std::lock_guard<std::mutex> lock(tcpClientsCountMutex);
+
+ if (tcpClientsCount[remote] >= g_maxTCPConnectionsPerClient) {
+ close(ci->fd);
+ delete ci;
+ ci=nullptr;
+ vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort());
+ continue;
+ }
+ tcpClientsCount[remote]++;
+ tcpClientCountIncremented = true;
+ }
+
vinfolog("Got TCP connection from %s", remote.toStringWithPort());
-
+
ci->remote = remote;
int pipe = g_tcpclientthreads->getThread();
if (pipe >= 0) {
+ queuedCounterIncremented = true;
writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
}
else {
- --g_tcpclientthreads->d_queued;
+ g_tcpclientthreads->decrementQueuedCount();
+ queuedCounterIncremented = false;
close(ci->fd);
delete ci;
ci=nullptr;
+ if(tcpClientCountIncremented) {
+ decrementTCPClientCount(remote);
+ }
}
}
catch(std::exception& e) {
errlog("While reading a TCP question: %s", e.what());
if(ci && ci->fd >= 0)
close(ci->fd);
+ if(tcpClientCountIncremented) {
+ decrementTCPClientCount(remote);
+ }
delete ci;
+ ci = nullptr;
+ if (queuedCounterIncremented) {
+ g_tcpclientthreads->decrementQueuedCount();
+ }
}
catch(...){}
}