]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
Merge pull request #4356 from rgacogne/auth-nocachelookup-tsig
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
26 #include "dolog.hh"
27 #include "lock.hh"
28 #include "gettime.hh"
29 #include <thread>
30 #include <atomic>
31
32 using std::thread;
33 using std::atomic;
34
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
38 we will not go there.
39
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
43
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
46
47 Let's start naively.
48 */
49
50 static int setupTCPDownstream(shared_ptr<DownstreamState> ds)
51 {
52 vinfolog("TCP connecting to downstream %s", ds->remote.toStringWithPort());
53 int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
54 if (!IsAnyAddress(ds->sourceAddr)) {
55 SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
56 SBind(sock, ds->sourceAddr);
57 }
58 SConnect(sock, ds->remote);
59 setNonBlocking(sock);
60 return sock;
61 }
62
63 struct ConnectionInfo
64 {
65 int fd;
66 ComboAddress remote;
67 ClientState* cs;
68 };
69
70 uint64_t g_maxTCPQueuedConnections{0};
71 void* tcpClientThread(int pipefd);
72
73 // Should not be called simultaneously!
74 void TCPClientCollection::addTCPClientThread()
75 {
76 if (d_numthreads >= d_tcpclientthreads.capacity()) {
77 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
78 return;
79 }
80
81 vinfolog("Adding TCP Client thread");
82
83 int pipefds[2] = { -1, -1};
84 if(pipe(pipefds) < 0)
85 unixDie("Creating pipe");
86
87 if (!setNonBlocking(pipefds[1])) {
88 close(pipefds[0]);
89 close(pipefds[1]);
90 unixDie("Setting pipe non-blocking");
91 }
92
93 d_tcpclientthreads.push_back(pipefds[1]);
94 ++d_numthreads;
95 thread t1(tcpClientThread, pipefds[0]);
96 t1.detach();
97 }
98
99 static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
100 try
101 {
102 uint16_t raw;
103 size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
104 if(ret != sizeof raw)
105 return false;
106 *len = ntohs(raw);
107 return true;
108 }
109 catch(...) {
110 return false;
111 }
112
113 static bool putNonBlockingMsgLen(int fd, uint16_t len, int timeout)
114 try
115 {
116 uint16_t raw = htons(len);
117 size_t ret = writen2WithTimeout(fd, &raw, sizeof raw, timeout);
118 return ret == sizeof raw;
119 }
120 catch(...) {
121 return false;
122 }
123
124 static bool sendNonBlockingMsgLen(int fd, uint16_t len, int timeout, ComboAddress& dest, ComboAddress& local, unsigned int localItf)
125 try
126 {
127 if (localItf == 0)
128 return putNonBlockingMsgLen(fd, len, timeout);
129
130 uint16_t raw = htons(len);
131 ssize_t ret = sendMsgWithTimeout(fd, (char*) &raw, sizeof raw, timeout, dest, local, localItf);
132 return ret == sizeof raw;
133 }
134 catch(...) {
135 return false;
136 }
137
138 static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
139 {
140 if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
141 return false;
142
143 writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
144 return true;
145 }
146
147 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
148
149 void* tcpClientThread(int pipefd)
150 {
151 /* we get launched with a pipe on which we receive file descriptors from clients that we own
152 from that point on */
153
154 bool outstanding = false;
155 blockfilter_t blockFilter = 0;
156
157 {
158 std::lock_guard<std::mutex> lock(g_luamutex);
159 auto candidate = g_lua.readVariable<boost::optional<blockfilter_t> >("blockFilter");
160 if(candidate)
161 blockFilter = *candidate;
162 }
163
164 auto localPolicy = g_policy.getLocal();
165 auto localRulactions = g_rulactions.getLocal();
166 auto localRespRulactions = g_resprulactions.getLocal();
167 auto localDynBlockNMG = g_dynblockNMG.getLocal();
168 auto localDynBlockSMT = g_dynblockSMT.getLocal();
169 auto localPools = g_pools.getLocal();
170 #ifdef HAVE_PROTOBUF
171 boost::uuids::random_generator uuidGenerator;
172 #endif
173
174 map<ComboAddress,int> sockets;
175 for(;;) {
176 ConnectionInfo* citmp, ci;
177
178 try {
179 readn2(pipefd, &citmp, sizeof(citmp));
180 }
181 catch(const std::runtime_error& e) {
182 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
183 }
184
185 --g_tcpclientthreads->d_queued;
186 ci=*citmp;
187 delete citmp;
188
189 uint16_t qlen, rlen;
190 string largerQuery;
191 vector<uint8_t> rewrittenResponse;
192 shared_ptr<DownstreamState> ds;
193 if (!setNonBlocking(ci.fd))
194 goto drop;
195
196 try {
197 for(;;) {
198 ds = nullptr;
199 outstanding = false;
200
201 if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
202 break;
203
204 ci.cs->queries++;
205 g_stats.queries++;
206
207 if (qlen < sizeof(dnsheader)) {
208 g_stats.nonCompliantQueries++;
209 break;
210 }
211
212 bool ednsAdded = false;
213 bool ecsAdded = false;
214 /* if the query is small, allocate a bit more
215 memory to be able to spoof the content,
216 or to add ECS without allocating a new buffer */
217 size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
218 char queryBuffer[querySize];
219 const char* query = queryBuffer;
220 readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
221
222 #ifdef HAVE_DNSCRYPT
223 std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
224
225 if (ci.cs->dnscryptCtx) {
226 dnsCryptQuery = std::make_shared<DnsCryptQuery>();
227 uint16_t decryptedQueryLen = 0;
228 vector<uint8_t> response;
229 bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
230
231 if (!decrypted) {
232 if (response.size() > 0) {
233 sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), (uint16_t) response.size());
234 }
235 break;
236 }
237 qlen = decryptedQueryLen;
238 }
239 #endif
240 struct dnsheader* dh = (struct dnsheader*) query;
241
242 if(dh->qr) { // don't respond to responses
243 g_stats.nonCompliantQueries++;
244 goto drop;
245 }
246
247 if(dh->qdcount == 0) {
248 g_stats.emptyQueries++;
249 goto drop;
250 }
251
252 if (dh->rd) {
253 g_stats.rdQueries++;
254 }
255
256 const uint16_t* flags = getFlagsFromDNSHeader(dh);
257 uint16_t origFlags = *flags;
258 uint16_t qtype, qclass;
259 unsigned int consumed = 0;
260 DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
261 DNSQuestion dq(&qname, qtype, qclass, &ci.cs->local, &ci.remote, (dnsheader*)query, querySize, qlen, true);
262 #ifdef HAVE_PROTOBUF
263 dq.uniqueId = uuidGenerator();
264 #endif
265
266 string poolname;
267 int delayMsec=0;
268 struct timespec now;
269 gettime(&now, true);
270
271 if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, blockFilter, dq, poolname, &delayMsec, now)) {
272 goto drop;
273 }
274
275 if(dq.dh->qr) { // something turned it into a response
276 restoreFlags(dh, origFlags);
277 #ifdef HAVE_DNSCRYPT
278 if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
279 goto drop;
280 }
281 #endif
282 sendResponseToClient(ci.fd, query, dq.len);
283 g_stats.selfAnswered++;
284 goto drop;
285 }
286
287 std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
288 std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
289 {
290 std::lock_guard<std::mutex> lock(g_luamutex);
291 ds = localPolicy->policy(serverPool->servers, &dq);
292 packetCache = serverPool->packetCache;
293 }
294
295 if (ds && ds->useECS) {
296 uint16_t newLen = dq.len;
297 handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote);
298 if (largerQuery.empty() == false) {
299 query = largerQuery.c_str();
300 dq.len = (uint16_t) largerQuery.size();
301 dq.size = largerQuery.size();
302 } else {
303 dq.len = newLen;
304 }
305 }
306
307 uint32_t cacheKey = 0;
308 if (packetCache && !dq.skipCache) {
309 char cachedResponse[4096];
310 uint16_t cachedResponseSize = sizeof cachedResponse;
311 uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
312 if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
313 #ifdef HAVE_DNSCRYPT
314 if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery)) {
315 goto drop;
316 }
317 #endif
318 sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
319 g_stats.cacheHits++;
320 goto drop;
321 }
322 g_stats.cacheMisses++;
323 }
324
325 if(!ds) {
326 g_stats.noPolicy++;
327 break;
328 }
329
330 int dsock = -1;
331 if(sockets.count(ds->remote) == 0) {
332 dsock=sockets[ds->remote]=setupTCPDownstream(ds);
333 }
334 else
335 dsock=sockets[ds->remote];
336
337 ds->queries++;
338 ds->outstanding++;
339 outstanding = true;
340
341 uint16_t downstream_failures=0;
342 retry:;
343 if (dsock < 0) {
344 sockets.erase(ds->remote);
345 break;
346 }
347
348 if (ds->retries > 0 && downstream_failures > ds->retries) {
349 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstream_failures);
350 close(dsock);
351 dsock=-1;
352 sockets.erase(ds->remote);
353 break;
354 }
355
356 if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
357 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
358 close(dsock);
359 dsock=-1;
360 sockets.erase(ds->remote);
361 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
362 downstream_failures++;
363 goto retry;
364 }
365
366 try {
367 if (ds->sourceItf == 0) {
368 writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout);
369 }
370 else {
371 sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
372 }
373 }
374 catch(const runtime_error& e) {
375 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
376 close(dsock);
377 dsock=-1;
378 sockets.erase(ds->remote);
379 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
380 downstream_failures++;
381 goto retry;
382 }
383
384 bool xfrStarted = false;
385 bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
386 if (isXFR) {
387 dq.skipCache = true;
388 }
389
390 getpacket:;
391
392 if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
393 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
394 close(dsock);
395 dsock=-1;
396 sockets.erase(ds->remote);
397 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
398 downstream_failures++;
399 if(xfrStarted) {
400 goto drop;
401 }
402 goto retry;
403 }
404
405 size_t responseSize = rlen;
406 uint16_t addRoom = 0;
407 #ifdef HAVE_DNSCRYPT
408 if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
409 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
410 }
411 #endif
412 responseSize += addRoom;
413 char answerbuffer[responseSize];
414 readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
415 char* response = answerbuffer;
416 uint16_t responseLen = rlen;
417 --ds->outstanding;
418 outstanding = false;
419
420 if (rlen < sizeof(dnsheader)) {
421 break;
422 }
423
424 if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
425 break;
426 }
427
428 if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom)) {
429 break;
430 }
431
432 dh = (struct dnsheader*) response;
433 DNSResponse dr(&qname, qtype, qclass, &ci.cs->local, &ci.remote, dh, responseSize, responseLen, true, &now);
434 #ifdef HAVE_PROTOBUF
435 dr.uniqueId = dq.uniqueId;
436 #endif
437 if (!processResponse(localRespRulactions, dr, &delayMsec)) {
438 break;
439 }
440
441 if (packetCache && !dq.skipCache) {
442 packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
443 }
444
445 #ifdef HAVE_DNSCRYPT
446 if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery)) {
447 goto drop;
448 }
449 #endif
450 if (!sendResponseToClient(ci.fd, response, responseLen)) {
451 break;
452 }
453
454 if (isXFR && dh->rcode == 0 && dh->ancount != 0) {
455 if (xfrStarted == false) {
456 xfrStarted = true;
457 if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
458 goto getpacket;
459 }
460 }
461 else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
462 goto getpacket;
463 }
464 }
465
466 g_stats.responses++;
467 struct timespec answertime;
468 gettime(&answertime);
469 unsigned int udiff = 1000000.0*DiffTime(now,answertime);
470 {
471 std::lock_guard<std::mutex> lock(g_rings.respMutex);
472 g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dq.dh, ds->remote});
473 }
474
475 largerQuery.clear();
476 rewrittenResponse.clear();
477 }
478 }
479 catch(...){}
480
481 drop:;
482
483 vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
484 close(ci.fd);
485 ci.fd=-1;
486 if (ds && outstanding) {
487 outstanding = false;
488 --ds->outstanding;
489 }
490 }
491 return 0;
492 }
493
494
495 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
496 they will hand off to worker threads & spawn more of them if required
497 */
498 void* tcpAcceptorThread(void* p)
499 {
500 ClientState* cs = (ClientState*) p;
501
502 ComboAddress remote;
503 remote.sin4.sin_family = cs->local.sin4.sin_family;
504
505 g_tcpclientthreads->addTCPClientThread();
506
507 auto acl = g_ACL.getLocal();
508 for(;;) {
509 ConnectionInfo* ci;
510 try {
511 ci=0;
512 ci = new ConnectionInfo;
513 ci->cs = cs;
514 ci->fd = -1;
515 ci->fd = SAccept(cs->tcpFD, remote);
516
517 if(!acl->match(remote)) {
518 g_stats.aclDrops++;
519 close(ci->fd);
520 delete ci;
521 ci=0;
522 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
523 continue;
524 }
525
526 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) {
527 close(ci->fd);
528 delete ci;
529 ci=nullptr;
530 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
531 continue;
532 }
533
534 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
535
536 ci->remote = remote;
537 int pipe = g_tcpclientthreads->getThread();
538 if (pipe >= 0) {
539 writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
540 }
541 else {
542 --g_tcpclientthreads->d_queued;
543 close(ci->fd);
544 delete ci;
545 ci=nullptr;
546 }
547 }
548 catch(std::exception& e) {
549 errlog("While reading a TCP question: %s", e.what());
550 if(ci && ci->fd >= 0)
551 close(ci->fd);
552 delete ci;
553 }
554 catch(...){}
555 }
556
557 return 0;
558 }
559
560
561 bool getMsgLen32(int fd, uint32_t* len)
562 try
563 {
564 uint32_t raw;
565 size_t ret = readn2(fd, &raw, sizeof raw);
566 if(ret != sizeof raw)
567 return false;
568 *len = ntohl(raw);
569 if(*len > 10000000) // arbitrary 10MB limit
570 return false;
571 return true;
572 }
573 catch(...) {
574 return false;
575 }
576
577 bool putMsgLen32(int fd, uint32_t len)
578 try
579 {
580 uint32_t raw = htonl(len);
581 size_t ret = writen2(fd, &raw, sizeof raw);
582 return ret==sizeof raw;
583 }
584 catch(...) {
585 return false;
586 }