2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
50 static int setupTCPDownstream(shared_ptr
<DownstreamState
> ds
)
52 vinfolog("TCP connecting to downstream %s", ds
->remote
.toStringWithPort());
53 int sock
= SSocket(ds
->remote
.sin4
.sin_family
, SOCK_STREAM
, 0);
54 if (!IsAnyAddress(ds
->sourceAddr
)) {
55 SSetsockopt(sock
, SOL_SOCKET
, SO_REUSEADDR
, 1);
56 SBind(sock
, ds
->sourceAddr
);
58 SConnect(sock
, ds
->remote
);
70 uint64_t g_maxTCPQueuedConnections
{0};
71 void* tcpClientThread(int pipefd
);
73 // Should not be called simultaneously!
74 void TCPClientCollection::addTCPClientThread()
76 if (d_numthreads
>= d_tcpclientthreads
.capacity()) {
77 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads
.load(), d_tcpclientthreads
.capacity());
81 vinfolog("Adding TCP Client thread");
83 int pipefds
[2] = { -1, -1};
85 unixDie("Creating pipe");
87 if (!setNonBlocking(pipefds
[1])) {
90 unixDie("Setting pipe non-blocking");
93 d_tcpclientthreads
.push_back(pipefds
[1]);
95 thread
t1(tcpClientThread
, pipefds
[0]);
99 static bool getNonBlockingMsgLen(int fd
, uint16_t* len
, int timeout
)
103 size_t ret
= readn2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
104 if(ret
!= sizeof raw
)
113 static bool putNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
)
116 uint16_t raw
= htons(len
);
117 size_t ret
= writen2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
118 return ret
== sizeof raw
;
124 static bool sendNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
, ComboAddress
& dest
, ComboAddress
& local
, unsigned int localItf
)
128 return putNonBlockingMsgLen(fd
, len
, timeout
);
130 uint16_t raw
= htons(len
);
131 ssize_t ret
= sendMsgWithTimeout(fd
, (char*) &raw
, sizeof raw
, timeout
, dest
, local
, localItf
);
132 return ret
== sizeof raw
;
138 static bool sendResponseToClient(int fd
, const char* response
, uint16_t responseLen
)
140 if (!putNonBlockingMsgLen(fd
, responseLen
, g_tcpSendTimeout
))
143 writen2WithTimeout(fd
, response
, responseLen
, g_tcpSendTimeout
);
147 std::shared_ptr
<TCPClientCollection
> g_tcpclientthreads
;
149 void* tcpClientThread(int pipefd
)
151 /* we get launched with a pipe on which we receive file descriptors from clients that we own
152 from that point on */
154 bool outstanding
= false;
155 blockfilter_t blockFilter
= 0;
158 std::lock_guard
<std::mutex
> lock(g_luamutex
);
159 auto candidate
= g_lua
.readVariable
<boost::optional
<blockfilter_t
> >("blockFilter");
161 blockFilter
= *candidate
;
164 auto localPolicy
= g_policy
.getLocal();
165 auto localRulactions
= g_rulactions
.getLocal();
166 auto localRespRulactions
= g_resprulactions
.getLocal();
167 auto localDynBlockNMG
= g_dynblockNMG
.getLocal();
168 auto localDynBlockSMT
= g_dynblockSMT
.getLocal();
169 auto localPools
= g_pools
.getLocal();
171 boost::uuids::random_generator uuidGenerator
;
174 map
<ComboAddress
,int> sockets
;
176 ConnectionInfo
* citmp
, ci
;
179 readn2(pipefd
, &citmp
, sizeof(citmp
));
181 catch(const std::runtime_error
& e
) {
182 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode: " + e
.what());
185 --g_tcpclientthreads
->d_queued
;
191 vector
<uint8_t> rewrittenResponse
;
192 shared_ptr
<DownstreamState
> ds
;
193 if (!setNonBlocking(ci
.fd
))
201 if(!getNonBlockingMsgLen(ci
.fd
, &qlen
, g_tcpRecvTimeout
))
207 if (qlen
< sizeof(dnsheader
)) {
208 g_stats
.nonCompliantQueries
++;
212 bool ednsAdded
= false;
213 bool ecsAdded
= false;
214 /* if the query is small, allocate a bit more
215 memory to be able to spoof the content,
216 or to add ECS without allocating a new buffer */
217 size_t querySize
= qlen
<= 4096 ? qlen
+ 512 : qlen
;
218 char queryBuffer
[querySize
];
219 const char* query
= queryBuffer
;
220 readn2WithTimeout(ci
.fd
, queryBuffer
, qlen
, g_tcpRecvTimeout
);
223 std::shared_ptr
<DnsCryptQuery
> dnsCryptQuery
= 0;
225 if (ci
.cs
->dnscryptCtx
) {
226 dnsCryptQuery
= std::make_shared
<DnsCryptQuery
>();
227 uint16_t decryptedQueryLen
= 0;
228 vector
<uint8_t> response
;
229 bool decrypted
= handleDnsCryptQuery(ci
.cs
->dnscryptCtx
, queryBuffer
, qlen
, dnsCryptQuery
, &decryptedQueryLen
, true, response
);
232 if (response
.size() > 0) {
233 sendResponseToClient(ci
.fd
, reinterpret_cast<char*>(response
.data()), (uint16_t) response
.size());
237 qlen
= decryptedQueryLen
;
240 struct dnsheader
* dh
= (struct dnsheader
*) query
;
242 if(dh
->qr
) { // don't respond to responses
243 g_stats
.nonCompliantQueries
++;
247 if(dh
->qdcount
== 0) {
248 g_stats
.emptyQueries
++;
256 const uint16_t* flags
= getFlagsFromDNSHeader(dh
);
257 uint16_t origFlags
= *flags
;
258 uint16_t qtype
, qclass
;
259 unsigned int consumed
= 0;
260 DNSName
qname(query
, qlen
, sizeof(dnsheader
), false, &qtype
, &qclass
, &consumed
);
261 DNSQuestion
dq(&qname
, qtype
, qclass
, &ci
.cs
->local
, &ci
.remote
, (dnsheader
*)query
, querySize
, qlen
, true);
263 dq
.uniqueId
= uuidGenerator();
271 if (!processQuery(localDynBlockNMG
, localDynBlockSMT
, localRulactions
, blockFilter
, dq
, poolname
, &delayMsec
, now
)) {
275 if(dq
.dh
->qr
) { // something turned it into a response
276 restoreFlags(dh
, origFlags
);
278 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
)) {
282 sendResponseToClient(ci
.fd
, query
, dq
.len
);
283 g_stats
.selfAnswered
++;
287 std::shared_ptr
<ServerPool
> serverPool
= getPool(*localPools
, poolname
);
288 std::shared_ptr
<DNSDistPacketCache
> packetCache
= nullptr;
290 std::lock_guard
<std::mutex
> lock(g_luamutex
);
291 ds
= localPolicy
->policy(serverPool
->servers
, &dq
);
292 packetCache
= serverPool
->packetCache
;
295 if (dq
.useECS
&& ds
&& ds
->useECS
) {
296 uint16_t newLen
= dq
.len
;
297 handleEDNSClientSubnet(queryBuffer
, dq
.size
, consumed
, &newLen
, largerQuery
, &ednsAdded
, &ecsAdded
, ci
.remote
, dq
.ecsOverride
, dq
.ecsPrefixLength
);
298 if (largerQuery
.empty() == false) {
299 query
= largerQuery
.c_str();
300 dq
.len
= (uint16_t) largerQuery
.size();
301 dq
.size
= largerQuery
.size();
307 uint32_t cacheKey
= 0;
308 if (packetCache
&& !dq
.skipCache
) {
309 char cachedResponse
[4096];
310 uint16_t cachedResponseSize
= sizeof cachedResponse
;
311 uint32_t allowExpired
= ds
? 0 : g_staleCacheEntriesTTL
;
312 if (packetCache
->get(dq
, (uint16_t) consumed
, dq
.dh
->id
, cachedResponse
, &cachedResponseSize
, &cacheKey
, allowExpired
)) {
314 if (!encryptResponse(cachedResponse
, &cachedResponseSize
, sizeof cachedResponse
, true, dnsCryptQuery
)) {
318 sendResponseToClient(ci
.fd
, cachedResponse
, cachedResponseSize
);
322 g_stats
.cacheMisses
++;
331 if(sockets
.count(ds
->remote
) == 0) {
332 dsock
=sockets
[ds
->remote
]=setupTCPDownstream(ds
);
335 dsock
=sockets
[ds
->remote
];
341 uint16_t downstream_failures
=0;
344 sockets
.erase(ds
->remote
);
348 if (ds
->retries
> 0 && downstream_failures
> ds
->retries
) {
349 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds
->getName(), downstream_failures
);
352 sockets
.erase(ds
->remote
);
356 if(!sendNonBlockingMsgLen(dsock
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
)) {
357 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
360 sockets
.erase(ds
->remote
);
361 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
362 downstream_failures
++;
367 if (ds
->sourceItf
== 0) {
368 writen2WithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
);
371 sendMsgWithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
);
374 catch(const runtime_error
& e
) {
375 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
378 sockets
.erase(ds
->remote
);
379 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
380 downstream_failures
++;
384 bool xfrStarted
= false;
385 bool isXFR
= (dq
.qtype
== QType::AXFR
|| dq
.qtype
== QType::IXFR
);
392 if(!getNonBlockingMsgLen(dsock
, &rlen
, ds
->tcpRecvTimeout
)) {
393 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds
->getName());
396 sockets
.erase(ds
->remote
);
397 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
398 downstream_failures
++;
405 size_t responseSize
= rlen
;
406 uint16_t addRoom
= 0;
408 if (dnsCryptQuery
&& (UINT16_MAX
- rlen
) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
) {
409 addRoom
= DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
;
412 responseSize
+= addRoom
;
413 char answerbuffer
[responseSize
];
414 readn2WithTimeout(dsock
, answerbuffer
, rlen
, ds
->tcpRecvTimeout
);
415 char* response
= answerbuffer
;
416 uint16_t responseLen
= rlen
;
418 /* might be false for {A,I}XFR */
423 if (rlen
< sizeof(dnsheader
)) {
427 if (!responseContentMatches(response
, responseLen
, qname
, qtype
, qclass
, ds
->remote
)) {
431 if (!fixUpResponse(&response
, &responseLen
, &responseSize
, qname
, origFlags
, ednsAdded
, ecsAdded
, rewrittenResponse
, addRoom
)) {
435 dh
= (struct dnsheader
*) response
;
436 DNSResponse
dr(&qname
, qtype
, qclass
, &ci
.cs
->local
, &ci
.remote
, dh
, responseSize
, responseLen
, true, &now
);
438 dr
.uniqueId
= dq
.uniqueId
;
440 if (!processResponse(localRespRulactions
, dr
, &delayMsec
)) {
444 if (packetCache
&& !dq
.skipCache
) {
445 packetCache
->insert(cacheKey
, qname
, qtype
, qclass
, response
, responseLen
, true, dh
->rcode
== RCode::ServFail
);
449 if (!encryptResponse(response
, &responseLen
, responseSize
, true, dnsCryptQuery
)) {
453 if (!sendResponseToClient(ci
.fd
, response
, responseLen
)) {
457 if (isXFR
&& dh
->rcode
== 0 && dh
->ancount
!= 0) {
458 if (xfrStarted
== false) {
460 if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 1) {
464 else if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 0) {
470 struct timespec answertime
;
471 gettime(&answertime
);
472 unsigned int udiff
= 1000000.0*DiffTime(now
,answertime
);
474 std::lock_guard
<std::mutex
> lock(g_rings
.respMutex
);
475 g_rings
.respRing
.push_back({answertime
, ci
.remote
, qname
, dq
.qtype
, (unsigned int)udiff
, (unsigned int)responseLen
, *dq
.dh
, ds
->remote
});
479 rewrittenResponse
.clear();
486 vinfolog("Closing TCP client connection with %s", ci
.remote
.toStringWithPort());
489 if (ds
&& outstanding
) {
498 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
499 they will hand off to worker threads & spawn more of them if required
501 void* tcpAcceptorThread(void* p
)
503 ClientState
* cs
= (ClientState
*) p
;
506 remote
.sin4
.sin_family
= cs
->local
.sin4
.sin_family
;
508 g_tcpclientthreads
->addTCPClientThread();
510 auto acl
= g_ACL
.getLocal();
515 ci
= new ConnectionInfo
;
518 ci
->fd
= SAccept(cs
->tcpFD
, remote
);
520 if(!acl
->match(remote
)) {
525 vinfolog("Dropped TCP connection from %s because of ACL", remote
.toStringWithPort());
529 if(g_maxTCPQueuedConnections
> 0 && g_tcpclientthreads
->d_queued
>= g_maxTCPQueuedConnections
) {
533 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote
.toStringWithPort());
537 vinfolog("Got TCP connection from %s", remote
.toStringWithPort());
540 int pipe
= g_tcpclientthreads
->getThread();
542 writen2WithTimeout(pipe
, &ci
, sizeof(ci
), 0);
545 --g_tcpclientthreads
->d_queued
;
551 catch(std::exception
& e
) {
552 errlog("While reading a TCP question: %s", e
.what());
553 if(ci
&& ci
->fd
>= 0)
564 bool getMsgLen32(int fd
, uint32_t* len
)
568 size_t ret
= readn2(fd
, &raw
, sizeof raw
);
569 if(ret
!= sizeof raw
)
572 if(*len
> 10000000) // arbitrary 10MB limit
580 bool putMsgLen32(int fd
, uint32_t len
)
583 uint32_t raw
= htonl(len
);
584 size_t ret
= writen2(fd
, &raw
, sizeof raw
);
585 return ret
==sizeof raw
;