2 PowerDNS Versatile Database Driven Nameserver
3 Copyright (C) 2013 - 2015 PowerDNS.COM BV
5 This program is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License version 2
7 as published by the Free Software Foundation
9 Additionally, the license of this program contains a special
10 exception which allows to distribute the program in binary form when
11 it is linked against OpenSSL.
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
18 You should have received a copy of the GNU General Public License
19 along with this program; if not, write to the Free Software
20 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 #include "dnsdist-ecs.hh"
25 #include "dnsparser.hh"
26 #include "ednsoptions.hh"
36 /* TCP: the grand design.
37 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
38 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
41 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
42 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
43 to guarantee performance.
45 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
46 So whenever an answer comes in, we know where it needs to go.
51 static int setupTCPDownstream(shared_ptr
<DownstreamState
> ds
)
53 vinfolog("TCP connecting to downstream %s", ds
->remote
.toStringWithPort());
54 int sock
= SSocket(ds
->remote
.sin4
.sin_family
, SOCK_STREAM
, 0);
55 if (!IsAnyAddress(ds
->sourceAddr
)) {
56 SSetsockopt(sock
, SOL_SOCKET
, SO_REUSEADDR
, 1);
57 SBind(sock
, ds
->sourceAddr
);
59 SConnect(sock
, ds
->remote
);
71 uint64_t g_maxTCPQueuedConnections
{0};
72 void* tcpClientThread(int pipefd
);
74 // Should not be called simultaneously!
75 void TCPClientCollection::addTCPClientThread()
77 if (d_numthreads
>= d_tcpclientthreads
.capacity()) {
78 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads
.load(), d_tcpclientthreads
.capacity());
82 vinfolog("Adding TCP Client thread");
84 int pipefds
[2] = { -1, -1};
86 unixDie("Creating pipe");
88 if (!setNonBlocking(pipefds
[1])) {
91 unixDie("Setting pipe non-blocking");
94 d_tcpclientthreads
.push_back(pipefds
[1]);
96 thread
t1(tcpClientThread
, pipefds
[0]);
100 static bool getNonBlockingMsgLen(int fd
, uint16_t* len
, int timeout
)
104 size_t ret
= readn2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
105 if(ret
!= sizeof raw
)
114 static bool putNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
)
117 uint16_t raw
= htons(len
);
118 size_t ret
= writen2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
119 return ret
== sizeof raw
;
125 static bool sendNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
, ComboAddress
& dest
, ComboAddress
& local
, unsigned int localItf
)
129 return putNonBlockingMsgLen(fd
, len
, timeout
);
131 uint16_t raw
= htons(len
);
132 ssize_t ret
= sendMsgWithTimeout(fd
, (char*) &raw
, sizeof raw
, timeout
, dest
, local
, localItf
);
133 return ret
== sizeof raw
;
139 static bool sendResponseToClient(int fd
, const char* response
, uint16_t responseLen
)
141 if (!putNonBlockingMsgLen(fd
, responseLen
, g_tcpSendTimeout
))
144 writen2WithTimeout(fd
, response
, responseLen
, g_tcpSendTimeout
);
148 std::shared_ptr
<TCPClientCollection
> g_tcpclientthreads
;
150 void* tcpClientThread(int pipefd
)
152 /* we get launched with a pipe on which we receive file descriptors from clients that we own
153 from that point on */
155 bool outstanding
= false;
156 blockfilter_t blockFilter
= 0;
159 std::lock_guard
<std::mutex
> lock(g_luamutex
);
160 auto candidate
= g_lua
.readVariable
<boost::optional
<blockfilter_t
> >("blockFilter");
162 blockFilter
= *candidate
;
165 auto localPolicy
= g_policy
.getLocal();
166 auto localRulactions
= g_rulactions
.getLocal();
167 auto localRespRulactions
= g_resprulactions
.getLocal();
168 auto localDynBlockNMG
= g_dynblockNMG
.getLocal();
169 auto localDynBlockSMT
= g_dynblockSMT
.getLocal();
170 auto localPools
= g_pools
.getLocal();
172 boost::uuids::random_generator uuidGenerator
;
175 map
<ComboAddress
,int> sockets
;
177 ConnectionInfo
* citmp
, ci
;
180 readn2(pipefd
, &citmp
, sizeof(citmp
));
182 catch(const std::runtime_error
& e
) {
183 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode: " + e
.what());
186 --g_tcpclientthreads
->d_queued
;
192 vector
<uint8_t> rewrittenResponse
;
193 shared_ptr
<DownstreamState
> ds
;
194 if (!setNonBlocking(ci
.fd
))
202 if(!getNonBlockingMsgLen(ci
.fd
, &qlen
, g_tcpRecvTimeout
))
208 if (qlen
< sizeof(dnsheader
)) {
209 g_stats
.nonCompliantQueries
++;
213 bool ednsAdded
= false;
214 bool ecsAdded
= false;
215 /* if the query is small, allocate a bit more
216 memory to be able to spoof the content,
217 or to add ECS without allocating a new buffer */
218 size_t querySize
= qlen
<= 4096 ? qlen
+ 512 : qlen
;
219 char queryBuffer
[querySize
];
220 const char* query
= queryBuffer
;
221 readn2WithTimeout(ci
.fd
, queryBuffer
, qlen
, g_tcpRecvTimeout
);
224 std::shared_ptr
<DnsCryptQuery
> dnsCryptQuery
= 0;
226 if (ci
.cs
->dnscryptCtx
) {
227 dnsCryptQuery
= std::make_shared
<DnsCryptQuery
>();
228 uint16_t decryptedQueryLen
= 0;
229 vector
<uint8_t> response
;
230 bool decrypted
= handleDnsCryptQuery(ci
.cs
->dnscryptCtx
, queryBuffer
, qlen
, dnsCryptQuery
, &decryptedQueryLen
, true, response
);
233 if (response
.size() > 0) {
234 sendResponseToClient(ci
.fd
, reinterpret_cast<char*>(response
.data()), (uint16_t) response
.size());
238 qlen
= decryptedQueryLen
;
241 struct dnsheader
* dh
= (struct dnsheader
*) query
;
243 if(dh
->qr
) { // don't respond to responses
244 g_stats
.nonCompliantQueries
++;
248 if(dh
->qdcount
== 0) {
249 g_stats
.emptyQueries
++;
257 const uint16_t* flags
= getFlagsFromDNSHeader(dh
);
258 uint16_t origFlags
= *flags
;
259 uint16_t qtype
, qclass
;
260 unsigned int consumed
= 0;
261 DNSName
qname(query
, qlen
, sizeof(dnsheader
), false, &qtype
, &qclass
, &consumed
);
262 DNSQuestion
dq(&qname
, qtype
, qclass
, &ci
.cs
->local
, &ci
.remote
, (dnsheader
*)query
, querySize
, qlen
, true);
264 dq
.uniqueId
= uuidGenerator();
272 if (!processQuery(localDynBlockNMG
, localDynBlockSMT
, localRulactions
, blockFilter
, dq
, poolname
, &delayMsec
, now
)) {
276 if(dq
.dh
->qr
) { // something turned it into a response
277 restoreFlags(dh
, origFlags
);
279 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
)) {
283 sendResponseToClient(ci
.fd
, query
, dq
.len
);
284 g_stats
.selfAnswered
++;
288 std::shared_ptr
<ServerPool
> serverPool
= getPool(*localPools
, poolname
);
289 std::shared_ptr
<DNSDistPacketCache
> packetCache
= nullptr;
291 std::lock_guard
<std::mutex
> lock(g_luamutex
);
292 ds
= localPolicy
->policy(serverPool
->servers
, &dq
);
293 packetCache
= serverPool
->packetCache
;
296 if (ds
&& ds
->useECS
) {
297 uint16_t newLen
= dq
.len
;
298 handleEDNSClientSubnet(queryBuffer
, dq
.size
, consumed
, &newLen
, largerQuery
, &ednsAdded
, &ecsAdded
, ci
.remote
);
299 if (largerQuery
.empty() == false) {
300 query
= largerQuery
.c_str();
301 dq
.len
= (uint16_t) largerQuery
.size();
302 dq
.size
= largerQuery
.size();
308 uint32_t cacheKey
= 0;
309 if (packetCache
&& !dq
.skipCache
) {
310 char cachedResponse
[4096];
311 uint16_t cachedResponseSize
= sizeof cachedResponse
;
312 uint32_t allowExpired
= ds
? 0 : g_staleCacheEntriesTTL
;
313 if (packetCache
->get(dq
, (uint16_t) consumed
, dq
.dh
->id
, cachedResponse
, &cachedResponseSize
, &cacheKey
, allowExpired
)) {
315 if (!encryptResponse(cachedResponse
, &cachedResponseSize
, sizeof cachedResponse
, true, dnsCryptQuery
)) {
319 sendResponseToClient(ci
.fd
, cachedResponse
, cachedResponseSize
);
323 g_stats
.cacheMisses
++;
332 if(sockets
.count(ds
->remote
) == 0) {
333 dsock
=sockets
[ds
->remote
]=setupTCPDownstream(ds
);
336 dsock
=sockets
[ds
->remote
];
342 uint16_t downstream_failures
=0;
345 sockets
.erase(ds
->remote
);
349 if (ds
->retries
> 0 && downstream_failures
> ds
->retries
) {
350 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds
->getName(), downstream_failures
);
353 sockets
.erase(ds
->remote
);
357 if(!sendNonBlockingMsgLen(dsock
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
)) {
358 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
361 sockets
.erase(ds
->remote
);
362 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
363 downstream_failures
++;
368 if (ds
->sourceItf
== 0) {
369 writen2WithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
);
372 sendMsgWithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
);
375 catch(const runtime_error
& e
) {
376 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
379 sockets
.erase(ds
->remote
);
380 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
381 downstream_failures
++;
385 bool xfrStarted
= false;
386 bool isXFR
= (dq
.qtype
== QType::AXFR
|| dq
.qtype
== QType::IXFR
);
393 if(!getNonBlockingMsgLen(dsock
, &rlen
, ds
->tcpRecvTimeout
)) {
394 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds
->getName());
397 sockets
.erase(ds
->remote
);
398 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
399 downstream_failures
++;
406 size_t responseSize
= rlen
;
407 uint16_t addRoom
= 0;
409 if (dnsCryptQuery
&& (UINT16_MAX
- rlen
) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
) {
410 addRoom
= DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
;
413 responseSize
+= addRoom
;
414 char answerbuffer
[responseSize
];
415 readn2WithTimeout(dsock
, answerbuffer
, rlen
, ds
->tcpRecvTimeout
);
416 char* response
= answerbuffer
;
417 uint16_t responseLen
= rlen
;
421 if (rlen
< sizeof(dnsheader
)) {
425 if (!responseContentMatches(response
, responseLen
, qname
, qtype
, qclass
, ds
->remote
)) {
429 if (!fixUpResponse(&response
, &responseLen
, &responseSize
, qname
, origFlags
, ednsAdded
, ecsAdded
, rewrittenResponse
, addRoom
)) {
433 dh
= (struct dnsheader
*) response
;
434 DNSResponse
dr(&qname
, qtype
, qclass
, &ci
.cs
->local
, &ci
.remote
, dh
, responseSize
, responseLen
, true, &now
);
436 dr
.uniqueId
= dq
.uniqueId
;
438 if (!processResponse(localRespRulactions
, dr
)) {
442 if (packetCache
&& !dq
.skipCache
) {
443 packetCache
->insert(cacheKey
, qname
, qtype
, qclass
, response
, responseLen
, true, dh
->rcode
== RCode::ServFail
);
447 if (!encryptResponse(response
, &responseLen
, responseSize
, true, dnsCryptQuery
)) {
451 if (!sendResponseToClient(ci
.fd
, response
, responseLen
)) {
455 if (isXFR
&& dh
->rcode
== 0 && dh
->ancount
!= 0) {
456 if (xfrStarted
== false) {
458 if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 1) {
462 else if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 0) {
468 struct timespec answertime
;
469 gettime(&answertime
);
470 unsigned int udiff
= 1000000.0*DiffTime(now
,answertime
);
472 std::lock_guard
<std::mutex
> lock(g_rings
.respMutex
);
473 g_rings
.respRing
.push_back({answertime
, ci
.remote
, qname
, dq
.qtype
, (unsigned int)udiff
, (unsigned int)responseLen
, *dq
.dh
, ds
->remote
});
477 rewrittenResponse
.clear();
484 vinfolog("Closing TCP client connection with %s", ci
.remote
.toStringWithPort());
487 if (ds
&& outstanding
) {
496 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
497 they will hand off to worker threads & spawn more of them if required
499 void* tcpAcceptorThread(void* p
)
501 ClientState
* cs
= (ClientState
*) p
;
504 remote
.sin4
.sin_family
= cs
->local
.sin4
.sin_family
;
506 g_tcpclientthreads
->addTCPClientThread();
508 auto acl
= g_ACL
.getLocal();
513 ci
= new ConnectionInfo
;
516 ci
->fd
= SAccept(cs
->tcpFD
, remote
);
518 if(!acl
->match(remote
)) {
523 vinfolog("Dropped TCP connection from %s because of ACL", remote
.toStringWithPort());
527 if(g_maxTCPQueuedConnections
> 0 && g_tcpclientthreads
->d_queued
>= g_maxTCPQueuedConnections
) {
531 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote
.toStringWithPort());
535 vinfolog("Got TCP connection from %s", remote
.toStringWithPort());
538 int pipe
= g_tcpclientthreads
->getThread();
540 writen2WithTimeout(pipe
, &ci
, sizeof(ci
), 0);
543 --g_tcpclientthreads
->d_queued
;
549 catch(std::exception
& e
) {
550 errlog("While reading a TCP question: %s", e
.what());
551 if(ci
&& ci
->fd
>= 0)
562 bool getMsgLen32(int fd
, uint32_t* len
)
566 size_t ret
= readn2(fd
, &raw
, sizeof raw
);
567 if(ret
!= sizeof raw
)
570 if(*len
> 10000000) // arbitrary 10MB limit
578 bool putMsgLen32(int fd
, uint32_t len
)
581 uint32_t raw
= htonl(len
);
582 size_t ret
= writen2(fd
, &raw
, sizeof raw
);
583 return ret
==sizeof raw
;