2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
50 static int setupTCPDownstream(shared_ptr
<DownstreamState
> ds
, uint16_t& downstreamFailures
)
53 vinfolog("TCP connecting to downstream %s (%d)", ds
->remote
.toStringWithPort(), downstreamFailures
);
54 int sock
= SSocket(ds
->remote
.sin4
.sin_family
, SOCK_STREAM
, 0);
56 if (!IsAnyAddress(ds
->sourceAddr
)) {
57 SSetsockopt(sock
, SOL_SOCKET
, SO_REUSEADDR
, 1);
58 #ifdef IP_BIND_ADDRESS_NO_PORT
59 SSetsockopt(sock
, SOL_IP
, IP_BIND_ADDRESS_NO_PORT
, 1);
61 SBind(sock
, ds
->sourceAddr
);
65 if (!ds
->tcpFastOpen
) {
66 SConnectWithTimeout(sock
, ds
->remote
, ds
->tcpConnectTimeout
);
69 SConnectWithTimeout(sock
, ds
->remote
, ds
->tcpConnectTimeout
);
70 #endif /* MSG_FASTOPEN */
73 catch(const std::runtime_error
& e
) {
74 /* don't leak our file descriptor if SConnect() (for example) throws */
77 if (downstreamFailures
> ds
->retries
) {
81 } while(downstreamFailures
<= ds
->retries
);
93 uint64_t g_maxTCPQueuedConnections
{1000};
94 size_t g_maxTCPQueriesPerConn
{0};
95 size_t g_maxTCPConnectionDuration
{0};
96 size_t g_maxTCPConnectionsPerClient
{0};
97 static std::mutex tcpClientsCountMutex
;
98 static std::map
<ComboAddress
,size_t,ComboAddress::addressOnlyLessThan
> tcpClientsCount
;
99 bool g_useTCPSinglePipe
{false};
100 std::atomic
<uint16_t> g_downstreamTCPCleanupInterval
{60};
102 void* tcpClientThread(int pipefd
);
104 static void decrementTCPClientCount(const ComboAddress
& client
)
106 if (g_maxTCPConnectionsPerClient
) {
107 std::lock_guard
<std::mutex
> lock(tcpClientsCountMutex
);
108 tcpClientsCount
[client
]--;
109 if (tcpClientsCount
[client
] == 0) {
110 tcpClientsCount
.erase(client
);
115 void TCPClientCollection::addTCPClientThread()
117 int pipefds
[2] = { -1, -1};
119 vinfolog("Adding TCP Client thread");
121 if (d_useSinglePipe
) {
122 pipefds
[0] = d_singlePipe
[0];
123 pipefds
[1] = d_singlePipe
[1];
126 if (pipe(pipefds
) < 0) {
127 errlog("Error creating the TCP thread communication pipe: %s", strerror(errno
));
131 if (!setNonBlocking(pipefds
[1])) {
134 errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno
));
140 std::lock_guard
<std::mutex
> lock(d_mutex
);
142 if (d_numthreads
>= d_tcpclientthreads
.capacity()) {
143 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads
.load(), d_tcpclientthreads
.capacity());
144 if (!d_useSinglePipe
) {
152 thread
t1(tcpClientThread
, pipefds
[0]);
155 catch(const std::runtime_error
& e
) {
156 /* the thread creation failed, don't leak */
157 errlog("Error creating a TCP thread: %s", e
.what());
158 if (!d_useSinglePipe
) {
165 d_tcpclientthreads
.push_back(pipefds
[1]);
171 static bool getNonBlockingMsgLen(int fd
, uint16_t* len
, int timeout
)
175 size_t ret
= readn2WithTimeout(fd
, &raw
, sizeof raw
, timeout
);
176 if(ret
!= sizeof raw
)
185 static bool sendResponseToClient(int fd
, const char* response
, uint16_t responseLen
)
187 return sendSizeAndMsgWithTimeout(fd
, responseLen
, response
, g_tcpSendTimeout
, nullptr, nullptr, 0, 0, 0);
190 static bool maxConnectionDurationReached(unsigned int maxConnectionDuration
, time_t start
, unsigned int& remainingTime
)
192 if (maxConnectionDuration
) {
193 time_t elapsed
= time(NULL
) - start
;
194 if (elapsed
>= maxConnectionDuration
) {
197 remainingTime
= maxConnectionDuration
- elapsed
;
202 void cleanupClosedTCPConnections(std::map
<ComboAddress
,int>& sockets
)
204 for(auto it
= sockets
.begin(); it
!= sockets
.end(); ) {
205 if (isTCPSocketUsable(it
->second
)) {
210 it
= sockets
.erase(it
);
215 std::shared_ptr
<TCPClientCollection
> g_tcpclientthreads
;
217 void* tcpClientThread(int pipefd
)
219 /* we get launched with a pipe on which we receive file descriptors from clients that we own
220 from that point on */
222 bool outstanding
= false;
223 time_t lastTCPCleanup
= time(nullptr);
226 auto localPolicy
= g_policy
.getLocal();
227 auto localRulactions
= g_rulactions
.getLocal();
228 auto localRespRulactions
= g_resprulactions
.getLocal();
229 auto localCacheHitRespRulactions
= g_cachehitresprulactions
.getLocal();
230 auto localDynBlockNMG
= g_dynblockNMG
.getLocal();
231 auto localDynBlockSMT
= g_dynblockSMT
.getLocal();
232 auto localPools
= g_pools
.getLocal();
234 boost::uuids::random_generator uuidGenerator
;
237 /* when the answer is encrypted in place, we need to get a copy
238 of the original header before encryption to fill the ring buffer */
242 map
<ComboAddress
,int> sockets
;
244 ConnectionInfo
* citmp
, ci
;
247 readn2(pipefd
, &citmp
, sizeof(citmp
));
249 catch(const std::runtime_error
& e
) {
250 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd
) + ") in " + std::string(isNonBlocking(pipefd
) ? "non-blocking" : "blocking") + " mode: " + e
.what());
253 g_tcpclientthreads
->decrementQueuedCount();
259 vector
<uint8_t> rewrittenResponse
;
260 shared_ptr
<DownstreamState
> ds
;
262 memset(&dest
, 0, sizeof(dest
));
263 dest
.sin4
.sin_family
= ci
.remote
.sin4
.sin_family
;
264 socklen_t len
= dest
.getSocklen();
265 size_t queriesCount
= 0;
266 time_t connectionStartTime
= time(NULL
);
268 if (!setNonBlocking(ci
.fd
))
271 if (getsockname(ci
.fd
, (sockaddr
*)&dest
, &len
)) {
277 unsigned int remainingTime
= 0;
281 if(!getNonBlockingMsgLen(ci
.fd
, &qlen
, g_tcpRecvTimeout
))
289 if (g_maxTCPQueriesPerConn
&& queriesCount
> g_maxTCPQueriesPerConn
) {
290 vinfolog("Terminating TCP connection from %s because it reached the maximum number of queries per conn (%d / %d)", ci
.remote
.toStringWithPort(), queriesCount
, g_maxTCPQueriesPerConn
);
294 if (maxConnectionDurationReached(g_maxTCPConnectionDuration
, connectionStartTime
, remainingTime
)) {
295 vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", ci
.remote
.toStringWithPort());
299 if (qlen
< sizeof(dnsheader
)) {
300 g_stats
.nonCompliantQueries
++;
304 bool ednsAdded
= false;
305 bool ecsAdded
= false;
306 /* if the query is small, allocate a bit more
307 memory to be able to spoof the content,
308 or to add ECS without allocating a new buffer */
309 size_t querySize
= qlen
<= 4096 ? qlen
+ 512 : qlen
;
310 char queryBuffer
[querySize
];
311 const char* query
= queryBuffer
;
312 readn2WithTimeout(ci
.fd
, queryBuffer
, qlen
, g_tcpRecvTimeout
, remainingTime
);
315 std::shared_ptr
<DnsCryptQuery
> dnsCryptQuery
= 0;
317 if (ci
.cs
->dnscryptCtx
) {
318 dnsCryptQuery
= std::make_shared
<DnsCryptQuery
>();
319 uint16_t decryptedQueryLen
= 0;
320 vector
<uint8_t> response
;
321 bool decrypted
= handleDnsCryptQuery(ci
.cs
->dnscryptCtx
, queryBuffer
, qlen
, dnsCryptQuery
, &decryptedQueryLen
, true, response
);
324 if (response
.size() > 0) {
325 sendResponseToClient(ci
.fd
, reinterpret_cast<char*>(response
.data()), (uint16_t) response
.size());
329 qlen
= decryptedQueryLen
;
332 struct dnsheader
* dh
= (struct dnsheader
*) query
;
334 if(dh
->qr
) { // don't respond to responses
335 g_stats
.nonCompliantQueries
++;
339 if(dh
->qdcount
== 0) {
340 g_stats
.emptyQueries
++;
348 const uint16_t* flags
= getFlagsFromDNSHeader(dh
);
349 uint16_t origFlags
= *flags
;
350 uint16_t qtype
, qclass
;
351 unsigned int consumed
= 0;
352 DNSName
qname(query
, qlen
, sizeof(dnsheader
), false, &qtype
, &qclass
, &consumed
);
353 DNSQuestion
dq(&qname
, qtype
, qclass
, &dest
, &ci
.remote
, (dnsheader
*)query
, querySize
, qlen
, true);
355 dq
.uniqueId
= uuidGenerator();
360 /* we need this one to be accurate ("real") for the protobuf message */
361 struct timespec queryRealTime
;
364 gettime(&queryRealTime
, true);
366 if (!processQuery(localDynBlockNMG
, localDynBlockSMT
, localRulactions
, dq
, poolname
, &delayMsec
, now
)) {
370 if(dq
.dh
->qr
) { // something turned it into a response
371 restoreFlags(dh
, origFlags
);
373 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
, nullptr, nullptr)) {
377 sendResponseToClient(ci
.fd
, query
, dq
.len
);
378 g_stats
.selfAnswered
++;
382 std::shared_ptr
<ServerPool
> serverPool
= getPool(*localPools
, poolname
);
383 std::shared_ptr
<DNSDistPacketCache
> packetCache
= nullptr;
384 auto policy
= localPolicy
->policy
;
385 if (serverPool
->policy
!= nullptr) {
386 policy
= serverPool
->policy
->policy
;
389 std::lock_guard
<std::mutex
> lock(g_luamutex
);
390 ds
= policy(serverPool
->servers
, &dq
);
391 packetCache
= serverPool
->packetCache
;
394 if (dq
.useECS
&& ds
&& ds
->useECS
) {
395 uint16_t newLen
= dq
.len
;
396 handleEDNSClientSubnet(queryBuffer
, dq
.size
, consumed
, &newLen
, largerQuery
, &ednsAdded
, &ecsAdded
, ci
.remote
, dq
.ecsOverride
, dq
.ecsPrefixLength
);
397 if (largerQuery
.empty() == false) {
398 query
= largerQuery
.c_str();
399 dq
.len
= (uint16_t) largerQuery
.size();
400 dq
.size
= largerQuery
.size();
406 uint32_t cacheKey
= 0;
407 if (packetCache
&& !dq
.skipCache
) {
408 char cachedResponse
[4096];
409 uint16_t cachedResponseSize
= sizeof cachedResponse
;
410 uint32_t allowExpired
= ds
? 0 : g_staleCacheEntriesTTL
;
411 if (packetCache
->get(dq
, (uint16_t) consumed
, dq
.dh
->id
, cachedResponse
, &cachedResponseSize
, &cacheKey
, allowExpired
)) {
412 DNSResponse
dr(dq
.qname
, dq
.qtype
, dq
.qclass
, dq
.local
, dq
.remote
, (dnsheader
*) cachedResponse
, sizeof cachedResponse
, cachedResponseSize
, true, &queryRealTime
);
414 dr
.uniqueId
= dq
.uniqueId
;
416 if (!processResponse(localCacheHitRespRulactions
, dr
, &delayMsec
)) {
421 if (!encryptResponse(cachedResponse
, &cachedResponseSize
, sizeof cachedResponse
, true, dnsCryptQuery
, nullptr, nullptr)) {
425 sendResponseToClient(ci
.fd
, cachedResponse
, cachedResponseSize
);
429 g_stats
.cacheMisses
++;
435 if (g_servFailOnNoPolicy
) {
436 restoreFlags(dh
, origFlags
);
437 dq
.dh
->rcode
= RCode::ServFail
;
441 if (!encryptResponse(queryBuffer
, &dq
.len
, dq
.size
, true, dnsCryptQuery
, nullptr, nullptr)) {
445 sendResponseToClient(ci
.fd
, query
, dq
.len
);
452 uint16_t downstreamFailures
=0;
454 bool freshConn
= true;
455 #endif /* MSG_FASTOPEN */
456 if(sockets
.count(ds
->remote
) == 0) {
457 dsock
=setupTCPDownstream(ds
, downstreamFailures
);
458 sockets
[ds
->remote
]=dsock
;
461 dsock
=sockets
[ds
->remote
];
464 #endif /* MSG_FASTOPEN */
473 sockets
.erase(ds
->remote
);
477 if (ds
->retries
> 0 && downstreamFailures
> ds
->retries
) {
478 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds
->getName(), downstreamFailures
);
481 sockets
.erase(ds
->remote
);
488 if (ds
->tcpFastOpen
&& freshConn
) {
489 socketFlags
|= MSG_FASTOPEN
;
491 #endif /* MSG_FASTOPEN */
492 sendSizeAndMsgWithTimeout(dsock
, dq
.len
, query
, ds
->tcpSendTimeout
, &ds
->remote
, &ds
->sourceAddr
, ds
->sourceItf
, 0, socketFlags
);
494 catch(const runtime_error
& e
) {
495 vinfolog("Downstream connection to %s died on us, getting a new one!", ds
->getName());
498 sockets
.erase(ds
->remote
);
499 downstreamFailures
++;
500 dsock
=setupTCPDownstream(ds
, downstreamFailures
);
501 sockets
[ds
->remote
]=dsock
;
504 #endif /* MSG_FASTOPEN */
508 bool xfrStarted
= false;
509 bool isXFR
= (dq
.qtype
== QType::AXFR
|| dq
.qtype
== QType::IXFR
);
516 if(!getNonBlockingMsgLen(dsock
, &rlen
, ds
->tcpRecvTimeout
)) {
517 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds
->getName());
520 sockets
.erase(ds
->remote
);
521 downstreamFailures
++;
522 dsock
=setupTCPDownstream(ds
, downstreamFailures
);
523 sockets
[ds
->remote
]=dsock
;
526 #endif /* MSG_FASTOPEN */
533 size_t responseSize
= rlen
;
534 uint16_t addRoom
= 0;
536 if (dnsCryptQuery
&& (UINT16_MAX
- rlen
) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
) {
537 addRoom
= DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE
;
540 responseSize
+= addRoom
;
541 char answerbuffer
[responseSize
];
542 readn2WithTimeout(dsock
, answerbuffer
, rlen
, ds
->tcpRecvTimeout
);
543 char* response
= answerbuffer
;
544 uint16_t responseLen
= rlen
;
546 /* might be false for {A,I}XFR */
551 if (rlen
< sizeof(dnsheader
)) {
555 if (!responseContentMatches(response
, responseLen
, qname
, qtype
, qclass
, ds
->remote
)) {
559 if (!fixUpResponse(&response
, &responseLen
, &responseSize
, qname
, origFlags
, ednsAdded
, ecsAdded
, rewrittenResponse
, addRoom
)) {
563 dh
= (struct dnsheader
*) response
;
564 DNSResponse
dr(&qname
, qtype
, qclass
, &dest
, &ci
.remote
, dh
, responseSize
, responseLen
, true, &queryRealTime
);
566 dr
.uniqueId
= dq
.uniqueId
;
568 if (!processResponse(localRespRulactions
, dr
, &delayMsec
)) {
572 if (packetCache
&& !dq
.skipCache
) {
573 packetCache
->insert(cacheKey
, qname
, qtype
, qclass
, response
, responseLen
, true, dh
->rcode
);
577 if (!encryptResponse(response
, &responseLen
, responseSize
, true, dnsCryptQuery
, &dh
, &dhCopy
)) {
581 if (!sendResponseToClient(ci
.fd
, response
, responseLen
)) {
586 if (dh
->rcode
== 0 && dh
->ancount
!= 0) {
587 if (xfrStarted
== false) {
589 if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 1) {
593 else if (getRecordsOfTypeCount(response
, responseLen
, 1, QType::SOA
) == 0) {
597 /* Don't reuse the TCP connection after an {A,I}XFR */
600 sockets
.erase(ds
->remote
);
604 struct timespec answertime
;
605 gettime(&answertime
);
606 unsigned int udiff
= 1000000.0*DiffTime(now
,answertime
);
608 std::lock_guard
<std::mutex
> lock(g_rings
.respMutex
);
609 g_rings
.respRing
.push_back({answertime
, ci
.remote
, qname
, dq
.qtype
, (unsigned int)udiff
, (unsigned int)responseLen
, *dh
, ds
->remote
});
613 rewrittenResponse
.clear();
620 vinfolog("Closing TCP client connection with %s", ci
.remote
.toStringWithPort());
625 if (ds
&& outstanding
) {
629 decrementTCPClientCount(ci
.remote
);
631 if (g_downstreamTCPCleanupInterval
> 0 && (connectionStartTime
> (lastTCPCleanup
+ g_downstreamTCPCleanupInterval
))) {
632 cleanupClosedTCPConnections(sockets
);
633 lastTCPCleanup
= time(nullptr);
639 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
640 they will hand off to worker threads & spawn more of them if required
642 void* tcpAcceptorThread(void* p
)
644 ClientState
* cs
= (ClientState
*) p
;
645 bool tcpClientCountIncremented
= false;
647 remote
.sin4
.sin_family
= cs
->local
.sin4
.sin_family
;
649 g_tcpclientthreads
->addTCPClientThread();
651 auto acl
= g_ACL
.getLocal();
653 bool queuedCounterIncremented
= false;
654 ConnectionInfo
* ci
= nullptr;
655 tcpClientCountIncremented
= false;
657 ci
= new ConnectionInfo
;
660 ci
->fd
= SAccept(cs
->tcpFD
, remote
);
662 if(!acl
->match(remote
)) {
667 vinfolog("Dropped TCP connection from %s because of ACL", remote
.toStringWithPort());
671 if(g_maxTCPQueuedConnections
> 0 && g_tcpclientthreads
->getQueuedCount() >= g_maxTCPQueuedConnections
) {
675 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote
.toStringWithPort());
679 if (g_maxTCPConnectionsPerClient
) {
680 std::lock_guard
<std::mutex
> lock(tcpClientsCountMutex
);
682 if (tcpClientsCount
[remote
] >= g_maxTCPConnectionsPerClient
) {
686 vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote
.toStringWithPort());
689 tcpClientsCount
[remote
]++;
690 tcpClientCountIncremented
= true;
693 vinfolog("Got TCP connection from %s", remote
.toStringWithPort());
696 int pipe
= g_tcpclientthreads
->getThread();
698 queuedCounterIncremented
= true;
699 writen2WithTimeout(pipe
, &ci
, sizeof(ci
), 0);
702 g_tcpclientthreads
->decrementQueuedCount();
703 queuedCounterIncremented
= false;
707 if(tcpClientCountIncremented
) {
708 decrementTCPClientCount(remote
);
712 catch(std::exception
& e
) {
713 errlog("While reading a TCP question: %s", e
.what());
714 if(ci
&& ci
->fd
>= 0)
716 if(tcpClientCountIncremented
) {
717 decrementTCPClientCount(remote
);
721 if (queuedCounterIncremented
) {
722 g_tcpclientthreads
->decrementQueuedCount();
732 bool getMsgLen32(int fd
, uint32_t* len
)
736 size_t ret
= readn2(fd
, &raw
, sizeof raw
);
737 if(ret
!= sizeof raw
)
740 if(*len
> 10000000) // arbitrary 10MB limit
748 bool putMsgLen32(int fd
, uint32_t len
)
751 uint32_t raw
= htonl(len
);
752 size_t ret
= writen2(fd
, &raw
, sizeof raw
);
753 return ret
==sizeof raw
;