2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
50 static int setupTCPDownstream(shared_ptr
<DownstreamState
> ds
)
52 vinfolog("TCP connecting to downstream %s", ds
->remote
.toStringWithPort());
53 int sock
= SSocket(ds
->remote
.sin4
.sin_family
, SOCK_STREAM
, 0);
54 if (!IsAnyAddress(ds
->sourceAddr
)) {
55 SSetsockopt(sock
, SOL_SOCKET
, SO_REUSEADDR
, 1);
56 SBind(sock
, ds
->sourceAddr
);
58 SConnect(sock
, ds
->remote
);
70 uint64_t g_maxTCPQueuedConnections
{1000};
71 void* tcpClientThread(int pipefd
);
73 // Should not be called simultaneously!
74 void TCPClientCollection::addTCPClientThread()
76 if (d_numthreads
>= d_tcpclientthreads
.capacity()) {
77 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads
.load(), d_tcpclientthreads
.capacity());
81 vinfolog("Adding TCP Client thread");
83 int pipefds
[2] = { -1, -1};
85 unixDie("Creating pipe");
87 if (!setNonBlocking(pipefds
[1])) {
90 unixDie("Setting pipe non-blocking");
93 d_tcpclientthreads
.push_back(pipefds
[1]);
95 thread
t1(tcpClientThread
, pipefds
[0]);
99 static bool getNonBlockingMsgLen(int fd
, uint16_t* len
, int timeout
)
103 size_t ret
= readn2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
104 if(ret
!= sizeof raw
)
113 static bool putNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
)
116 uint16_t raw
= htons(len
);
117 size_t ret
= writen2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
118 return ret
== sizeof raw
;
124 static bool sendNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
, ComboAddress
& dest
, ComboAddress
& local
, unsigned int localItf
)
128 return putNonBlockingMsgLen(fd
, len
, timeout
);
130 uint16_t raw
= htons(len
);
131 ssize_t ret
= sendMsgWithTimeout(fd
, (char*) &raw
, sizeof raw
, timeout
, dest
, local
, localItf
);
132 return ret
== sizeof raw
;
138 static bool sendResponseToClient(int fd
, const char* response
, uint16_t responseLen
)
140 if (!putNonBlockingMsgLen(fd
, responseLen
, g_tcpSendTimeout
))
143 writen2WithTimeout(fd
, response
, responseLen
, g_tcpSendTimeout
);
147 std::shared_ptr
<TCPClientCollection
> g_tcpclientthreads
;
149 void* tcpClientThread(int pipefd
)
151 /* we get launched with a pipe on which we receive file descriptors from clients that we own
152 from that point on */
154 bool outstanding
= false;
155 blockfilter_t blockFilter
= 0;
158 std::lock_guard
<std::mutex
> lock(g_luamutex
);
159 auto candidate
= g_lua
.readVariable
<boost::optional
<blockfilter_t
> >("blockFilter");
161 blockFilter
= *candidate
;
164 auto localPolicy
= g_policy
.getLocal();
165 auto localRulactions
= g_rulactions
.getLocal();
166 auto localRespRulactions
= g_resprulactions
.getLocal();
167 auto localDynBlockNMG
= g_dynblockNMG
.getLocal();
168 auto localDynBlockSMT
= g_dynblockSMT
.getLocal();
169 auto localPools
= g_pools
.getLocal();
171 boost::uuids::random_generator uuidGenerator
;
174 map
<ComboAddress
,int> sockets
;
176 ConnectionInfo
* citmp
, ci
;
179 readn2(pipefd
, &citmp
, sizeof(citmp
));
181 catch(const std::runtime_error
& e
) {
182 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode: " + e
.what());
185 --g_tcpclientthreads
->d_queued
;
191 vector
<uint8_t> rewrittenResponse
;
192 shared_ptr
<DownstreamState
> ds
;
194 memset(&dest
, 0, sizeof(dest
));
195 dest
.sin4
.sin_family
= ci
.remote
.sin4
.sin_family
;
196 socklen_t len
= dest
.getSocklen();
197 if (!setNonBlocking(ci
.fd
))
200 if (getsockname(ci
.fd
, (sockaddr
*)&dest
, &len
)) {
209 if(!getNonBlockingMsgLen(ci
.fd
, &qlen
, g_tcpRecvTimeout
))
215 if (qlen
< sizeof(dnsheader
)) {
216 g_stats
.nonCompliantQueries
++;
220 bool ednsAdded
= false;
221 bool ecsAdded
= false;
222 /* if the query is small, allocate a bit more
223 memory to be able to spoof the content,
224 or to add ECS without allocating a new buffer */
225 size_t querySize
= qlen
<= 4096 ? qlen
+ 512 : qlen
;
226 char queryBuffer
[querySize
];
227 const char* query
= queryBuffer
;
228 readn2WithTimeout(ci
.fd
, queryBuffer
, qlen
, g_tcpRecvTimeout
);
231 std::shared_ptr
<DnsCryptQuery
> dnsCryptQuery
= 0;
233 if (ci
.cs
->dnscryptCtx
) {
234 dnsCryptQuery
= std::make_shared
<DnsCryptQuery
>();
235 uint16_t decryptedQueryLen
= 0;
236 vector
<uint8_t> response
;
237 bool decrypted
= handleDnsCryptQuery(ci
.cs
->dnscryptCtx
, queryBuffer
, qlen
, dnsCryptQuery
, &decryptedQueryLen
, true, response
);
240 if (response
.size() > 0) {
241 sendResponseToClient(ci
.fd
, reinterpret_cast<char*>(response
.data()), (uint16_t) response
.size());
245 qlen
= decryptedQueryLen
;
248 struct dnsheader
* dh
= (struct dnsheader
*) query
;
250 if(dh
->qr
) { // don't respond to responses
251 g_stats
.nonCompliantQueries
++;
255 if(dh
->qdcount
== 0) {
256 g_stats
.emptyQueries
++;
264 const uint16_t* flags
= getFlagsFromDNSHeader(dh
);
265 uint16_t origFlags
= *flags
;
266 uint16_t qtype
, qclass
;
267 unsigned int consumed
= 0;
268 DNSName
qname(query
, qlen
, sizeof(dnsheader
), false, &qtype
, &qclass
, &consumed
);
269 DNSQuestion
dq(&qname
, qtype
, qclass
, &dest
, &ci
.remote
, (dnsheader
*)query
, querySize
, qlen
, true);
271 dq
.uniqueId
= uuidGenerator();
276 /* we need this one to be accurate ("real") for the protobuf message */
277 struct timespec queryRealTime
;
280 gettime(&queryRealTime
, true);
282 if (!processQuery(localDynBlockNMG
, localDynBlockSMT
, localRulactions
, blockFilter
, dq
, poolname
, &delayMsec
, now
)) {
286 if(dq
.dh
->qr
) { // something turned it into a response
287 restoreFlags(dh
, origFlags
);
289 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
)) {
293 sendResponseToClient(ci
.fd
, query
, dq
.len
);
294 g_stats
.selfAnswered
++;
298 std::shared_ptr
<ServerPool
> serverPool
= getPool(*localPools
, poolname
);
299 std::shared_ptr
<DNSDistPacketCache
> packetCache
= nullptr;
301 std::lock_guard
<std::mutex
> lock(g_luamutex
);
302 ds
= localPolicy
->policy(serverPool
->servers
, &dq
);
303 packetCache
= serverPool
->packetCache
;
306 if (dq
.useECS
&& ds
&& ds
->useECS
) {
307 uint16_t newLen
= dq
.len
;
308 handleEDNSClientSubnet(queryBuffer
, dq
.size
, consumed
, &newLen
, largerQuery
, &ednsAdded
, &ecsAdded
, ci
.remote
, dq
.ecsOverride
, dq
.ecsPrefixLength
);
309 if (largerQuery
.empty() == false) {
310 query
= largerQuery
.c_str();
311 dq
.len
= (uint16_t) largerQuery
.size();
312 dq
.size
= largerQuery
.size();
318 uint32_t cacheKey
= 0;
319 if (packetCache
&& !dq
.skipCache
) {
320 char cachedResponse
[4096];
321 uint16_t cachedResponseSize
= sizeof cachedResponse
;
322 uint32_t allowExpired
= ds
? 0 : g_staleCacheEntriesTTL
;
323 if (packetCache
->get(dq
, (uint16_t) consumed
, dq
.dh
->id
, cachedResponse
, &cachedResponseSize
, &cacheKey
, allowExpired
)) {
325 if (!encryptResponse(cachedResponse
, &cachedResponseSize
, sizeof cachedResponse
, true, dnsCryptQuery
)) {
329 sendResponseToClient(ci
.fd
, cachedResponse
, cachedResponseSize
);
333 g_stats
.cacheMisses
++;
342 if(sockets
.count(ds
->remote
) == 0) {
343 dsock
=sockets
[ds
->remote
]=setupTCPDownstream(ds
);
346 dsock
=sockets
[ds
->remote
];
352 uint16_t downstream_failures
=0;
355 sockets
.erase(ds
->remote
);
359 if (ds
->retries
> 0 && downstream_failures
> ds
->retries
) {
360 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds
->getName(), downstream_failures
);
363 sockets
.erase(ds
->remote
);
367 if(!sendNonBlockingMsgLen(dsock
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
)) {
368 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
371 sockets
.erase(ds
->remote
);
372 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
373 downstream_failures
++;
378 if (ds
->sourceItf
== 0) {
379 writen2WithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
);
382 sendMsgWithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
);
385 catch(const runtime_error
& e
) {
386 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
389 sockets
.erase(ds
->remote
);
390 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
391 downstream_failures
++;
395 bool xfrStarted
= false;
396 bool isXFR
= (dq
.qtype
== QType::AXFR
|| dq
.qtype
== QType::IXFR
);
403 if(!getNonBlockingMsgLen(dsock
, &rlen
, ds
->tcpRecvTimeout
)) {
404 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds
->getName());
407 sockets
.erase(ds
->remote
);
408 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
409 downstream_failures
++;
416 size_t responseSize
= rlen
;
417 uint16_t addRoom
= 0;
419 if (dnsCryptQuery
&& (UINT16_MAX
- rlen
) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
) {
420 addRoom
= DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
;
423 responseSize
+= addRoom
;
424 char answerbuffer
[responseSize
];
425 readn2WithTimeout(dsock
, answerbuffer
, rlen
, ds
->tcpRecvTimeout
);
426 char* response
= answerbuffer
;
427 uint16_t responseLen
= rlen
;
429 /* might be false for {A,I}XFR */
434 if (rlen
< sizeof(dnsheader
)) {
438 if (!responseContentMatches(response
, responseLen
, qname
, qtype
, qclass
, ds
->remote
)) {
442 if (!fixUpResponse(&response
, &responseLen
, &responseSize
, qname
, origFlags
, ednsAdded
, ecsAdded
, rewrittenResponse
, addRoom
)) {
446 dh
= (struct dnsheader
*) response
;
447 DNSResponse
dr(&qname
, qtype
, qclass
, &dest
, &ci
.remote
, dh
, responseSize
, responseLen
, true, &queryRealTime
);
449 dr
.uniqueId
= dq
.uniqueId
;
451 if (!processResponse(localRespRulactions
, dr
, &delayMsec
)) {
455 if (packetCache
&& !dq
.skipCache
) {
456 packetCache
->insert(cacheKey
, qname
, qtype
, qclass
, response
, responseLen
, true, dh
->rcode
== RCode::ServFail
);
460 if (!encryptResponse(response
, &responseLen
, responseSize
, true, dnsCryptQuery
)) {
464 if (!sendResponseToClient(ci
.fd
, response
, responseLen
)) {
468 if (isXFR
&& dh
->rcode
== 0 && dh
->ancount
!= 0) {
469 if (xfrStarted
== false) {
471 if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 1) {
475 else if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 0) {
481 struct timespec answertime
;
482 gettime(&answertime
);
483 unsigned int udiff
= 1000000.0*DiffTime(now
,answertime
);
485 std::lock_guard
<std::mutex
> lock(g_rings
.respMutex
);
486 g_rings
.respRing
.push_back({answertime
, ci
.remote
, qname
, dq
.qtype
, (unsigned int)udiff
, (unsigned int)responseLen
, *dh
, ds
->remote
});
490 rewrittenResponse
.clear();
497 vinfolog("Closing TCP client connection with %s", ci
.remote
.toStringWithPort());
500 if (ds
&& outstanding
) {
509 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
510 they will hand off to worker threads & spawn more of them if required
512 void* tcpAcceptorThread(void* p
)
514 ClientState
* cs
= (ClientState
*) p
;
517 remote
.sin4
.sin_family
= cs
->local
.sin4
.sin_family
;
519 g_tcpclientthreads
->addTCPClientThread();
521 auto acl
= g_ACL
.getLocal();
526 ci
= new ConnectionInfo
;
529 ci
->fd
= SAccept(cs
->tcpFD
, remote
);
531 if(!acl
->match(remote
)) {
536 vinfolog("Dropped TCP connection from %s because of ACL", remote
.toStringWithPort());
540 if(g_maxTCPQueuedConnections
> 0 && g_tcpclientthreads
->d_queued
>= g_maxTCPQueuedConnections
) {
544 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote
.toStringWithPort());
548 vinfolog("Got TCP connection from %s", remote
.toStringWithPort());
551 int pipe
= g_tcpclientthreads
->getThread();
553 writen2WithTimeout(pipe
, &ci
, sizeof(ci
), 0);
556 --g_tcpclientthreads
->d_queued
;
562 catch(std::exception
& e
) {
563 errlog("While reading a TCP question: %s", e
.what());
564 if(ci
&& ci
->fd
>= 0)
575 bool getMsgLen32(int fd
, uint32_t* len
)
579 size_t ret
= readn2(fd
, &raw
, sizeof raw
);
580 if(ret
!= sizeof raw
)
583 if(*len
> 10000000) // arbitrary 10MB limit
591 bool putMsgLen32(int fd
, uint32_t len
)
594 uint32_t raw
= htonl(len
);
595 size_t ret
= writen2(fd
, &raw
, sizeof raw
);
596 return ret
==sizeof raw
;