]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #4764 from rgacogne/dnsdist-tcp-workers-vect-race
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
26 #include "dolog.hh"
27 #include "lock.hh"
28 #include "gettime.hh"
29 #include <thread>
30 #include <atomic>
31
32 using std::thread;
33 using std::atomic;
34
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
38 we will not go there.
39
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
43
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
46
47 Let's start naively.
48 */
49
50 static int setupTCPDownstream(shared_ptr<DownstreamState> ds)
51 {
52 vinfolog("TCP connecting to downstream %s", ds->remote.toStringWithPort());
53 int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
54 try {
55 if (!IsAnyAddress(ds->sourceAddr)) {
56 SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
57 SBind(sock, ds->sourceAddr);
58 }
59 SConnect(sock, ds->remote);
60 setNonBlocking(sock);
61 }
62 catch(const std::runtime_error& e) {
63 /* don't leak our file descriptor if SConnect() (for example) throws */
64 close(sock);
65 throw;
66 }
67 return sock;
68 }
69
70 struct ConnectionInfo
71 {
72 int fd;
73 ComboAddress remote;
74 ClientState* cs;
75 };
76
77 uint64_t g_maxTCPQueuedConnections{1000};
78 void* tcpClientThread(int pipefd);
79
80 void TCPClientCollection::addTCPClientThread()
81 {
82 vinfolog("Adding TCP Client thread");
83
84 int pipefds[2] = { -1, -1};
85 if (pipe(pipefds) < 0) {
86 errlog("Error creating the TCP thread communication pipe: %s", strerror(errno));
87 return;
88 }
89
90 if (!setNonBlocking(pipefds[1])) {
91 close(pipefds[0]);
92 close(pipefds[1]);
93 errlog("Error setting the TCP thread communication pipe non-blocking: %s", strerror(errno));
94 return;
95 }
96
97 {
98 std::lock_guard<std::mutex> lock(d_mutex);
99
100 if (d_numthreads >= d_tcpclientthreads.capacity()) {
101 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
102 close(pipefds[0]);
103 close(pipefds[1]);
104 return;
105 }
106
107 try {
108 thread t1(tcpClientThread, pipefds[0]);
109 t1.detach();
110 }
111 catch(const std::runtime_error& e) {
112 /* the thread creation failed, don't leak */
113 errlog("Error creating a TCP thread: %s", e.what());
114 close(pipefds[0]);
115 close(pipefds[1]);
116 return;
117 }
118
119 d_tcpclientthreads.push_back(pipefds[1]);
120 }
121
122 ++d_numthreads;
123 }
124
125 static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
126 try
127 {
128 uint16_t raw;
129 size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
130 if(ret != sizeof raw)
131 return false;
132 *len = ntohs(raw);
133 return true;
134 }
135 catch(...) {
136 return false;
137 }
138
139 static bool putNonBlockingMsgLen(int fd, uint16_t len, int timeout)
140 try
141 {
142 uint16_t raw = htons(len);
143 size_t ret = writen2WithTimeout(fd, &raw, sizeof raw, timeout);
144 return ret == sizeof raw;
145 }
146 catch(...) {
147 return false;
148 }
149
150 static bool sendNonBlockingMsgLen(int fd, uint16_t len, int timeout, ComboAddress& dest, ComboAddress& local, unsigned int localItf)
151 try
152 {
153 if (localItf == 0)
154 return putNonBlockingMsgLen(fd, len, timeout);
155
156 uint16_t raw = htons(len);
157 ssize_t ret = sendMsgWithTimeout(fd, (char*) &raw, sizeof raw, timeout, dest, local, localItf);
158 return ret == sizeof raw;
159 }
160 catch(...) {
161 return false;
162 }
163
164 static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
165 {
166 if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
167 return false;
168
169 writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
170 return true;
171 }
172
173 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
174
175 void* tcpClientThread(int pipefd)
176 {
177 /* we get launched with a pipe on which we receive file descriptors from clients that we own
178 from that point on */
179
180 bool outstanding = false;
181 blockfilter_t blockFilter = 0;
182
183 {
184 std::lock_guard<std::mutex> lock(g_luamutex);
185 auto candidate = g_lua.readVariable<boost::optional<blockfilter_t> >("blockFilter");
186 if(candidate)
187 blockFilter = *candidate;
188 }
189
190 auto localPolicy = g_policy.getLocal();
191 auto localRulactions = g_rulactions.getLocal();
192 auto localRespRulactions = g_resprulactions.getLocal();
193 auto localDynBlockNMG = g_dynblockNMG.getLocal();
194 auto localDynBlockSMT = g_dynblockSMT.getLocal();
195 auto localPools = g_pools.getLocal();
196 #ifdef HAVE_PROTOBUF
197 boost::uuids::random_generator uuidGenerator;
198 #endif
199
200 map<ComboAddress,int> sockets;
201 for(;;) {
202 ConnectionInfo* citmp, ci;
203
204 try {
205 readn2(pipefd, &citmp, sizeof(citmp));
206 }
207 catch(const std::runtime_error& e) {
208 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
209 }
210
211 g_tcpclientthreads->decrementQueuedCount();
212 ci=*citmp;
213 delete citmp;
214
215 uint16_t qlen, rlen;
216 string largerQuery;
217 vector<uint8_t> rewrittenResponse;
218 shared_ptr<DownstreamState> ds;
219 ComboAddress dest;
220 memset(&dest, 0, sizeof(dest));
221 dest.sin4.sin_family = ci.remote.sin4.sin_family;
222 socklen_t len = dest.getSocklen();
223 if (!setNonBlocking(ci.fd))
224 goto drop;
225
226 if (getsockname(ci.fd, (sockaddr*)&dest, &len)) {
227 dest = ci.cs->local;
228 }
229
230 try {
231 for(;;) {
232 ds = nullptr;
233 outstanding = false;
234
235 if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
236 break;
237
238 ci.cs->queries++;
239 g_stats.queries++;
240
241 if (qlen < sizeof(dnsheader)) {
242 g_stats.nonCompliantQueries++;
243 break;
244 }
245
246 bool ednsAdded = false;
247 bool ecsAdded = false;
248 /* if the query is small, allocate a bit more
249 memory to be able to spoof the content,
250 or to add ECS without allocating a new buffer */
251 size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
252 char queryBuffer[querySize];
253 const char* query = queryBuffer;
254 readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
255
256 #ifdef HAVE_DNSCRYPT
257 std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
258
259 if (ci.cs->dnscryptCtx) {
260 dnsCryptQuery = std::make_shared<DnsCryptQuery>();
261 uint16_t decryptedQueryLen = 0;
262 vector<uint8_t> response;
263 bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
264
265 if (!decrypted) {
266 if (response.size() > 0) {
267 sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), (uint16_t) response.size());
268 }
269 break;
270 }
271 qlen = decryptedQueryLen;
272 }
273 #endif
274 struct dnsheader* dh = (struct dnsheader*) query;
275
276 if(dh->qr) { // don't respond to responses
277 g_stats.nonCompliantQueries++;
278 goto drop;
279 }
280
281 if(dh->qdcount == 0) {
282 g_stats.emptyQueries++;
283 goto drop;
284 }
285
286 if (dh->rd) {
287 g_stats.rdQueries++;
288 }
289
290 const uint16_t* flags = getFlagsFromDNSHeader(dh);
291 uint16_t origFlags = *flags;
292 uint16_t qtype, qclass;
293 unsigned int consumed = 0;
294 DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
295 DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, (dnsheader*)query, querySize, qlen, true);
296 #ifdef HAVE_PROTOBUF
297 dq.uniqueId = uuidGenerator();
298 #endif
299
300 string poolname;
301 int delayMsec=0;
302 /* we need this one to be accurate ("real") for the protobuf message */
303 struct timespec queryRealTime;
304 struct timespec now;
305 gettime(&now);
306 gettime(&queryRealTime, true);
307
308 if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, blockFilter, dq, poolname, &delayMsec, now)) {
309 goto drop;
310 }
311
312 if(dq.dh->qr) { // something turned it into a response
313 restoreFlags(dh, origFlags);
314 #ifdef HAVE_DNSCRYPT
315 if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
316 goto drop;
317 }
318 #endif
319 sendResponseToClient(ci.fd, query, dq.len);
320 g_stats.selfAnswered++;
321 goto drop;
322 }
323
324 std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
325 std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
326 {
327 std::lock_guard<std::mutex> lock(g_luamutex);
328 ds = localPolicy->policy(serverPool->servers, &dq);
329 packetCache = serverPool->packetCache;
330 }
331
332 if (dq.useECS && ds && ds->useECS) {
333 uint16_t newLen = dq.len;
334 handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength);
335 if (largerQuery.empty() == false) {
336 query = largerQuery.c_str();
337 dq.len = (uint16_t) largerQuery.size();
338 dq.size = largerQuery.size();
339 } else {
340 dq.len = newLen;
341 }
342 }
343
344 uint32_t cacheKey = 0;
345 if (packetCache && !dq.skipCache) {
346 char cachedResponse[4096];
347 uint16_t cachedResponseSize = sizeof cachedResponse;
348 uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
349 if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
350 #ifdef HAVE_DNSCRYPT
351 if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery)) {
352 goto drop;
353 }
354 #endif
355 sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
356 g_stats.cacheHits++;
357 goto drop;
358 }
359 g_stats.cacheMisses++;
360 }
361
362 if(!ds) {
363 g_stats.noPolicy++;
364
365 if (g_servFailOnNoPolicy) {
366 restoreFlags(dh, origFlags);
367 dq.dh->rcode = RCode::ServFail;
368 dq.dh->qr = true;
369
370 #ifdef HAVE_DNSCRYPT
371 if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
372 goto drop;
373 }
374 #endif
375 sendResponseToClient(ci.fd, query, dq.len);
376 }
377
378 break;
379 }
380
381 int dsock = -1;
382 if(sockets.count(ds->remote) == 0) {
383 dsock=setupTCPDownstream(ds);
384 sockets[ds->remote]=dsock;
385 }
386 else
387 dsock=sockets[ds->remote];
388
389 ds->queries++;
390 ds->outstanding++;
391 outstanding = true;
392
393 uint16_t downstream_failures=0;
394 retry:;
395 if (dsock < 0) {
396 sockets.erase(ds->remote);
397 break;
398 }
399
400 if (ds->retries > 0 && downstream_failures > ds->retries) {
401 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstream_failures);
402 close(dsock);
403 dsock=-1;
404 sockets.erase(ds->remote);
405 break;
406 }
407
408 if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
409 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
410 close(dsock);
411 dsock=-1;
412 sockets.erase(ds->remote);
413 dsock=setupTCPDownstream(ds);
414 sockets[ds->remote]=dsock;
415 downstream_failures++;
416 goto retry;
417 }
418
419 try {
420 if (ds->sourceItf == 0) {
421 writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout);
422 }
423 else {
424 sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
425 }
426 }
427 catch(const runtime_error& e) {
428 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
429 close(dsock);
430 dsock=-1;
431 sockets.erase(ds->remote);
432 dsock=setupTCPDownstream(ds);
433 sockets[ds->remote]=dsock;
434 downstream_failures++;
435 goto retry;
436 }
437
438 bool xfrStarted = false;
439 bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
440 if (isXFR) {
441 dq.skipCache = true;
442 }
443
444 getpacket:;
445
446 if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
447 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
448 close(dsock);
449 dsock=-1;
450 sockets.erase(ds->remote);
451 dsock=setupTCPDownstream(ds);
452 sockets[ds->remote]=dsock;
453 downstream_failures++;
454 if(xfrStarted) {
455 goto drop;
456 }
457 goto retry;
458 }
459
460 size_t responseSize = rlen;
461 uint16_t addRoom = 0;
462 #ifdef HAVE_DNSCRYPT
463 if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
464 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
465 }
466 #endif
467 responseSize += addRoom;
468 char answerbuffer[responseSize];
469 readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
470 char* response = answerbuffer;
471 uint16_t responseLen = rlen;
472 if (outstanding) {
473 /* might be false for {A,I}XFR */
474 --ds->outstanding;
475 outstanding = false;
476 }
477
478 if (rlen < sizeof(dnsheader)) {
479 break;
480 }
481
482 if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
483 break;
484 }
485
486 if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom)) {
487 break;
488 }
489
490 dh = (struct dnsheader*) response;
491 DNSResponse dr(&qname, qtype, qclass, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
492 #ifdef HAVE_PROTOBUF
493 dr.uniqueId = dq.uniqueId;
494 #endif
495 if (!processResponse(localRespRulactions, dr, &delayMsec)) {
496 break;
497 }
498
499 if (packetCache && !dq.skipCache) {
500 packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
501 }
502
503 #ifdef HAVE_DNSCRYPT
504 if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery)) {
505 goto drop;
506 }
507 #endif
508 if (!sendResponseToClient(ci.fd, response, responseLen)) {
509 break;
510 }
511
512 if (isXFR && dh->rcode == 0 && dh->ancount != 0) {
513 if (xfrStarted == false) {
514 xfrStarted = true;
515 if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
516 goto getpacket;
517 }
518 }
519 else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
520 goto getpacket;
521 }
522 }
523
524 g_stats.responses++;
525 struct timespec answertime;
526 gettime(&answertime);
527 unsigned int udiff = 1000000.0*DiffTime(now,answertime);
528 {
529 std::lock_guard<std::mutex> lock(g_rings.respMutex);
530 g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote});
531 }
532
533 largerQuery.clear();
534 rewrittenResponse.clear();
535 }
536 }
537 catch(...){}
538
539 drop:;
540
541 vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
542 close(ci.fd);
543 ci.fd=-1;
544 if (ds && outstanding) {
545 outstanding = false;
546 --ds->outstanding;
547 }
548 }
549 return 0;
550 }
551
552
553 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
554 they will hand off to worker threads & spawn more of them if required
555 */
556 void* tcpAcceptorThread(void* p)
557 {
558 ClientState* cs = (ClientState*) p;
559
560 ComboAddress remote;
561 remote.sin4.sin_family = cs->local.sin4.sin_family;
562
563 g_tcpclientthreads->addTCPClientThread();
564
565 auto acl = g_ACL.getLocal();
566 for(;;) {
567 bool queuedCounterIncremented = false;
568 ConnectionInfo* ci = nullptr;
569 try {
570 ci = new ConnectionInfo;
571 ci->cs = cs;
572 ci->fd = -1;
573 ci->fd = SAccept(cs->tcpFD, remote);
574
575 if(!acl->match(remote)) {
576 g_stats.aclDrops++;
577 close(ci->fd);
578 delete ci;
579 ci=nullptr;
580 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
581 continue;
582 }
583
584 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) {
585 close(ci->fd);
586 delete ci;
587 ci=nullptr;
588 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
589 continue;
590 }
591
592 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
593
594 ci->remote = remote;
595 int pipe = g_tcpclientthreads->getThread();
596 if (pipe >= 0) {
597 queuedCounterIncremented = true;
598 writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
599 }
600 else {
601 g_tcpclientthreads->decrementQueuedCount();
602 queuedCounterIncremented = false;
603 close(ci->fd);
604 delete ci;
605 ci=nullptr;
606 }
607 }
608 catch(std::exception& e) {
609 errlog("While reading a TCP question: %s", e.what());
610 if(ci && ci->fd >= 0)
611 close(ci->fd);
612 delete ci;
613 ci = nullptr;
614 if (queuedCounterIncremented) {
615 g_tcpclientthreads->decrementQueuedCount();
616 }
617 }
618 catch(...){}
619 }
620
621 return 0;
622 }
623
624
625 bool getMsgLen32(int fd, uint32_t* len)
626 try
627 {
628 uint32_t raw;
629 size_t ret = readn2(fd, &raw, sizeof raw);
630 if(ret != sizeof raw)
631 return false;
632 *len = ntohl(raw);
633 if(*len > 10000000) // arbitrary 10MB limit
634 return false;
635 return true;
636 }
637 catch(...) {
638 return false;
639 }
640
641 bool putMsgLen32(int fd, uint32_t len)
642 try
643 {
644 uint32_t raw = htonl(len);
645 size_t ret = writen2(fd, &raw, sizeof raw);
646 return ret==sizeof raw;
647 }
648 catch(...) {
649 return false;
650 }