]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #4489 from pieterlexis/issue-4483-caching-on-forward-zones
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
26 #include "dolog.hh"
27 #include "lock.hh"
28 #include "gettime.hh"
29 #include <thread>
30 #include <atomic>
31
32 using std::thread;
33 using std::atomic;
34
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
38 we will not go there.
39
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
43
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
46
47 Let's start naively.
48 */
49
50 static int setupTCPDownstream(shared_ptr<DownstreamState> ds)
51 {
52 vinfolog("TCP connecting to downstream %s", ds->remote.toStringWithPort());
53 int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
54 if (!IsAnyAddress(ds->sourceAddr)) {
55 SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
56 SBind(sock, ds->sourceAddr);
57 }
58 SConnect(sock, ds->remote);
59 setNonBlocking(sock);
60 return sock;
61 }
62
63 struct ConnectionInfo
64 {
65 int fd;
66 ComboAddress remote;
67 ClientState* cs;
68 };
69
70 uint64_t g_maxTCPQueuedConnections{0};
71 void* tcpClientThread(int pipefd);
72
73 // Should not be called simultaneously!
74 void TCPClientCollection::addTCPClientThread()
75 {
76 if (d_numthreads >= d_tcpclientthreads.capacity()) {
77 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
78 return;
79 }
80
81 vinfolog("Adding TCP Client thread");
82
83 int pipefds[2] = { -1, -1};
84 if(pipe(pipefds) < 0)
85 unixDie("Creating pipe");
86
87 if (!setNonBlocking(pipefds[1])) {
88 close(pipefds[0]);
89 close(pipefds[1]);
90 unixDie("Setting pipe non-blocking");
91 }
92
93 d_tcpclientthreads.push_back(pipefds[1]);
94 ++d_numthreads;
95 thread t1(tcpClientThread, pipefds[0]);
96 t1.detach();
97 }
98
99 static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
100 try
101 {
102 uint16_t raw;
103 size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
104 if(ret != sizeof raw)
105 return false;
106 *len = ntohs(raw);
107 return true;
108 }
109 catch(...) {
110 return false;
111 }
112
113 static bool putNonBlockingMsgLen(int fd, uint16_t len, int timeout)
114 try
115 {
116 uint16_t raw = htons(len);
117 size_t ret = writen2WithTimeout(fd, &raw, sizeof raw, timeout);
118 return ret == sizeof raw;
119 }
120 catch(...) {
121 return false;
122 }
123
124 static bool sendNonBlockingMsgLen(int fd, uint16_t len, int timeout, ComboAddress& dest, ComboAddress& local, unsigned int localItf)
125 try
126 {
127 if (localItf == 0)
128 return putNonBlockingMsgLen(fd, len, timeout);
129
130 uint16_t raw = htons(len);
131 ssize_t ret = sendMsgWithTimeout(fd, (char*) &raw, sizeof raw, timeout, dest, local, localItf);
132 return ret == sizeof raw;
133 }
134 catch(...) {
135 return false;
136 }
137
138 static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
139 {
140 if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
141 return false;
142
143 writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
144 return true;
145 }
146
147 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
148
149 void* tcpClientThread(int pipefd)
150 {
151 /* we get launched with a pipe on which we receive file descriptors from clients that we own
152 from that point on */
153
154 bool outstanding = false;
155 blockfilter_t blockFilter = 0;
156
157 {
158 std::lock_guard<std::mutex> lock(g_luamutex);
159 auto candidate = g_lua.readVariable<boost::optional<blockfilter_t> >("blockFilter");
160 if(candidate)
161 blockFilter = *candidate;
162 }
163
164 auto localPolicy = g_policy.getLocal();
165 auto localRulactions = g_rulactions.getLocal();
166 auto localRespRulactions = g_resprulactions.getLocal();
167 auto localDynBlockNMG = g_dynblockNMG.getLocal();
168 auto localDynBlockSMT = g_dynblockSMT.getLocal();
169 auto localPools = g_pools.getLocal();
170 #ifdef HAVE_PROTOBUF
171 boost::uuids::random_generator uuidGenerator;
172 #endif
173
174 map<ComboAddress,int> sockets;
175 for(;;) {
176 ConnectionInfo* citmp, ci;
177
178 try {
179 readn2(pipefd, &citmp, sizeof(citmp));
180 }
181 catch(const std::runtime_error& e) {
182 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
183 }
184
185 --g_tcpclientthreads->d_queued;
186 ci=*citmp;
187 delete citmp;
188
189 uint16_t qlen, rlen;
190 string largerQuery;
191 vector<uint8_t> rewrittenResponse;
192 shared_ptr<DownstreamState> ds;
193 if (!setNonBlocking(ci.fd))
194 goto drop;
195
196 try {
197 for(;;) {
198 ds = nullptr;
199 outstanding = false;
200
201 if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
202 break;
203
204 ci.cs->queries++;
205 g_stats.queries++;
206
207 if (qlen < sizeof(dnsheader)) {
208 g_stats.nonCompliantQueries++;
209 break;
210 }
211
212 bool ednsAdded = false;
213 bool ecsAdded = false;
214 /* if the query is small, allocate a bit more
215 memory to be able to spoof the content,
216 or to add ECS without allocating a new buffer */
217 size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
218 char queryBuffer[querySize];
219 const char* query = queryBuffer;
220 readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
221
222 #ifdef HAVE_DNSCRYPT
223 std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
224
225 if (ci.cs->dnscryptCtx) {
226 dnsCryptQuery = std::make_shared<DnsCryptQuery>();
227 uint16_t decryptedQueryLen = 0;
228 vector<uint8_t> response;
229 bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
230
231 if (!decrypted) {
232 if (response.size() > 0) {
233 sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), (uint16_t) response.size());
234 }
235 break;
236 }
237 qlen = decryptedQueryLen;
238 }
239 #endif
240 struct dnsheader* dh = (struct dnsheader*) query;
241
242 if(dh->qr) { // don't respond to responses
243 g_stats.nonCompliantQueries++;
244 goto drop;
245 }
246
247 if(dh->qdcount == 0) {
248 g_stats.emptyQueries++;
249 goto drop;
250 }
251
252 if (dh->rd) {
253 g_stats.rdQueries++;
254 }
255
256 const uint16_t* flags = getFlagsFromDNSHeader(dh);
257 uint16_t origFlags = *flags;
258 uint16_t qtype, qclass;
259 unsigned int consumed = 0;
260 DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
261 DNSQuestion dq(&qname, qtype, qclass, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true);
262 #ifdef HAVE_PROTOBUF
263 dq.uniqueId = uuidGenerator();
264 #endif
265
266 string poolname;
267 int delayMsec=0;
268 struct timespec now;
269 gettime(&now, true);
270
271 if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, blockFilter, dq, poolname, &delayMsec, now)) {
272 goto drop;
273 }
274
275 if(dq.dh->qr) { // something turned it into a response
276 restoreFlags(dh, origFlags);
277 #ifdef HAVE_DNSCRYPT
278 if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
279 goto drop;
280 }
281 #endif
282 sendResponseToClient(ci.fd, query, dq.len);
283 g_stats.selfAnswered++;
284 goto drop;
285 }
286
287 std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
288 std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
289 {
290 std::lock_guard<std::mutex> lock(g_luamutex);
291 ds = localPolicy->policy(serverPool->servers, &dq);
292 packetCache = serverPool->packetCache;
293 }
294
295 if (dq.useECS && ds && ds->useECS) {
296 uint16_t newLen = dq.len;
297 handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength);
298 if (largerQuery.empty() == false) {
299 query = largerQuery.c_str();
300 dq.len = (uint16_t) largerQuery.size();
301 dq.size = largerQuery.size();
302 } else {
303 dq.len = newLen;
304 }
305 }
306
307 uint32_t cacheKey = 0;
308 if (packetCache && !dq.skipCache) {
309 char cachedResponse[4096];
310 uint16_t cachedResponseSize = sizeof cachedResponse;
311 uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
312 if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
313 #ifdef HAVE_DNSCRYPT
314 if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery)) {
315 goto drop;
316 }
317 #endif
318 sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
319 g_stats.cacheHits++;
320 goto drop;
321 }
322 g_stats.cacheMisses++;
323 }
324
325 if(!ds) {
326 g_stats.noPolicy++;
327 break;
328 }
329
330 int dsock = -1;
331 if(sockets.count(ds->remote) == 0) {
332 dsock=sockets[ds->remote]=setupTCPDownstream(ds);
333 }
334 else
335 dsock=sockets[ds->remote];
336
337 ds->queries++;
338 ds->outstanding++;
339 outstanding = true;
340
341 uint16_t downstream_failures=0;
342 retry:;
343 if (dsock < 0) {
344 sockets.erase(ds->remote);
345 break;
346 }
347
348 if (ds->retries > 0 && downstream_failures > ds->retries) {
349 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstream_failures);
350 close(dsock);
351 dsock=-1;
352 sockets.erase(ds->remote);
353 break;
354 }
355
356 if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
357 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
358 close(dsock);
359 dsock=-1;
360 sockets.erase(ds->remote);
361 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
362 downstream_failures++;
363 goto retry;
364 }
365
366 try {
367 if (ds->sourceItf == 0) {
368 writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout);
369 }
370 else {
371 sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
372 }
373 }
374 catch(const runtime_error& e) {
375 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
376 close(dsock);
377 dsock=-1;
378 sockets.erase(ds->remote);
379 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
380 downstream_failures++;
381 goto retry;
382 }
383
384 bool xfrStarted = false;
385 bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
386 if (isXFR) {
387 dq.skipCache = true;
388 }
389
390 getpacket:;
391
392 if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
393 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
394 close(dsock);
395 dsock=-1;
396 sockets.erase(ds->remote);
397 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
398 downstream_failures++;
399 if(xfrStarted) {
400 goto drop;
401 }
402 goto retry;
403 }
404
405 size_t responseSize = rlen;
406 uint16_t addRoom = 0;
407 #ifdef HAVE_DNSCRYPT
408 if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
409 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
410 }
411 #endif
412 responseSize += addRoom;
413 char answerbuffer[responseSize];
414 readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
415 char* response = answerbuffer;
416 uint16_t responseLen = rlen;
417 if (outstanding) {
418 /* might be false for {A,I}XFR */
419 --ds->outstanding;
420 outstanding = false;
421 }
422
423 if (rlen < sizeof(dnsheader)) {
424 break;
425 }
426
427 if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
428 break;
429 }
430
431 if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom)) {
432 break;
433 }
434
435 dh = (struct dnsheader*) response;
436 DNSResponse dr(&qname, qtype, qclass, &ci.cs->local, &ci.remote, dh, responseSize, responseLen, true, &now);
437 #ifdef HAVE_PROTOBUF
438 dr.uniqueId = dq.uniqueId;
439 #endif
440 if (!processResponse(localRespRulactions, dr, &delayMsec)) {
441 break;
442 }
443
444 if (packetCache && !dq.skipCache) {
445 packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
446 }
447
448 #ifdef HAVE_DNSCRYPT
449 if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery)) {
450 goto drop;
451 }
452 #endif
453 if (!sendResponseToClient(ci.fd, response, responseLen)) {
454 break;
455 }
456
457 if (isXFR && dh->rcode == 0 && dh->ancount != 0) {
458 if (xfrStarted == false) {
459 xfrStarted = true;
460 if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
461 goto getpacket;
462 }
463 }
464 else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
465 goto getpacket;
466 }
467 }
468
469 g_stats.responses++;
470 struct timespec answertime;
471 gettime(&answertime);
472 unsigned int udiff = 1000000.0*DiffTime(now,answertime);
473 {
474 std::lock_guard<std::mutex> lock(g_rings.respMutex);
475 g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dq.dh, ds->remote});
476 }
477
478 largerQuery.clear();
479 rewrittenResponse.clear();
480 }
481 }
482 catch(...){}
483
484 drop:;
485
486 vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
487 close(ci.fd);
488 ci.fd=-1;
489 if (ds && outstanding) {
490 outstanding = false;
491 --ds->outstanding;
492 }
493 }
494 return 0;
495 }
496
497
498 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
499 they will hand off to worker threads & spawn more of them if required
500 */
501 void* tcpAcceptorThread(void* p)
502 {
503 ClientState* cs = (ClientState*) p;
504
505 ComboAddress remote;
506 remote.sin4.sin_family = cs->local.sin4.sin_family;
507
508 g_tcpclientthreads->addTCPClientThread();
509
510 auto acl = g_ACL.getLocal();
511 for(;;) {
512 ConnectionInfo* ci;
513 try {
514 ci=0;
515 ci = new ConnectionInfo;
516 ci->cs = cs;
517 ci->fd = -1;
518 ci->fd = SAccept(cs->tcpFD, remote);
519
520 if(!acl->match(remote)) {
521 g_stats.aclDrops++;
522 close(ci->fd);
523 delete ci;
524 ci=0;
525 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
526 continue;
527 }
528
529 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) {
530 close(ci->fd);
531 delete ci;
532 ci=nullptr;
533 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
534 continue;
535 }
536
537 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
538
539 ci->remote = remote;
540 int pipe = g_tcpclientthreads->getThread();
541 if (pipe >= 0) {
542 writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
543 }
544 else {
545 --g_tcpclientthreads->d_queued;
546 close(ci->fd);
547 delete ci;
548 ci=nullptr;
549 }
550 }
551 catch(std::exception& e) {
552 errlog("While reading a TCP question: %s", e.what());
553 if(ci && ci->fd >= 0)
554 close(ci->fd);
555 delete ci;
556 }
557 catch(...){}
558 }
559
560 return 0;
561 }
562
563
564 bool getMsgLen32(int fd, uint32_t* len)
565 try
566 {
567 uint32_t raw;
568 size_t ret = readn2(fd, &raw, sizeof raw);
569 if(ret != sizeof raw)
570 return false;
571 *len = ntohl(raw);
572 if(*len > 10000000) // arbitrary 10MB limit
573 return false;
574 return true;
575 }
576 catch(...) {
577 return false;
578 }
579
580 bool putMsgLen32(int fd, uint32_t len)
581 try
582 {
583 uint32_t raw = htonl(len);
584 size_t ret = writen2(fd, &raw, sizeof raw);
585 return ret==sizeof raw;
586 }
587 catch(...) {
588 return false;
589 }