2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
50 static int setupTCPDownstream(shared_ptr
<DownstreamState
> ds
)
52 vinfolog("TCP connecting to downstream %s", ds
->remote
.toStringWithPort());
53 int sock
= SSocket(ds
->remote
.sin4
.sin_family
, SOCK_STREAM
, 0);
55 if (!IsAnyAddress(ds
->sourceAddr
)) {
56 SSetsockopt(sock
, SOL_SOCKET
, SO_REUSEADDR
, 1);
57 SBind(sock
, ds
->sourceAddr
);
59 SConnect(sock
, ds
->remote
);
62 catch(const std::runtime_error
& e
) {
63 /* don't leak our file descriptor if SConnect() (for example) throws */
77 uint64_t g_maxTCPQueuedConnections
{1000};
78 void* tcpClientThread(int pipefd
);
80 void TCPClientCollection::addTCPClientThread()
82 vinfolog("Adding TCP Client thread");
84 int pipefds
[2] = { -1, -1};
85 if (pipe(pipefds
) < 0) {
86 errlog("Error creating the TCP thread communication pipe: %s", strerror(errno
));
90 if (!setNonBlocking(pipefds
[1])) {
93 errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno
));
98 std::lock_guard
<std::mutex
> lock(d_mutex
);
100 if (d_numthreads
>= d_tcpclientthreads
.capacity()) {
101 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads
.load(), d_tcpclientthreads
.capacity());
108 thread
t1(tcpClientThread
, pipefds
[0]);
111 catch(const std::runtime_error
& e
) {
112 /* the thread creation failed, don't leak */
113 errlog("Error creating a TCP thread: %s", e
.what());
119 d_tcpclientthreads
.push_back(pipefds
[1]);
125 static bool getNonBlockingMsgLen(int fd
, uint16_t* len
, int timeout
)
129 size_t ret
= readn2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
130 if(ret
!= sizeof raw
)
139 static bool putNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
)
142 uint16_t raw
= htons(len
);
143 size_t ret
= writen2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
144 return ret
== sizeof raw
;
150 static bool sendNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
, ComboAddress
& dest
, ComboAddress
& local
, unsigned int localItf
)
154 return putNonBlockingMsgLen(fd
, len
, timeout
);
156 uint16_t raw
= htons(len
);
157 ssize_t ret
= sendMsgWithTimeout(fd
, (char*) &raw
, sizeof raw
, timeout
, dest
, local
, localItf
);
158 return ret
== sizeof raw
;
164 static bool sendResponseToClient(int fd
, const char* response
, uint16_t responseLen
)
166 if (!putNonBlockingMsgLen(fd
, responseLen
, g_tcpSendTimeout
))
169 writen2WithTimeout(fd
, response
, responseLen
, g_tcpSendTimeout
);
173 std::shared_ptr
<TCPClientCollection
> g_tcpclientthreads
;
175 void* tcpClientThread(int pipefd
)
177 /* we get launched with a pipe on which we receive file descriptors from clients that we own
178 from that point on */
180 bool outstanding
= false;
181 blockfilter_t blockFilter
= 0;
184 std::lock_guard
<std::mutex
> lock(g_luamutex
);
185 auto candidate
= g_lua
.readVariable
<boost::optional
<blockfilter_t
> >("blockFilter");
187 blockFilter
= *candidate
;
190 auto localPolicy
= g_policy
.getLocal();
191 auto localRulactions
= g_rulactions
.getLocal();
192 auto localRespRulactions
= g_resprulactions
.getLocal();
193 auto localDynBlockNMG
= g_dynblockNMG
.getLocal();
194 auto localDynBlockSMT
= g_dynblockSMT
.getLocal();
195 auto localPools
= g_pools
.getLocal();
197 boost::uuids::random_generator uuidGenerator
;
200 map
<ComboAddress
,int> sockets
;
202 ConnectionInfo
* citmp
, ci
;
205 readn2(pipefd
, &citmp
, sizeof(citmp
));
207 catch(const std::runtime_error
& e
) {
208 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode: " + e
.what());
211 g_tcpclientthreads
->decrementQueuedCount();
217 vector
<uint8_t> rewrittenResponse
;
218 shared_ptr
<DownstreamState
> ds
;
220 memset(&dest
, 0, sizeof(dest
));
221 dest
.sin4
.sin_family
= ci
.remote
.sin4
.sin_family
;
222 socklen_t len
= dest
.getSocklen();
223 if (!setNonBlocking(ci
.fd
))
226 if (getsockname(ci
.fd
, (sockaddr
*)&dest
, &len
)) {
235 if(!getNonBlockingMsgLen(ci
.fd
, &qlen
, g_tcpRecvTimeout
))
241 if (qlen
< sizeof(dnsheader
)) {
242 g_stats
.nonCompliantQueries
++;
246 bool ednsAdded
= false;
247 bool ecsAdded
= false;
248 /* if the query is small, allocate a bit more
249 memory to be able to spoof the content,
250 or to add ECS without allocating a new buffer */
251 size_t querySize
= qlen
<= 4096 ? qlen
+ 512 : qlen
;
252 char queryBuffer
[querySize
];
253 const char* query
= queryBuffer
;
254 readn2WithTimeout(ci
.fd
, queryBuffer
, qlen
, g_tcpRecvTimeout
);
257 std::shared_ptr
<DnsCryptQuery
> dnsCryptQuery
= 0;
259 if (ci
.cs
->dnscryptCtx
) {
260 dnsCryptQuery
= std::make_shared
<DnsCryptQuery
>();
261 uint16_t decryptedQueryLen
= 0;
262 vector
<uint8_t> response
;
263 bool decrypted
= handleDnsCryptQuery(ci
.cs
->dnscryptCtx
, queryBuffer
, qlen
, dnsCryptQuery
, &decryptedQueryLen
, true, response
);
266 if (response
.size() > 0) {
267 sendResponseToClient(ci
.fd
, reinterpret_cast<char*>(response
.data()), (uint16_t) response
.size());
271 qlen
= decryptedQueryLen
;
274 struct dnsheader
* dh
= (struct dnsheader
*) query
;
276 if(dh
->qr
) { // don't respond to responses
277 g_stats
.nonCompliantQueries
++;
281 if(dh
->qdcount
== 0) {
282 g_stats
.emptyQueries
++;
290 const uint16_t* flags
= getFlagsFromDNSHeader(dh
);
291 uint16_t origFlags
= *flags
;
292 uint16_t qtype
, qclass
;
293 unsigned int consumed
= 0;
294 DNSName
qname(query
, qlen
, sizeof(dnsheader
), false, &qtype
, &qclass
, &consumed
);
295 DNSQuestion
dq(&qname
, qtype
, qclass
, &dest
, &ci
.remote
, (dnsheader
*)query
, querySize
, qlen
, true);
297 dq
.uniqueId
= uuidGenerator();
302 /* we need this one to be accurate ("real") for the protobuf message */
303 struct timespec queryRealTime
;
306 gettime(&queryRealTime
, true);
308 if (!processQuery(localDynBlockNMG
, localDynBlockSMT
, localRulactions
, blockFilter
, dq
, poolname
, &delayMsec
, now
)) {
312 if(dq
.dh
->qr
) { // something turned it into a response
313 restoreFlags(dh
, origFlags
);
315 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
)) {
319 sendResponseToClient(ci
.fd
, query
, dq
.len
);
320 g_stats
.selfAnswered
++;
324 std::shared_ptr
<ServerPool
> serverPool
= getPool(*localPools
, poolname
);
325 std::shared_ptr
<DNSDistPacketCache
> packetCache
= nullptr;
327 std::lock_guard
<std::mutex
> lock(g_luamutex
);
328 ds
= localPolicy
->policy(serverPool
->servers
, &dq
);
329 packetCache
= serverPool
->packetCache
;
332 if (dq
.useECS
&& ds
&& ds
->useECS
) {
333 uint16_t newLen
= dq
.len
;
334 handleEDNSClientSubnet(queryBuffer
, dq
.size
, consumed
, &newLen
, largerQuery
, &ednsAdded
, &ecsAdded
, ci
.remote
, dq
.ecsOverride
, dq
.ecsPrefixLength
);
335 if (largerQuery
.empty() == false) {
336 query
= largerQuery
.c_str();
337 dq
.len
= (uint16_t) largerQuery
.size();
338 dq
.size
= largerQuery
.size();
344 uint32_t cacheKey
= 0;
345 if (packetCache
&& !dq
.skipCache
) {
346 char cachedResponse
[4096];
347 uint16_t cachedResponseSize
= sizeof cachedResponse
;
348 uint32_t allowExpired
= ds
? 0 : g_staleCacheEntriesTTL
;
349 if (packetCache
->get(dq
, (uint16_t) consumed
, dq
.dh
->id
, cachedResponse
, &cachedResponseSize
, &cacheKey
, allowExpired
)) {
351 if (!encryptResponse(cachedResponse
, &cachedResponseSize
, sizeof cachedResponse
, true, dnsCryptQuery
)) {
355 sendResponseToClient(ci
.fd
, cachedResponse
, cachedResponseSize
);
359 g_stats
.cacheMisses
++;
365 if (g_servFailOnNoPolicy
) {
366 restoreFlags(dh
, origFlags
);
367 dq
.dh
->rcode
= RCode::ServFail
;
371 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
)) {
375 sendResponseToClient(ci
.fd
, query
, dq
.len
);
382 if(sockets
.count(ds
->remote
) == 0) {
383 dsock
=setupTCPDownstream(ds
);
384 sockets
[ds
->remote
]=dsock
;
387 dsock
=sockets
[ds
->remote
];
393 uint16_t downstream_failures
=0;
396 sockets
.erase(ds
->remote
);
400 if (ds
->retries
> 0 && downstream_failures
> ds
->retries
) {
401 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds
->getName(), downstream_failures
);
404 sockets
.erase(ds
->remote
);
408 if(!sendNonBlockingMsgLen(dsock
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
)) {
409 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
412 sockets
.erase(ds
->remote
);
413 dsock
=setupTCPDownstream(ds
);
414 sockets
[ds
->remote
]=dsock
;
415 downstream_failures
++;
420 if (ds
->sourceItf
== 0) {
421 writen2WithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
);
424 sendMsgWithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
);
427 catch(const runtime_error
& e
) {
428 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
431 sockets
.erase(ds
->remote
);
432 dsock
=setupTCPDownstream(ds
);
433 sockets
[ds
->remote
]=dsock
;
434 downstream_failures
++;
438 bool xfrStarted
= false;
439 bool isXFR
= (dq
.qtype
== QType::AXFR
|| dq
.qtype
== QType::IXFR
);
446 if(!getNonBlockingMsgLen(dsock
, &rlen
, ds
->tcpRecvTimeout
)) {
447 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds
->getName());
450 sockets
.erase(ds
->remote
);
451 dsock
=setupTCPDownstream(ds
);
452 sockets
[ds
->remote
]=dsock
;
453 downstream_failures
++;
460 size_t responseSize
= rlen
;
461 uint16_t addRoom
= 0;
463 if (dnsCryptQuery
&& (UINT16_MAX
- rlen
) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
) {
464 addRoom
= DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
;
467 responseSize
+= addRoom
;
468 char answerbuffer
[responseSize
];
469 readn2WithTimeout(dsock
, answerbuffer
, rlen
, ds
->tcpRecvTimeout
);
470 char* response
= answerbuffer
;
471 uint16_t responseLen
= rlen
;
473 /* might be false for {A,I}XFR */
478 if (rlen
< sizeof(dnsheader
)) {
482 if (!responseContentMatches(response
, responseLen
, qname
, qtype
, qclass
, ds
->remote
)) {
486 if (!fixUpResponse(&response
, &responseLen
, &responseSize
, qname
, origFlags
, ednsAdded
, ecsAdded
, rewrittenResponse
, addRoom
)) {
490 dh
= (struct dnsheader
*) response
;
491 DNSResponse
dr(&qname
, qtype
, qclass
, &dest
, &ci
.remote
, dh
, responseSize
, responseLen
, true, &queryRealTime
);
493 dr
.uniqueId
= dq
.uniqueId
;
495 if (!processResponse(localRespRulactions
, dr
, &delayMsec
)) {
499 if (packetCache
&& !dq
.skipCache
) {
500 packetCache
->insert(cacheKey
, qname
, qtype
, qclass
, response
, responseLen
, true, dh
->rcode
== RCode::ServFail
);
504 if (!encryptResponse(response
, &responseLen
, responseSize
, true, dnsCryptQuery
)) {
508 if (!sendResponseToClient(ci
.fd
, response
, responseLen
)) {
512 if (isXFR
&& dh
->rcode
== 0 && dh
->ancount
!= 0) {
513 if (xfrStarted
== false) {
515 if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 1) {
519 else if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 0) {
525 struct timespec answertime
;
526 gettime(&answertime
);
527 unsigned int udiff
= 1000000.0*DiffTime(now
,answertime
);
529 std::lock_guard
<std::mutex
> lock(g_rings
.respMutex
);
530 g_rings
.respRing
.push_back({answertime
, ci
.remote
, qname
, dq
.qtype
, (unsigned int)udiff
, (unsigned int)responseLen
, *dh
, ds
->remote
});
534 rewrittenResponse
.clear();
541 vinfolog("Closing TCP client connection with %s", ci
.remote
.toStringWithPort());
544 if (ds
&& outstanding
) {
553 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
554 they will hand off to worker threads & spawn more of them if required
556 void* tcpAcceptorThread(void* p
)
558 ClientState
* cs
= (ClientState
*) p
;
561 remote
.sin4
.sin_family
= cs
->local
.sin4
.sin_family
;
563 g_tcpclientthreads
->addTCPClientThread();
565 auto acl
= g_ACL
.getLocal();
567 bool queuedCounterIncremented
= false;
568 ConnectionInfo
* ci
= nullptr;
570 ci
= new ConnectionInfo
;
573 ci
->fd
= SAccept(cs
->tcpFD
, remote
);
575 if(!acl
->match(remote
)) {
580 vinfolog("Dropped TCP connection from %s because of ACL", remote
.toStringWithPort());
584 if(g_maxTCPQueuedConnections
> 0 && g_tcpclientthreads
->getQueuedCount() >= g_maxTCPQueuedConnections
) {
588 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote
.toStringWithPort());
592 vinfolog("Got TCP connection from %s", remote
.toStringWithPort());
595 int pipe
= g_tcpclientthreads
->getThread();
597 queuedCounterIncremented
= true;
598 writen2WithTimeout(pipe
, &ci
, sizeof(ci
), 0);
601 g_tcpclientthreads
->decrementQueuedCount();
602 queuedCounterIncremented
= false;
608 catch(std::exception
& e
) {
609 errlog("While reading a TCP question: %s", e
.what());
610 if(ci
&& ci
->fd
>= 0)
614 if (queuedCounterIncremented
) {
615 g_tcpclientthreads
->decrementQueuedCount();
625 bool getMsgLen32(int fd
, uint32_t* len
)
629 size_t ret
= readn2(fd
, &raw
, sizeof raw
);
630 if(ret
!= sizeof raw
)
633 if(*len
> 10000000) // arbitrary 10MB limit
641 bool putMsgLen32(int fd
, uint32_t len
)
644 uint32_t raw
= htonl(len
);
645 size_t ret
= writen2(fd
, &raw
, sizeof raw
);
646 return ret
==sizeof raw
;