]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-tcp.cc
dnsdist: Fix destination address reporting
[thirdparty/pdns.git] / pdns / dnsdist-tcp.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsparser.hh"
25 #include "ednsoptions.hh"
26 #include "dolog.hh"
27 #include "lock.hh"
28 #include "gettime.hh"
29 #include <thread>
30 #include <atomic>
31
32 using std::thread;
33 using std::atomic;
34
35 /* TCP: the grand design.
36 We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops.
37 An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially
38 we will not go there.
39
40 In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup.
41 This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections
42 to guarantee performance.
43
44 So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue.
45 So whenever an answer comes in, we know where it needs to go.
46
47 Let's start naively.
48 */
49
50 static int setupTCPDownstream(shared_ptr<DownstreamState> ds)
51 {
52 vinfolog("TCP connecting to downstream %s", ds->remote.toStringWithPort());
53 int sock = SSocket(ds->remote.sin4.sin_family, SOCK_STREAM, 0);
54 if (!IsAnyAddress(ds->sourceAddr)) {
55 SSetsockopt(sock, SOL_SOCKET, SO_REUSEADDR, 1);
56 SBind(sock, ds->sourceAddr);
57 }
58 SConnect(sock, ds->remote);
59 setNonBlocking(sock);
60 return sock;
61 }
62
63 struct ConnectionInfo
64 {
65 int fd;
66 ComboAddress remote;
67 ClientState* cs;
68 };
69
70 uint64_t g_maxTCPQueuedConnections{1000};
71 void* tcpClientThread(int pipefd);
72
73 // Should not be called simultaneously!
74 void TCPClientCollection::addTCPClientThread()
75 {
76 if (d_numthreads >= d_tcpclientthreads.capacity()) {
77 warnlog("Adding a new TCP client thread would exceed the vector capacity (%d/%d), skipping", d_numthreads.load(), d_tcpclientthreads.capacity());
78 return;
79 }
80
81 vinfolog("Adding TCP Client thread");
82
83 int pipefds[2] = { -1, -1};
84 if(pipe(pipefds) < 0)
85 unixDie("Creating pipe");
86
87 if (!setNonBlocking(pipefds[1])) {
88 close(pipefds[0]);
89 close(pipefds[1]);
90 unixDie("Setting pipe non-blocking");
91 }
92
93 d_tcpclientthreads.push_back(pipefds[1]);
94 ++d_numthreads;
95 thread t1(tcpClientThread, pipefds[0]);
96 t1.detach();
97 }
98
99 static bool getNonBlockingMsgLen(int fd, uint16_t* len, int timeout)
100 try
101 {
102 uint16_t raw;
103 size_t ret = readn2WithTimeout(fd, &raw, sizeof raw, timeout);
104 if(ret != sizeof raw)
105 return false;
106 *len = ntohs(raw);
107 return true;
108 }
109 catch(...) {
110 return false;
111 }
112
113 static bool putNonBlockingMsgLen(int fd, uint16_t len, int timeout)
114 try
115 {
116 uint16_t raw = htons(len);
117 size_t ret = writen2WithTimeout(fd, &raw, sizeof raw, timeout);
118 return ret == sizeof raw;
119 }
120 catch(...) {
121 return false;
122 }
123
124 static bool sendNonBlockingMsgLen(int fd, uint16_t len, int timeout, ComboAddress& dest, ComboAddress& local, unsigned int localItf)
125 try
126 {
127 if (localItf == 0)
128 return putNonBlockingMsgLen(fd, len, timeout);
129
130 uint16_t raw = htons(len);
131 ssize_t ret = sendMsgWithTimeout(fd, (char*) &raw, sizeof raw, timeout, dest, local, localItf);
132 return ret == sizeof raw;
133 }
134 catch(...) {
135 return false;
136 }
137
138 static bool sendResponseToClient(int fd, const char* response, uint16_t responseLen)
139 {
140 if (!putNonBlockingMsgLen(fd, responseLen, g_tcpSendTimeout))
141 return false;
142
143 writen2WithTimeout(fd, response, responseLen, g_tcpSendTimeout);
144 return true;
145 }
146
147 std::shared_ptr<TCPClientCollection> g_tcpclientthreads;
148
149 void* tcpClientThread(int pipefd)
150 {
151 /* we get launched with a pipe on which we receive file descriptors from clients that we own
152 from that point on */
153
154 bool outstanding = false;
155 blockfilter_t blockFilter = 0;
156
157 {
158 std::lock_guard<std::mutex> lock(g_luamutex);
159 auto candidate = g_lua.readVariable<boost::optional<blockfilter_t> >("blockFilter");
160 if(candidate)
161 blockFilter = *candidate;
162 }
163
164 auto localPolicy = g_policy.getLocal();
165 auto localRulactions = g_rulactions.getLocal();
166 auto localRespRulactions = g_resprulactions.getLocal();
167 auto localDynBlockNMG = g_dynblockNMG.getLocal();
168 auto localDynBlockSMT = g_dynblockSMT.getLocal();
169 auto localPools = g_pools.getLocal();
170 #ifdef HAVE_PROTOBUF
171 boost::uuids::random_generator uuidGenerator;
172 #endif
173
174 map<ComboAddress,int> sockets;
175 for(;;) {
176 ConnectionInfo* citmp, ci;
177
178 try {
179 readn2(pipefd, &citmp, sizeof(citmp));
180 }
181 catch(const std::runtime_error& e) {
182 throw std::runtime_error("Error reading from TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode: " + e.what());
183 }
184
185 --g_tcpclientthreads->d_queued;
186 ci=*citmp;
187 delete citmp;
188
189 uint16_t qlen, rlen;
190 string largerQuery;
191 vector<uint8_t> rewrittenResponse;
192 shared_ptr<DownstreamState> ds;
193 ComboAddress dest;
194 memset(&dest, 0, sizeof(dest));
195 dest.sin4.sin_family = ci.remote.sin4.sin_family;
196 socklen_t len = dest.getSocklen();
197 if (!setNonBlocking(ci.fd))
198 goto drop;
199
200 if (getsockname(ci.fd, (sockaddr*)&dest, &len)) {
201 dest = ci.cs->local;
202 }
203
204 try {
205 for(;;) {
206 ds = nullptr;
207 outstanding = false;
208
209 if(!getNonBlockingMsgLen(ci.fd, &qlen, g_tcpRecvTimeout))
210 break;
211
212 ci.cs->queries++;
213 g_stats.queries++;
214
215 if (qlen < sizeof(dnsheader)) {
216 g_stats.nonCompliantQueries++;
217 break;
218 }
219
220 bool ednsAdded = false;
221 bool ecsAdded = false;
222 /* if the query is small, allocate a bit more
223 memory to be able to spoof the content,
224 or to add ECS without allocating a new buffer */
225 size_t querySize = qlen <= 4096 ? qlen + 512 : qlen;
226 char queryBuffer[querySize];
227 const char* query = queryBuffer;
228 readn2WithTimeout(ci.fd, queryBuffer, qlen, g_tcpRecvTimeout);
229
230 #ifdef HAVE_DNSCRYPT
231 std::shared_ptr<DnsCryptQuery> dnsCryptQuery = 0;
232
233 if (ci.cs->dnscryptCtx) {
234 dnsCryptQuery = std::make_shared<DnsCryptQuery>();
235 uint16_t decryptedQueryLen = 0;
236 vector<uint8_t> response;
237 bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, queryBuffer, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
238
239 if (!decrypted) {
240 if (response.size() > 0) {
241 sendResponseToClient(ci.fd, reinterpret_cast<char*>(response.data()), (uint16_t) response.size());
242 }
243 break;
244 }
245 qlen = decryptedQueryLen;
246 }
247 #endif
248 struct dnsheader* dh = (struct dnsheader*) query;
249
250 if(dh->qr) { // don't respond to responses
251 g_stats.nonCompliantQueries++;
252 goto drop;
253 }
254
255 if(dh->qdcount == 0) {
256 g_stats.emptyQueries++;
257 goto drop;
258 }
259
260 if (dh->rd) {
261 g_stats.rdQueries++;
262 }
263
264 const uint16_t* flags = getFlagsFromDNSHeader(dh);
265 uint16_t origFlags = *flags;
266 uint16_t qtype, qclass;
267 unsigned int consumed = 0;
268 DNSName qname(query, qlen, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
269 DNSQuestion dq(&qname, qtype, qclass, &dest, &ci.remote, (dnsheader*)query, querySize, qlen, true);
270 #ifdef HAVE_PROTOBUF
271 dq.uniqueId = uuidGenerator();
272 #endif
273
274 string poolname;
275 int delayMsec=0;
276 /* we need this one to be accurate ("real") for the protobuf message */
277 struct timespec queryRealTime;
278 struct timespec now;
279 gettime(&now);
280 gettime(&queryRealTime, true);
281
282 if (!processQuery(localDynBlockNMG, localDynBlockSMT, localRulactions, blockFilter, dq, poolname, &delayMsec, now)) {
283 goto drop;
284 }
285
286 if(dq.dh->qr) { // something turned it into a response
287 restoreFlags(dh, origFlags);
288 #ifdef HAVE_DNSCRYPT
289 if (!encryptResponse(queryBuffer, &dq.len, dq.size, true, dnsCryptQuery)) {
290 goto drop;
291 }
292 #endif
293 sendResponseToClient(ci.fd, query, dq.len);
294 g_stats.selfAnswered++;
295 goto drop;
296 }
297
298 std::shared_ptr<ServerPool> serverPool = getPool(*localPools, poolname);
299 std::shared_ptr<DNSDistPacketCache> packetCache = nullptr;
300 {
301 std::lock_guard<std::mutex> lock(g_luamutex);
302 ds = localPolicy->policy(serverPool->servers, &dq);
303 packetCache = serverPool->packetCache;
304 }
305
306 if (dq.useECS && ds && ds->useECS) {
307 uint16_t newLen = dq.len;
308 handleEDNSClientSubnet(queryBuffer, dq.size, consumed, &newLen, largerQuery, &ednsAdded, &ecsAdded, ci.remote, dq.ecsOverride, dq.ecsPrefixLength);
309 if (largerQuery.empty() == false) {
310 query = largerQuery.c_str();
311 dq.len = (uint16_t) largerQuery.size();
312 dq.size = largerQuery.size();
313 } else {
314 dq.len = newLen;
315 }
316 }
317
318 uint32_t cacheKey = 0;
319 if (packetCache && !dq.skipCache) {
320 char cachedResponse[4096];
321 uint16_t cachedResponseSize = sizeof cachedResponse;
322 uint32_t allowExpired = ds ? 0 : g_staleCacheEntriesTTL;
323 if (packetCache->get(dq, (uint16_t) consumed, dq.dh->id, cachedResponse, &cachedResponseSize, &cacheKey, allowExpired)) {
324 #ifdef HAVE_DNSCRYPT
325 if (!encryptResponse(cachedResponse, &cachedResponseSize, sizeof cachedResponse, true, dnsCryptQuery)) {
326 goto drop;
327 }
328 #endif
329 sendResponseToClient(ci.fd, cachedResponse, cachedResponseSize);
330 g_stats.cacheHits++;
331 goto drop;
332 }
333 g_stats.cacheMisses++;
334 }
335
336 if(!ds) {
337 g_stats.noPolicy++;
338 break;
339 }
340
341 int dsock = -1;
342 if(sockets.count(ds->remote) == 0) {
343 dsock=sockets[ds->remote]=setupTCPDownstream(ds);
344 }
345 else
346 dsock=sockets[ds->remote];
347
348 ds->queries++;
349 ds->outstanding++;
350 outstanding = true;
351
352 uint16_t downstream_failures=0;
353 retry:;
354 if (dsock < 0) {
355 sockets.erase(ds->remote);
356 break;
357 }
358
359 if (ds->retries > 0 && downstream_failures > ds->retries) {
360 vinfolog("Downstream connection to %s failed %d times in a row, giving up.", ds->getName(), downstream_failures);
361 close(dsock);
362 dsock=-1;
363 sockets.erase(ds->remote);
364 break;
365 }
366
367 if(!sendNonBlockingMsgLen(dsock, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf)) {
368 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
369 close(dsock);
370 dsock=-1;
371 sockets.erase(ds->remote);
372 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
373 downstream_failures++;
374 goto retry;
375 }
376
377 try {
378 if (ds->sourceItf == 0) {
379 writen2WithTimeout(dsock, query, dq.len, ds->tcpSendTimeout);
380 }
381 else {
382 sendMsgWithTimeout(dsock, query, dq.len, ds->tcpSendTimeout, ds->remote, ds->sourceAddr, ds->sourceItf);
383 }
384 }
385 catch(const runtime_error& e) {
386 vinfolog("Downstream connection to %s died on us, getting a new one!", ds->getName());
387 close(dsock);
388 dsock=-1;
389 sockets.erase(ds->remote);
390 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
391 downstream_failures++;
392 goto retry;
393 }
394
395 bool xfrStarted = false;
396 bool isXFR = (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR);
397 if (isXFR) {
398 dq.skipCache = true;
399 }
400
401 getpacket:;
402
403 if(!getNonBlockingMsgLen(dsock, &rlen, ds->tcpRecvTimeout)) {
404 vinfolog("Downstream connection to %s died on us phase 2, getting a new one!", ds->getName());
405 close(dsock);
406 dsock=-1;
407 sockets.erase(ds->remote);
408 sockets[ds->remote]=dsock=setupTCPDownstream(ds);
409 downstream_failures++;
410 if(xfrStarted) {
411 goto drop;
412 }
413 goto retry;
414 }
415
416 size_t responseSize = rlen;
417 uint16_t addRoom = 0;
418 #ifdef HAVE_DNSCRYPT
419 if (dnsCryptQuery && (UINT16_MAX - rlen) > (uint16_t) DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE) {
420 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
421 }
422 #endif
423 responseSize += addRoom;
424 char answerbuffer[responseSize];
425 readn2WithTimeout(dsock, answerbuffer, rlen, ds->tcpRecvTimeout);
426 char* response = answerbuffer;
427 uint16_t responseLen = rlen;
428 if (outstanding) {
429 /* might be false for {A,I}XFR */
430 --ds->outstanding;
431 outstanding = false;
432 }
433
434 if (rlen < sizeof(dnsheader)) {
435 break;
436 }
437
438 if (!responseContentMatches(response, responseLen, qname, qtype, qclass, ds->remote)) {
439 break;
440 }
441
442 if (!fixUpResponse(&response, &responseLen, &responseSize, qname, origFlags, ednsAdded, ecsAdded, rewrittenResponse, addRoom)) {
443 break;
444 }
445
446 dh = (struct dnsheader*) response;
447 DNSResponse dr(&qname, qtype, qclass, &dest, &ci.remote, dh, responseSize, responseLen, true, &queryRealTime);
448 #ifdef HAVE_PROTOBUF
449 dr.uniqueId = dq.uniqueId;
450 #endif
451 if (!processResponse(localRespRulactions, dr, &delayMsec)) {
452 break;
453 }
454
455 if (packetCache && !dq.skipCache) {
456 packetCache->insert(cacheKey, qname, qtype, qclass, response, responseLen, true, dh->rcode == RCode::ServFail);
457 }
458
459 #ifdef HAVE_DNSCRYPT
460 if (!encryptResponse(response, &responseLen, responseSize, true, dnsCryptQuery)) {
461 goto drop;
462 }
463 #endif
464 if (!sendResponseToClient(ci.fd, response, responseLen)) {
465 break;
466 }
467
468 if (isXFR && dh->rcode == 0 && dh->ancount != 0) {
469 if (xfrStarted == false) {
470 xfrStarted = true;
471 if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 1) {
472 goto getpacket;
473 }
474 }
475 else if (getRecordsOfTypeCount(response, responseLen, 1, QType::SOA) == 0) {
476 goto getpacket;
477 }
478 }
479
480 g_stats.responses++;
481 struct timespec answertime;
482 gettime(&answertime);
483 unsigned int udiff = 1000000.0*DiffTime(now,answertime);
484 {
485 std::lock_guard<std::mutex> lock(g_rings.respMutex);
486 g_rings.respRing.push_back({answertime, ci.remote, qname, dq.qtype, (unsigned int)udiff, (unsigned int)responseLen, *dh, ds->remote});
487 }
488
489 largerQuery.clear();
490 rewrittenResponse.clear();
491 }
492 }
493 catch(...){}
494
495 drop:;
496
497 vinfolog("Closing TCP client connection with %s", ci.remote.toStringWithPort());
498 close(ci.fd);
499 ci.fd=-1;
500 if (ds && outstanding) {
501 outstanding = false;
502 --ds->outstanding;
503 }
504 }
505 return 0;
506 }
507
508
509 /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and
510 they will hand off to worker threads & spawn more of them if required
511 */
512 void* tcpAcceptorThread(void* p)
513 {
514 ClientState* cs = (ClientState*) p;
515
516 ComboAddress remote;
517 remote.sin4.sin_family = cs->local.sin4.sin_family;
518
519 g_tcpclientthreads->addTCPClientThread();
520
521 auto acl = g_ACL.getLocal();
522 for(;;) {
523 ConnectionInfo* ci;
524 try {
525 ci=0;
526 ci = new ConnectionInfo;
527 ci->cs = cs;
528 ci->fd = -1;
529 ci->fd = SAccept(cs->tcpFD, remote);
530
531 if(!acl->match(remote)) {
532 g_stats.aclDrops++;
533 close(ci->fd);
534 delete ci;
535 ci=0;
536 vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort());
537 continue;
538 }
539
540 if(g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->d_queued >= g_maxTCPQueuedConnections) {
541 close(ci->fd);
542 delete ci;
543 ci=nullptr;
544 vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort());
545 continue;
546 }
547
548 vinfolog("Got TCP connection from %s", remote.toStringWithPort());
549
550 ci->remote = remote;
551 int pipe = g_tcpclientthreads->getThread();
552 if (pipe >= 0) {
553 writen2WithTimeout(pipe, &ci, sizeof(ci), 0);
554 }
555 else {
556 --g_tcpclientthreads->d_queued;
557 close(ci->fd);
558 delete ci;
559 ci=nullptr;
560 }
561 }
562 catch(std::exception& e) {
563 errlog("While reading a TCP question: %s", e.what());
564 if(ci && ci->fd >= 0)
565 close(ci->fd);
566 delete ci;
567 }
568 catch(...){}
569 }
570
571 return 0;
572 }
573
574
575 bool getMsgLen32(int fd, uint32_t* len)
576 try
577 {
578 uint32_t raw;
579 size_t ret = readn2(fd, &raw, sizeof raw);
580 if(ret != sizeof raw)
581 return false;
582 *len = ntohl(raw);
583 if(*len > 10000000) // arbitrary 10MB limit
584 return false;
585 return true;
586 }
587 catch(...) {
588 return false;
589 }
590
591 bool putMsgLen32(int fd, uint32_t len)
592 try
593 {
594 uint32_t raw = htonl(len);
595 size_t ret = writen2(fd, &raw, sizeof raw);
596 return ret==sizeof raw;
597 }
598 catch(...) {
599 return false;
600 }