]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #4042 from rgacogne/dnsdist-tcp-fast-open
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 PowerDNS Versatile Database Driven Nameserver
3 Copyright (C) 2013 - 2015 PowerDNS.COM BV
4
5 This program is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License version 2
7 as published by the Free Software Foundation
8
9 Additionally, the license of this program contains a special
10 exception which allows to distribute the program in binary form when
11 it is linked against OpenSSL.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with this program; if not, write to the Free Software
20 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
21 */
22
23 #include "dnsdist.hh"
24 #include "dnsdist-ecs.hh"
25 #include "dnsparser.hh"
26 #include "ednsoptions.hh"
27 #include "dolog.hh"
28 #include "lock.hh"
29 #include "gettime.hh"
30 #include <thread>
31 #include <atomic>
32
33 using std::thread;
34 using std::atomic;
35
36 /* TCP: the grand design.
37 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
38 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
39 we will not go there.
40
41 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
42 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
43 to guarantee performance.
44
45 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
46 So whenever an answer comes in, we know where it needs to go.
47
48 Let's start naively.
49 */
50
51 static int setupTCPDownstream(shared_ptr<DownstreamState> ds)
52 {
53 vinfolog("TCP connecting to downstream %s", ds->remote.toStringWithPort());
54 int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
55 if (!IsAnyAddress(ds->sourceAddr)) {
56 SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
57 SBind(sock, ds->sourceAddr);
58 }
59 SConnect(sock, ds->remote);
60 setNonBlocking(sock);
61 return sock;
62 }
63
64 struct ConnectionInfo
65 {
66 int fd;
67 ComboAddress remote;
68 ClientState* cs;
69 };
70
71 uint64_t g_maxTCPQueuedConnections{0};
72 void* tcpClientThread(int pipefd);
73
74 // Should not be called simultaneously!
75 void TCPClientCollection::addTCPClientThread()
76 {
77 if (d_numthreads >= d_tcpclientthreads.capacity()) {
78 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
79 return;
80 }
81
82 vinfolog("Adding TCP Client thread");
83
84 int pipefds[2] = { -1, -1};
85 if(pipe(pipefds) < 0)
86 unixDie("Creating pipe");
87
88 if (!setNonBlocking(pipefds[1])) {
89 close(pipefds[0]);
90 close(pipefds[1]);
91 unixDie("Setting pipe non-blocking");
92 }
93
94 d_tcpclientthreads.push_back(pipefds[1]);
95 ++d_numthreads;
96 thread t1(tcpClientThread, pipefds[0]);
97 t1.detach();
98 }
99
100 static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
101 try
102 {
103 uint16_t raw;
104 size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
105 if(ret != sizeof raw)
106 return false;
107 *len = ntohs(raw);
108 return true;
109 }
110 catch(...) {
111 return false;
112 }
113
114 static bool putNonBlockingMsgLen(int fd, uint16_t len, int timeout)
115 try
116 {
117 uint16_t raw = htons(len);
118 size_t ret = writen2WithTimeout(fd, &raw, sizeof raw, timeout);
119 return ret == sizeof raw;
120 }
121 catch(...) {
122 return false;
123 }
124
125 static bool sendNonBlockingMsgLen(int fd, uint16_t len, int timeout, ComboAddress& dest, ComboAddress& local, unsigned int localItf)
126 try
127 {
128 if (localItf == 0)
129 return putNonBlockingMsgLen(fd, len, timeout);
130
131 uint16_t raw = htons(len);
132 ssize_t ret = sendMsgWithTimeout(fd, (char*) &raw, sizeof raw, timeout, dest, local, localItf);
133 return ret == sizeof raw;
134 }
135 catch(...) {
136 return false;
137 }
138
139 static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
140 {
141 if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
142 return false;
143
144 writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
145 return true;
146 }
147
148 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
149
150 void* tcpClientThread(int pipefd)
151 {
152 /* we get launched with a pipe on which we receive file descriptors from clients that we own
153 from that point on */
154
155 bool outstanding = false;
156 blockfilter_t blockFilter = 0;
157
158 {
159 std::lock_guard<std::mutex> lock(g_luamutex);
160 auto candidate = g_lua.readVariable<boost::optional<blockfilter_t> >("blockFilter");
161 if(candidate)
162 blockFilter = *candidate;
163 }
164
165 auto localPolicy = g_policy.getLocal();
166 auto localRulactions = g_rulactions.getLocal();
167 auto localRespRulactions = g_resprulactions.getLocal();
168 auto localDynBlockNMG = g_dynblockNMG.getLocal();
169 auto localDynBlockSMT = g_dynblockSMT.getLocal();
170 auto localPools = g_pools.getLocal();
171 #ifdef HAVE_PROTOBUF
172 boost::uuids::random_generator uuidGenerator;
173 #endif
174
175 map<ComboAddress,int> sockets;
176 for(;;) {
177 ConnectionInfo* citmp, ci;
178
179 try {
180 readn2(pipefd, &citmp, sizeof(citmp));
181 }
182 catch(const std::runtime_error& e) {
183 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
184 }
185
186 --g_tcpclientthreads->d_queued;
187 ci=*citmp;
188 delete citmp;
189
190 uint16_t qlen, rlen;
191 string largerQuery;
192 vector<uint8_t> rewrittenResponse;
193 shared_ptr<DownstreamState> ds;
194 if (!setNonBlocking(ci.fd))
195 goto drop;
196
197 try {
198 for(;;) {
199 ds = nullptr;
200 outstanding = false;
201
202 if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
203 break;
204
205 ci.cs->queries++;
206 g_stats.queries++;
207
208 if (qlen < sizeof(dnsheader)) {
209 g_stats.nonCompliantQueries++;
210 break;
211 }
212
213 bool ednsAdded = false;
214 bool ecsAdded = false;
215 /* if the query is small, allocate a bit more
216 memory to be able to spoof the content,
217 or to add ECS without allocating a new buffer */
218 size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
219 char queryBuffer[querySize];
220 const char* query = queryBuffer;
221 readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
222
223 #ifdef HAVE_DNSCRYPT
224 std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
225
226 if (ci.cs->dnscryptCtx) {
227 dnsCryptQuery = std::make_shared<DnsCryptQuery>();
228 uint16_t decryptedQueryLen = 0;
229 vector<uint8_t> response;
230 bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
231
232 if (!decrypted) {
233 if (response.size() > 0) {
234 sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), (uint16_t) response.size());
235 }
236 break;
237 }
238 qlen = decryptedQueryLen;
239 }
240 #endif
241 struct dnsheader* dh = (struct dnsheader*) query;
242
243 if(dh->qr) { // don't respond to responses
244 g_stats.nonCompliantQueries++;
245 goto drop;
246 }
247
248 if(dh->qdcount == 0) {
249 g_stats.emptyQueries++;
250 goto drop;
251 }
252
253 if (dh->rd) {
254 g_stats.rdQueries++;
255 }
256
257 const uint16_t* flags = getFlagsFromDNSHeader(dh);
258 uint16_t origFlags = *flags;
259 uint16_t qtype, qclass;
260 unsigned int consumed = 0;
261 DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
262 DNSQuestion dq(&qname, qtype, qclass, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true);
263 #ifdef HAVE_PROTOBUF
264 dq.uniqueId = uuidGenerator();
265 #endif
266
267 string poolname;
268 int delayMsec=0;
269 struct timespec now;
270 gettime(&now, true);
271
272 if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, blockFilter, dq, poolname, &delayMsec, now)) {
273 goto drop;
274 }
275
276 if(dq.dh->qr) { // something turned it into a response
277 restoreFlags(dh, origFlags);
278 #ifdef HAVE_DNSCRYPT
279 if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
280 goto drop;
281 }
282 #endif
283 sendResponseToClient(ci.fd, query, dq.len);
284 g_stats.selfAnswered++;
285 goto drop;
286 }
287
288 std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
289 std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
290 {
291 std::lock_guard<std::mutex> lock(g_luamutex);
292 ds = localPolicy->policy(serverPool->servers, &dq);
293 packetCache = serverPool->packetCache;
294 }
295
296 if (ds && ds->useECS) {
297 uint16_t newLen = dq.len;
298 handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote);
299 if (largerQuery.empty() == false) {
300 query = largerQuery.c_str();
301 dq.len = (uint16_t) largerQuery.size();
302 dq.size = largerQuery.size();
303 } else {
304 dq.len = newLen;
305 }
306 }
307
308 uint32_t cacheKey = 0;
309 if (packetCache && !dq.skipCache) {
310 char cachedResponse[4096];
311 uint16_t cachedResponseSize = sizeof cachedResponse;
312 uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
313 if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
314 #ifdef HAVE_DNSCRYPT
315 if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery)) {
316 goto drop;
317 }
318 #endif
319 sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
320 g_stats.cacheHits++;
321 goto drop;
322 }
323 g_stats.cacheMisses++;
324 }
325
326 if(!ds) {
327 g_stats.noPolicy++;
328 break;
329 }
330
331 int dsock = -1;
332 if(sockets.count(ds->remote) == 0) {
333 dsock=sockets[ds->remote]=setupTCPDownstream(ds);
334 }
335 else
336 dsock=sockets[ds->remote];
337
338 ds->queries++;
339 ds->outstanding++;
340 outstanding = true;
341
342 uint16_t downstream_failures=0;
343 retry:;
344 if (dsock < 0) {
345 sockets.erase(ds->remote);
346 break;
347 }
348
349 if (ds->retries > 0 && downstream_failures > ds->retries) {
350 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstream_failures);
351 close(dsock);
352 dsock=-1;
353 sockets.erase(ds->remote);
354 break;
355 }
356
357 if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
358 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
359 close(dsock);
360 dsock=-1;
361 sockets.erase(ds->remote);
362 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
363 downstream_failures++;
364 goto retry;
365 }
366
367 try {
368 if (ds->sourceItf == 0) {
369 writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout);
370 }
371 else {
372 sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
373 }
374 }
375 catch(const runtime_error& e) {
376 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
377 close(dsock);
378 dsock=-1;
379 sockets.erase(ds->remote);
380 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
381 downstream_failures++;
382 goto retry;
383 }
384
385 bool xfrStarted = false;
386 bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
387 if (isXFR) {
388 dq.skipCache = true;
389 }
390
391 getpacket:;
392
393 if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
394 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
395 close(dsock);
396 dsock=-1;
397 sockets.erase(ds->remote);
398 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
399 downstream_failures++;
400 if(xfrStarted) {
401 goto drop;
402 }
403 goto retry;
404 }
405
406 size_t responseSize = rlen;
407 uint16_t addRoom = 0;
408 #ifdef HAVE_DNSCRYPT
409 if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
410 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
411 }
412 #endif
413 responseSize += addRoom;
414 char answerbuffer[responseSize];
415 readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
416 char* response = answerbuffer;
417 uint16_t responseLen = rlen;
418 --ds->outstanding;
419 outstanding = false;
420
421 if (rlen < sizeof(dnsheader)) {
422 break;
423 }
424
425 if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
426 break;
427 }
428
429 if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom)) {
430 break;
431 }
432
433 dh = (struct dnsheader*) response;
434 DNSResponse dr(&qname, qtype, qclass, &ci.cs->local, &ci.remote, dh, responseSize, responseLen, true, &now);
435 #ifdef HAVE_PROTOBUF
436 dr.uniqueId = dq.uniqueId;
437 #endif
438 if (!processResponse(localRespRulactions, dr)) {
439 break;
440 }
441
442 if (packetCache && !dq.skipCache) {
443 packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
444 }
445
446 #ifdef HAVE_DNSCRYPT
447 if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery)) {
448 goto drop;
449 }
450 #endif
451 if (!sendResponseToClient(ci.fd, response, responseLen)) {
452 break;
453 }
454
455 if (isXFR && dh->rcode == 0 && dh->ancount != 0) {
456 if (xfrStarted == false) {
457 xfrStarted = true;
458 if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
459 goto getpacket;
460 }
461 }
462 else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
463 goto getpacket;
464 }
465 }
466
467 g_stats.responses++;
468 struct timespec answertime;
469 gettime(&answertime);
470 unsigned int udiff = 1000000.0*DiffTime(now,answertime);
471 {
472 std::lock_guard<std::mutex> lock(g_rings.respMutex);
473 g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dq.dh, ds->remote});
474 }
475
476 largerQuery.clear();
477 rewrittenResponse.clear();
478 }
479 }
480 catch(...){}
481
482 drop:;
483
484 vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
485 close(ci.fd);
486 ci.fd=-1;
487 if (ds && outstanding) {
488 outstanding = false;
489 --ds->outstanding;
490 }
491 }
492 return 0;
493 }
494
495
496 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
497 they will hand off to worker threads & spawn more of them if required
498 */
499 void* tcpAcceptorThread(void* p)
500 {
501 ClientState* cs = (ClientState*) p;
502
503 ComboAddress remote;
504 remote.sin4.sin_family = cs->local.sin4.sin_family;
505
506 g_tcpclientthreads->addTCPClientThread();
507
508 auto acl = g_ACL.getLocal();
509 for(;;) {
510 ConnectionInfo* ci;
511 try {
512 ci=0;
513 ci = new ConnectionInfo;
514 ci->cs = cs;
515 ci->fd = -1;
516 ci->fd = SAccept(cs->tcpFD, remote);
517
518 if(!acl->match(remote)) {
519 g_stats.aclDrops++;
520 close(ci->fd);
521 delete ci;
522 ci=0;
523 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
524 continue;
525 }
526
527 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) {
528 close(ci->fd);
529 delete ci;
530 ci=nullptr;
531 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
532 continue;
533 }
534
535 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
536
537 ci->remote = remote;
538 int pipe = g_tcpclientthreads->getThread();
539 if (pipe >= 0) {
540 writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
541 }
542 else {
543 --g_tcpclientthreads->d_queued;
544 close(ci->fd);
545 delete ci;
546 ci=nullptr;
547 }
548 }
549 catch(std::exception& e) {
550 errlog("While reading a TCP question: %s", e.what());
551 if(ci && ci->fd >= 0)
552 close(ci->fd);
553 delete ci;
554 }
555 catch(...){}
556 }
557
558 return 0;
559 }
560
561
562 bool getMsgLen32(int fd, uint32_t* len)
563 try
564 {
565 uint32_t raw;
566 size_t ret = readn2(fd, &raw, sizeof raw);
567 if(ret != sizeof raw)
568 return false;
569 *len = ntohl(raw);
570 if(*len > 10000000) // arbitrary 10MB limit
571 return false;
572 return true;
573 }
574 catch(...) {
575 return false;
576 }
577
578 bool putMsgLen32(int fd, uint32_t len)
579 try
580 {
581 uint32_t raw = htonl(len);
582 size_t ret = writen2(fd, &raw, sizeof raw);
583 return ret==sizeof raw;
584 }
585 catch(...) {
586 return false;
587 }