]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #4650 from zeha/api-multibackend-serial0
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
26 #include "dolog.hh"
27 #include "lock.hh"
28 #include "gettime.hh"
29 #include <thread>
30 #include <atomic>
31
32 using std::thread;
33 using std::atomic;
34
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
38 we will not go there.
39
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
43
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
46
47 Let's start naively.
48 */
49
50 static int setupTCPDownstream(shared_ptr<DownstreamState> ds)
51 {
52 vinfolog("TCP connecting to downstream %s", ds->remote.toStringWithPort());
53 int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
54 if (!IsAnyAddress(ds->sourceAddr)) {
55 SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
56 SBind(sock, ds->sourceAddr);
57 }
58 SConnect(sock, ds->remote);
59 setNonBlocking(sock);
60 return sock;
61 }
62
63 struct ConnectionInfo
64 {
65 int fd;
66 ComboAddress remote;
67 ClientState* cs;
68 };
69
70 uint64_t g_maxTCPQueuedConnections{1000};
71 void* tcpClientThread(int pipefd);
72
73 // Should not be called simultaneously!
74 void TCPClientCollection::addTCPClientThread()
75 {
76 if (d_numthreads >= d_tcpclientthreads.capacity()) {
77 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
78 return;
79 }
80
81 vinfolog("Adding TCP Client thread");
82
83 int pipefds[2] = { -1, -1};
84 if(pipe(pipefds) < 0)
85 unixDie("Creating pipe");
86
87 if (!setNonBlocking(pipefds[1])) {
88 close(pipefds[0]);
89 close(pipefds[1]);
90 unixDie("Setting pipe non-blocking");
91 }
92
93 d_tcpclientthreads.push_back(pipefds[1]);
94 ++d_numthreads;
95 thread t1(tcpClientThread, pipefds[0]);
96 t1.detach();
97 }
98
99 static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
100 try
101 {
102 uint16_t raw;
103 size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
104 if(ret != sizeof raw)
105 return false;
106 *len = ntohs(raw);
107 return true;
108 }
109 catch(...) {
110 return false;
111 }
112
113 static bool putNonBlockingMsgLen(int fd, uint16_t len, int timeout)
114 try
115 {
116 uint16_t raw = htons(len);
117 size_t ret = writen2WithTimeout(fd, &raw, sizeof raw, timeout);
118 return ret == sizeof raw;
119 }
120 catch(...) {
121 return false;
122 }
123
124 static bool sendNonBlockingMsgLen(int fd, uint16_t len, int timeout, ComboAddress& dest, ComboAddress& local, unsigned int localItf)
125 try
126 {
127 if (localItf == 0)
128 return putNonBlockingMsgLen(fd, len, timeout);
129
130 uint16_t raw = htons(len);
131 ssize_t ret = sendMsgWithTimeout(fd, (char*) &raw, sizeof raw, timeout, dest, local, localItf);
132 return ret == sizeof raw;
133 }
134 catch(...) {
135 return false;
136 }
137
138 static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
139 {
140 if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
141 return false;
142
143 writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
144 return true;
145 }
146
147 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
148
149 void* tcpClientThread(int pipefd)
150 {
151 /* we get launched with a pipe on which we receive file descriptors from clients that we own
152 from that point on */
153
154 bool outstanding = false;
155 blockfilter_t blockFilter = 0;
156
157 {
158 std::lock_guard<std::mutex> lock(g_luamutex);
159 auto candidate = g_lua.readVariable<boost::optional<blockfilter_t> >("blockFilter");
160 if(candidate)
161 blockFilter = *candidate;
162 }
163
164 auto localPolicy = g_policy.getLocal();
165 auto localRulactions = g_rulactions.getLocal();
166 auto localRespRulactions = g_resprulactions.getLocal();
167 auto localDynBlockNMG = g_dynblockNMG.getLocal();
168 auto localDynBlockSMT = g_dynblockSMT.getLocal();
169 auto localPools = g_pools.getLocal();
170 #ifdef HAVE_PROTOBUF
171 boost::uuids::random_generator uuidGenerator;
172 #endif
173
174 map<ComboAddress,int> sockets;
175 for(;;) {
176 ConnectionInfo* citmp, ci;
177
178 try {
179 readn2(pipefd, &citmp, sizeof(citmp));
180 }
181 catch(const std::runtime_error& e) {
182 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
183 }
184
185 --g_tcpclientthreads->d_queued;
186 ci=*citmp;
187 delete citmp;
188
189 uint16_t qlen, rlen;
190 string largerQuery;
191 vector<uint8_t> rewrittenResponse;
192 shared_ptr<DownstreamState> ds;
193 if (!setNonBlocking(ci.fd))
194 goto drop;
195
196 try {
197 for(;;) {
198 ds = nullptr;
199 outstanding = false;
200
201 if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
202 break;
203
204 ci.cs->queries++;
205 g_stats.queries++;
206
207 if (qlen < sizeof(dnsheader)) {
208 g_stats.nonCompliantQueries++;
209 break;
210 }
211
212 bool ednsAdded = false;
213 bool ecsAdded = false;
214 /* if the query is small, allocate a bit more
215 memory to be able to spoof the content,
216 or to add ECS without allocating a new buffer */
217 size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
218 char queryBuffer[querySize];
219 const char* query = queryBuffer;
220 readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
221
222 #ifdef HAVE_DNSCRYPT
223 std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
224
225 if (ci.cs->dnscryptCtx) {
226 dnsCryptQuery = std::make_shared<DnsCryptQuery>();
227 uint16_t decryptedQueryLen = 0;
228 vector<uint8_t> response;
229 bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
230
231 if (!decrypted) {
232 if (response.size() > 0) {
233 sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), (uint16_t) response.size());
234 }
235 break;
236 }
237 qlen = decryptedQueryLen;
238 }
239 #endif
240 struct dnsheader* dh = (struct dnsheader*) query;
241
242 if(dh->qr) { // don't respond to responses
243 g_stats.nonCompliantQueries++;
244 goto drop;
245 }
246
247 if(dh->qdcount == 0) {
248 g_stats.emptyQueries++;
249 goto drop;
250 }
251
252 if (dh->rd) {
253 g_stats.rdQueries++;
254 }
255
256 const uint16_t* flags = getFlagsFromDNSHeader(dh);
257 uint16_t origFlags = *flags;
258 uint16_t qtype, qclass;
259 unsigned int consumed = 0;
260 DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
261 DNSQuestion dq(&qname, qtype, qclass, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true);
262 #ifdef HAVE_PROTOBUF
263 dq.uniqueId = uuidGenerator();
264 #endif
265
266 string poolname;
267 int delayMsec=0;
268 /* we need this one to be accurate ("real") for the protobuf message */
269 struct timespec queryRealTime;
270 struct timespec now;
271 gettime(&now);
272 gettime(&queryRealTime, true);
273
274 if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, blockFilter, dq, poolname, &delayMsec, now)) {
275 goto drop;
276 }
277
278 if(dq.dh->qr) { // something turned it into a response
279 restoreFlags(dh, origFlags);
280 #ifdef HAVE_DNSCRYPT
281 if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
282 goto drop;
283 }
284 #endif
285 sendResponseToClient(ci.fd, query, dq.len);
286 g_stats.selfAnswered++;
287 goto drop;
288 }
289
290 std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
291 std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
292 {
293 std::lock_guard<std::mutex> lock(g_luamutex);
294 ds = localPolicy->policy(serverPool->servers, &dq);
295 packetCache = serverPool->packetCache;
296 }
297
298 if (dq.useECS && ds && ds->useECS) {
299 uint16_t newLen = dq.len;
300 handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength);
301 if (largerQuery.empty() == false) {
302 query = largerQuery.c_str();
303 dq.len = (uint16_t) largerQuery.size();
304 dq.size = largerQuery.size();
305 } else {
306 dq.len = newLen;
307 }
308 }
309
310 uint32_t cacheKey = 0;
311 if (packetCache && !dq.skipCache) {
312 char cachedResponse[4096];
313 uint16_t cachedResponseSize = sizeof cachedResponse;
314 uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
315 if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
316 #ifdef HAVE_DNSCRYPT
317 if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery)) {
318 goto drop;
319 }
320 #endif
321 sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
322 g_stats.cacheHits++;
323 goto drop;
324 }
325 g_stats.cacheMisses++;
326 }
327
328 if(!ds) {
329 g_stats.noPolicy++;
330 break;
331 }
332
333 int dsock = -1;
334 if(sockets.count(ds->remote) == 0) {
335 dsock=sockets[ds->remote]=setupTCPDownstream(ds);
336 }
337 else
338 dsock=sockets[ds->remote];
339
340 ds->queries++;
341 ds->outstanding++;
342 outstanding = true;
343
344 uint16_t downstream_failures=0;
345 retry:;
346 if (dsock < 0) {
347 sockets.erase(ds->remote);
348 break;
349 }
350
351 if (ds->retries > 0 && downstream_failures > ds->retries) {
352 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstream_failures);
353 close(dsock);
354 dsock=-1;
355 sockets.erase(ds->remote);
356 break;
357 }
358
359 if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
360 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
361 close(dsock);
362 dsock=-1;
363 sockets.erase(ds->remote);
364 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
365 downstream_failures++;
366 goto retry;
367 }
368
369 try {
370 if (ds->sourceItf == 0) {
371 writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout);
372 }
373 else {
374 sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
375 }
376 }
377 catch(const runtime_error& e) {
378 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
379 close(dsock);
380 dsock=-1;
381 sockets.erase(ds->remote);
382 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
383 downstream_failures++;
384 goto retry;
385 }
386
387 bool xfrStarted = false;
388 bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
389 if (isXFR) {
390 dq.skipCache = true;
391 }
392
393 getpacket:;
394
395 if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
396 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
397 close(dsock);
398 dsock=-1;
399 sockets.erase(ds->remote);
400 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
401 downstream_failures++;
402 if(xfrStarted) {
403 goto drop;
404 }
405 goto retry;
406 }
407
408 size_t responseSize = rlen;
409 uint16_t addRoom = 0;
410 #ifdef HAVE_DNSCRYPT
411 if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
412 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
413 }
414 #endif
415 responseSize += addRoom;
416 char answerbuffer[responseSize];
417 readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
418 char* response = answerbuffer;
419 uint16_t responseLen = rlen;
420 if (outstanding) {
421 /* might be false for {A,I}XFR */
422 --ds->outstanding;
423 outstanding = false;
424 }
425
426 if (rlen < sizeof(dnsheader)) {
427 break;
428 }
429
430 if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
431 break;
432 }
433
434 if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom)) {
435 break;
436 }
437
438 dh = (struct dnsheader*) response;
439 DNSResponse dr(&qname, qtype, qclass, &ci.cs->local, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
440 #ifdef HAVE_PROTOBUF
441 dr.uniqueId = dq.uniqueId;
442 #endif
443 if (!processResponse(localRespRulactions, dr, &delayMsec)) {
444 break;
445 }
446
447 if (packetCache && !dq.skipCache) {
448 packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
449 }
450
451 #ifdef HAVE_DNSCRYPT
452 if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery)) {
453 goto drop;
454 }
455 #endif
456 if (!sendResponseToClient(ci.fd, response, responseLen)) {
457 break;
458 }
459
460 if (isXFR && dh->rcode == 0 && dh->ancount != 0) {
461 if (xfrStarted == false) {
462 xfrStarted = true;
463 if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
464 goto getpacket;
465 }
466 }
467 else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
468 goto getpacket;
469 }
470 }
471
472 g_stats.responses++;
473 struct timespec answertime;
474 gettime(&answertime);
475 unsigned int udiff = 1000000.0*DiffTime(now,answertime);
476 {
477 std::lock_guard<std::mutex> lock(g_rings.respMutex);
478 g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote});
479 }
480
481 largerQuery.clear();
482 rewrittenResponse.clear();
483 }
484 }
485 catch(...){}
486
487 drop:;
488
489 vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
490 close(ci.fd);
491 ci.fd=-1;
492 if (ds && outstanding) {
493 outstanding = false;
494 --ds->outstanding;
495 }
496 }
497 return 0;
498 }
499
500
501 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
502 they will hand off to worker threads & spawn more of them if required
503 */
504 void* tcpAcceptorThread(void* p)
505 {
506 ClientState* cs = (ClientState*) p;
507
508 ComboAddress remote;
509 remote.sin4.sin_family = cs->local.sin4.sin_family;
510
511 g_tcpclientthreads->addTCPClientThread();
512
513 auto acl = g_ACL.getLocal();
514 for(;;) {
515 ConnectionInfo* ci;
516 try {
517 ci=0;
518 ci = new ConnectionInfo;
519 ci->cs = cs;
520 ci->fd = -1;
521 ci->fd = SAccept(cs->tcpFD, remote);
522
523 if(!acl->match(remote)) {
524 g_stats.aclDrops++;
525 close(ci->fd);
526 delete ci;
527 ci=0;
528 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
529 continue;
530 }
531
532 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) {
533 close(ci->fd);
534 delete ci;
535 ci=nullptr;
536 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
537 continue;
538 }
539
540 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
541
542 ci->remote = remote;
543 int pipe = g_tcpclientthreads->getThread();
544 if (pipe >= 0) {
545 writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
546 }
547 else {
548 --g_tcpclientthreads->d_queued;
549 close(ci->fd);
550 delete ci;
551 ci=nullptr;
552 }
553 }
554 catch(std::exception& e) {
555 errlog("While reading a TCP question: %s", e.what());
556 if(ci && ci->fd >= 0)
557 close(ci->fd);
558 delete ci;
559 }
560 catch(...){}
561 }
562
563 return 0;
564 }
565
566
567 bool getMsgLen32(int fd, uint32_t* len)
568 try
569 {
570 uint32_t raw;
571 size_t ret = readn2(fd, &raw, sizeof raw);
572 if(ret != sizeof raw)
573 return false;
574 *len = ntohl(raw);
575 if(*len > 10000000) // arbitrary 10MB limit
576 return false;
577 return true;
578 }
579 catch(...) {
580 return false;
581 }
582
583 bool putMsgLen32(int fd, uint32_t len)
584 try
585 {
586 uint32_t raw = htonl(len);
587 size_t ret = writen2(fd, &raw, sizeof raw);
588 return ret==sizeof raw;
589 }
590 catch(...) {
591 return false;
592 }