2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
50 static int setupTCPDownstream(shared_ptr
<DownstreamState
> ds
)
52 vinfolog("TCP connecting to downstream %s", ds
->remote
.toStringWithPort());
53 int sock
= SSocket(ds
->remote
.sin4
.sin_family
, SOCK_STREAM
, 0);
54 if (!IsAnyAddress(ds
->sourceAddr
)) {
55 SSetsockopt(sock
, SOL_SOCKET
, SO_REUSEADDR
, 1);
56 SBind(sock
, ds
->sourceAddr
);
58 SConnect(sock
, ds
->remote
);
70 uint64_t g_maxTCPQueuedConnections
{1000};
71 void* tcpClientThread(int pipefd
);
73 // Should not be called simultaneously!
74 void TCPClientCollection::addTCPClientThread()
76 if (d_numthreads
>= d_tcpclientthreads
.capacity()) {
77 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads
.load(), d_tcpclientthreads
.capacity());
81 vinfolog("Adding TCP Client thread");
83 int pipefds
[2] = { -1, -1};
85 unixDie("Creating pipe");
87 if (!setNonBlocking(pipefds
[1])) {
90 unixDie("Setting pipe non-blocking");
93 d_tcpclientthreads
.push_back(pipefds
[1]);
95 thread
t1(tcpClientThread
, pipefds
[0]);
99 static bool getNonBlockingMsgLen(int fd
, uint16_t* len
, int timeout
)
103 size_t ret
= readn2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
104 if(ret
!= sizeof raw
)
113 static bool putNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
)
116 uint16_t raw
= htons(len
);
117 size_t ret
= writen2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
118 return ret
== sizeof raw
;
124 static bool sendNonBlockingMsgLen(int fd
, uint16_t len
, int timeout
, ComboAddress
& dest
, ComboAddress
& local
, unsigned int localItf
)
128 return putNonBlockingMsgLen(fd
, len
, timeout
);
130 uint16_t raw
= htons(len
);
131 ssize_t ret
= sendMsgWithTimeout(fd
, (char*) &raw
, sizeof raw
, timeout
, dest
, local
, localItf
);
132 return ret
== sizeof raw
;
138 static bool sendResponseToClient(int fd
, const char* response
, uint16_t responseLen
)
140 if (!putNonBlockingMsgLen(fd
, responseLen
, g_tcpSendTimeout
))
143 writen2WithTimeout(fd
, response
, responseLen
, g_tcpSendTimeout
);
147 std::shared_ptr
<TCPClientCollection
> g_tcpclientthreads
;
149 void* tcpClientThread(int pipefd
)
151 /* we get launched with a pipe on which we receive file descriptors from clients that we own
152 from that point on */
154 bool outstanding
= false;
155 blockfilter_t blockFilter
= 0;
158 std::lock_guard
<std::mutex
> lock(g_luamutex
);
159 auto candidate
= g_lua
.readVariable
<boost::optional
<blockfilter_t
> >("blockFilter");
161 blockFilter
= *candidate
;
164 auto localPolicy
= g_policy
.getLocal();
165 auto localRulactions
= g_rulactions
.getLocal();
166 auto localRespRulactions
= g_resprulactions
.getLocal();
167 auto localDynBlockNMG
= g_dynblockNMG
.getLocal();
168 auto localDynBlockSMT
= g_dynblockSMT
.getLocal();
169 auto localPools
= g_pools
.getLocal();
171 boost::uuids::random_generator uuidGenerator
;
174 map
<ComboAddress
,int> sockets
;
176 ConnectionInfo
* citmp
, ci
;
179 readn2(pipefd
, &citmp
, sizeof(citmp
));
181 catch(const std::runtime_error
& e
) {
182 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode: " + e
.what());
185 --g_tcpclientthreads
->d_queued
;
191 vector
<uint8_t> rewrittenResponse
;
192 shared_ptr
<DownstreamState
> ds
;
193 if (!setNonBlocking(ci
.fd
))
201 if(!getNonBlockingMsgLen(ci
.fd
, &qlen
, g_tcpRecvTimeout
))
207 if (qlen
< sizeof(dnsheader
)) {
208 g_stats
.nonCompliantQueries
++;
212 bool ednsAdded
= false;
213 bool ecsAdded
= false;
214 /* if the query is small, allocate a bit more
215 memory to be able to spoof the content,
216 or to add ECS without allocating a new buffer */
217 size_t querySize
= qlen
<= 4096 ? qlen
+ 512 : qlen
;
218 char queryBuffer
[querySize
];
219 const char* query
= queryBuffer
;
220 readn2WithTimeout(ci
.fd
, queryBuffer
, qlen
, g_tcpRecvTimeout
);
223 std::shared_ptr
<DnsCryptQuery
> dnsCryptQuery
= 0;
225 if (ci
.cs
->dnscryptCtx
) {
226 dnsCryptQuery
= std::make_shared
<DnsCryptQuery
>();
227 uint16_t decryptedQueryLen
= 0;
228 vector
<uint8_t> response
;
229 bool decrypted
= handleDnsCryptQuery(ci
.cs
->dnscryptCtx
, queryBuffer
, qlen
, dnsCryptQuery
, &decryptedQueryLen
, true, response
);
232 if (response
.size() > 0) {
233 sendResponseToClient(ci
.fd
, reinterpret_cast<char*>(response
.data()), (uint16_t) response
.size());
237 qlen
= decryptedQueryLen
;
240 struct dnsheader
* dh
= (struct dnsheader
*) query
;
242 if(dh
->qr
) { // don't respond to responses
243 g_stats
.nonCompliantQueries
++;
247 if(dh
->qdcount
== 0) {
248 g_stats
.emptyQueries
++;
256 const uint16_t* flags
= getFlagsFromDNSHeader(dh
);
257 uint16_t origFlags
= *flags
;
258 uint16_t qtype
, qclass
;
259 unsigned int consumed
= 0;
260 DNSName
qname(query
, qlen
, sizeof(dnsheader
), false, &qtype
, &qclass
, &consumed
);
261 DNSQuestion
dq(&qname
, qtype
, qclass
, &ci
.cs
->local
, &ci
.remote
, (dnsheader
*)query
, querySize
, qlen
, true);
263 dq
.uniqueId
= uuidGenerator();
268 /* we need this one to be accurate ("real") for the protobuf message */
269 struct timespec queryRealTime
;
272 gettime(&queryRealTime
, true);
274 if (!processQuery(localDynBlockNMG
, localDynBlockSMT
, localRulactions
, blockFilter
, dq
, poolname
, &delayMsec
, now
)) {
278 if(dq
.dh
->qr
) { // something turned it into a response
279 restoreFlags(dh
, origFlags
);
281 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
)) {
285 sendResponseToClient(ci
.fd
, query
, dq
.len
);
286 g_stats
.selfAnswered
++;
290 std::shared_ptr
<ServerPool
> serverPool
= getPool(*localPools
, poolname
);
291 std::shared_ptr
<DNSDistPacketCache
> packetCache
= nullptr;
293 std::lock_guard
<std::mutex
> lock(g_luamutex
);
294 ds
= localPolicy
->policy(serverPool
->servers
, &dq
);
295 packetCache
= serverPool
->packetCache
;
298 if (dq
.useECS
&& ds
&& ds
->useECS
) {
299 uint16_t newLen
= dq
.len
;
300 handleEDNSClientSubnet(queryBuffer
, dq
.size
, consumed
, &newLen
, largerQuery
, &ednsAdded
, &ecsAdded
, ci
.remote
, dq
.ecsOverride
, dq
.ecsPrefixLength
);
301 if (largerQuery
.empty() == false) {
302 query
= largerQuery
.c_str();
303 dq
.len
= (uint16_t) largerQuery
.size();
304 dq
.size
= largerQuery
.size();
310 uint32_t cacheKey
= 0;
311 if (packetCache
&& !dq
.skipCache
) {
312 char cachedResponse
[4096];
313 uint16_t cachedResponseSize
= sizeof cachedResponse
;
314 uint32_t allowExpired
= ds
? 0 : g_staleCacheEntriesTTL
;
315 if (packetCache
->get(dq
, (uint16_t) consumed
, dq
.dh
->id
, cachedResponse
, &cachedResponseSize
, &cacheKey
, allowExpired
)) {
317 if (!encryptResponse(cachedResponse
, &cachedResponseSize
, sizeof cachedResponse
, true, dnsCryptQuery
)) {
321 sendResponseToClient(ci
.fd
, cachedResponse
, cachedResponseSize
);
325 g_stats
.cacheMisses
++;
334 if(sockets
.count(ds
->remote
) == 0) {
335 dsock
=sockets
[ds
->remote
]=setupTCPDownstream(ds
);
338 dsock
=sockets
[ds
->remote
];
344 uint16_t downstream_failures
=0;
347 sockets
.erase(ds
->remote
);
351 if (ds
->retries
> 0 && downstream_failures
> ds
->retries
) {
352 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds
->getName(), downstream_failures
);
355 sockets
.erase(ds
->remote
);
359 if(!sendNonBlockingMsgLen(dsock
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
)) {
360 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
363 sockets
.erase(ds
->remote
);
364 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
365 downstream_failures
++;
370 if (ds
->sourceItf
== 0) {
371 writen2WithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
);
374 sendMsgWithTimeout(dsock
, query
, dq
.len
, ds
->tcpSendTimeout
, ds
->remote
, ds
->sourceAddr
, ds
->sourceItf
);
377 catch(const runtime_error
& e
) {
378 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
381 sockets
.erase(ds
->remote
);
382 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
383 downstream_failures
++;
387 bool xfrStarted
= false;
388 bool isXFR
= (dq
.qtype
== QType::AXFR
|| dq
.qtype
== QType::IXFR
);
395 if(!getNonBlockingMsgLen(dsock
, &rlen
, ds
->tcpRecvTimeout
)) {
396 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds
->getName());
399 sockets
.erase(ds
->remote
);
400 sockets
[ds
->remote
]=dsock
=setupTCPDownstream(ds
);
401 downstream_failures
++;
408 size_t responseSize
= rlen
;
409 uint16_t addRoom
= 0;
411 if (dnsCryptQuery
&& (UINT16_MAX
- rlen
) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
) {
412 addRoom
= DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
;
415 responseSize
+= addRoom
;
416 char answerbuffer
[responseSize
];
417 readn2WithTimeout(dsock
, answerbuffer
, rlen
, ds
->tcpRecvTimeout
);
418 char* response
= answerbuffer
;
419 uint16_t responseLen
= rlen
;
421 /* might be false for {A,I}XFR */
426 if (rlen
< sizeof(dnsheader
)) {
430 if (!responseContentMatches(response
, responseLen
, qname
, qtype
, qclass
, ds
->remote
)) {
434 if (!fixUpResponse(&response
, &responseLen
, &responseSize
, qname
, origFlags
, ednsAdded
, ecsAdded
, rewrittenResponse
, addRoom
)) {
438 dh
= (struct dnsheader
*) response
;
439 DNSResponse
dr(&qname
, qtype
, qclass
, &ci
.cs
->local
, &ci
.remote
, dh
, responseSize
, responseLen
, true, &queryRealTime
);
441 dr
.uniqueId
= dq
.uniqueId
;
443 if (!processResponse(localRespRulactions
, dr
, &delayMsec
)) {
447 if (packetCache
&& !dq
.skipCache
) {
448 packetCache
->insert(cacheKey
, qname
, qtype
, qclass
, response
, responseLen
, true, dh
->rcode
== RCode::ServFail
);
452 if (!encryptResponse(response
, &responseLen
, responseSize
, true, dnsCryptQuery
)) {
456 if (!sendResponseToClient(ci
.fd
, response
, responseLen
)) {
460 if (isXFR
&& dh
->rcode
== 0 && dh
->ancount
!= 0) {
461 if (xfrStarted
== false) {
463 if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 1) {
467 else if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 0) {
473 struct timespec answertime
;
474 gettime(&answertime
);
475 unsigned int udiff
= 1000000.0*DiffTime(now
,answertime
);
477 std::lock_guard
<std::mutex
> lock(g_rings
.respMutex
);
478 g_rings
.respRing
.push_back({answertime
, ci
.remote
, qname
, dq
.qtype
, (unsigned int)udiff
, (unsigned int)responseLen
, *dh
, ds
->remote
});
482 rewrittenResponse
.clear();
489 vinfolog("Closing TCP client connection with %s", ci
.remote
.toStringWithPort());
492 if (ds
&& outstanding
) {
501 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
502 they will hand off to worker threads & spawn more of them if required
504 void* tcpAcceptorThread(void* p
)
506 ClientState
* cs
= (ClientState
*) p
;
509 remote
.sin4
.sin_family
= cs
->local
.sin4
.sin_family
;
511 g_tcpclientthreads
->addTCPClientThread();
513 auto acl
= g_ACL
.getLocal();
518 ci
= new ConnectionInfo
;
521 ci
->fd
= SAccept(cs
->tcpFD
, remote
);
523 if(!acl
->match(remote
)) {
528 vinfolog("Dropped TCP connection from %s because of ACL", remote
.toStringWithPort());
532 if(g_maxTCPQueuedConnections
> 0 && g_tcpclientthreads
->d_queued
>= g_maxTCPQueuedConnections
) {
536 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote
.toStringWithPort());
540 vinfolog("Got TCP connection from %s", remote
.toStringWithPort());
543 int pipe
= g_tcpclientthreads
->getThread();
545 writen2WithTimeout(pipe
, &ci
, sizeof(ci
), 0);
548 --g_tcpclientthreads
->d_queued
;
554 catch(std::exception
& e
) {
555 errlog("While reading a TCP question: %s", e
.what());
556 if(ci
&& ci
->fd
>= 0)
567 bool getMsgLen32(int fd
, uint32_t* len
)
571 size_t ret
= readn2(fd
, &raw
, sizeof raw
);
572 if(ret
!= sizeof raw
)
575 if(*len
> 10000000) // arbitrary 10MB limit
583 bool putMsgLen32(int fd
, uint32_t len
)
586 uint32_t raw
= htonl(len
);
587 size_t ret
= writen2(fd
, &raw
, sizeof raw
);
588 return ret
==sizeof raw
;