]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist.cc
Merge pull request #14021 from Habbie/auth-lua-join-whitespace
[thirdparty/pdns.git] / pdns / dnsdist.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22
23 #include "config.h"
24
25 #include <fstream>
26 #include <getopt.h>
27 #include <grp.h>
28 #include <limits>
29 #include <netinet/tcp.h>
30 #include <pwd.h>
31 #include <sys/resource.h>
32 #include <unistd.h>
33
34 #if defined (__OpenBSD__) || defined(__NetBSD__)
35 #include <readline/readline.h>
36 #else
37 #include <editline/readline.h>
38 #endif
39
40 #ifdef HAVE_SYSTEMD
41 #include <systemd/sd-daemon.h>
42 #endif
43
44 #include "dnsdist.hh"
45 #include "dnsdist-cache.hh"
46 #include "dnsdist-console.hh"
47 #include "dnsdist-ecs.hh"
48 #include "dnsdist-lua.hh"
49 #include "dnsdist-rings.hh"
50 #include "dnsdist-secpoll.hh"
51 #include "dnsdist-xpf.hh"
52
53 #include "base64.hh"
54 #include "delaypipe.hh"
55 #include "dolog.hh"
56 #include "dnsname.hh"
57 #include "dnsparser.hh"
58 #include "dnswriter.hh"
59 #include "ednsoptions.hh"
60 #include "gettime.hh"
61 #include "lock.hh"
62 #include "misc.hh"
63 #include "sodcrypto.hh"
64 #include "sstuff.hh"
65 #include "threadname.hh"
66
67 /* Known sins:
68
69 Receiver is currently single threaded
70 not *that* bad actually, but now that we are thread safe, might want to scale
71 */
72
73 /* the Rulaction plan
74 Set of Rules, if one matches, it leads to an Action
75 Both rules and actions could conceivably be Lua based.
76 On the C++ side, both could be inherited from a class Rule and a class Action,
77 on the Lua side we can't do that. */
78
79 using std::atomic;
80 using std::thread;
81 bool g_verbose;
82
83 struct DNSDistStats g_stats;
84 MetricDefinitionStorage g_metricDefinitions;
85
86 uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()};
87 bool g_verboseHealthChecks{false};
88 uint32_t g_staleCacheEntriesTTL{0};
89 bool g_syslog{true};
90 bool g_allowEmptyResponse{false};
91
92 GlobalStateHolder<NetmaskGroup> g_ACL;
93 string g_outputBuffer;
94
95 std::vector<std::shared_ptr<TLSFrontend>> g_tlslocals;
96 std::vector<std::shared_ptr<DOHFrontend>> g_dohlocals;
97 std::vector<std::shared_ptr<DNSCryptContext>> g_dnsCryptLocals;
98 #ifdef HAVE_EBPF
99 shared_ptr<BPFFilter> g_defaultBPFFilter;
100 std::vector<std::shared_ptr<DynBPFFilter> > g_dynBPFFilters;
101 #endif /* HAVE_EBPF */
102 std::vector<std::unique_ptr<ClientState>> g_frontends;
103 GlobalStateHolder<pools_t> g_pools;
104 size_t g_udpVectorSize{1};
105
106 bool g_snmpEnabled{false};
107 bool g_snmpTrapsEnabled{false};
108 DNSDistSNMPAgent* g_snmpAgent{nullptr};
109
110 /* UDP: the grand design. Per socket we listen on for incoming queries there is one thread.
111 Then we have a bunch of connected sockets for talking to downstream servers.
112 We send directly to those sockets.
113
114 For the return path, per downstream server we have a thread that listens to responses.
115
116 Per socket there is an array of 2^16 states, when we send out a packet downstream, we note
117 there the original requestor and the original id. The new ID is the offset in the array.
118
119 When an answer comes in on a socket, we look up the offset by the id, and lob it to the
120 original requestor.
121
122 IDs are assigned by atomic increments of the socket offset.
123 */
124
125 GlobalStateHolder<vector<DNSDistRuleAction> > g_rulactions;
126 GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_resprulactions;
127 GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_cachehitresprulactions;
128 GlobalStateHolder<vector<DNSDistResponseRuleAction> > g_selfansweredresprulactions;
129
130 Rings g_rings;
131 QueryCount g_qcount;
132
133 GlobalStateHolder<servers_t> g_dstates;
134 GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
135 GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
136 DNSAction::Action g_dynBlockAction = DNSAction::Action::Drop;
137 int g_tcpRecvTimeout{2};
138 int g_tcpSendTimeout{2};
139 int g_udpTimeout{2};
140
141 bool g_servFailOnNoPolicy{false};
142 bool g_truncateTC{false};
143 bool g_fixupCase{false};
144 bool g_preserveTrailingData{false};
145 bool g_roundrobinFailOnNoServer{false};
146
147 static void truncateTC(char* packet, uint16_t* len, size_t responseSize, unsigned int consumed)
148 try
149 {
150 bool hadEDNS = false;
151 uint16_t payloadSize = 0;
152 uint16_t z = 0;
153
154 if (g_addEDNSToSelfGeneratedResponses) {
155 hadEDNS = getEDNSUDPPayloadSizeAndZ(packet, *len, &payloadSize, &z);
156 }
157
158 *len=static_cast<uint16_t>(sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE);
159 struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
160 dh->ancount = dh->arcount = dh->nscount = 0;
161
162 if (hadEDNS) {
163 addEDNS(dh, *len, responseSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0);
164 }
165 }
166 catch(...)
167 {
168 g_stats.truncFail++;
169 }
170
171 struct DelayedPacket
172 {
173 int fd;
174 string packet;
175 ComboAddress destination;
176 ComboAddress origDest;
177 void operator()()
178 {
179 ssize_t res;
180 if(origDest.sin4.sin_family == 0) {
181 res = sendto(fd, packet.c_str(), packet.size(), 0, (struct sockaddr*)&destination, destination.getSocklen());
182 }
183 else {
184 res = sendfromto(fd, packet.c_str(), packet.size(), 0, origDest, destination);
185 }
186 if (res == -1) {
187 int err = errno;
188 vinfolog("Error sending delayed response to %s: %s", destination.toStringWithPort(), strerror(err));
189 }
190 }
191 };
192
193 DelayPipe<DelayedPacket>* g_delay = nullptr;
194
195 void doLatencyStats(double udiff)
196 {
197 if(udiff < 1000) ++g_stats.latency0_1;
198 else if(udiff < 10000) ++g_stats.latency1_10;
199 else if(udiff < 50000) ++g_stats.latency10_50;
200 else if(udiff < 100000) ++g_stats.latency50_100;
201 else if(udiff < 1000000) ++g_stats.latency100_1000;
202 else ++g_stats.latencySlow;
203 g_stats.latencySum += udiff / 1000;
204
205 auto doAvg = [](double& var, double n, double weight) {
206 var = (weight -1) * var/weight + n/weight;
207 };
208
209 doAvg(g_stats.latencyAvg100, udiff, 100);
210 doAvg(g_stats.latencyAvg1000, udiff, 1000);
211 doAvg(g_stats.latencyAvg10000, udiff, 10000);
212 doAvg(g_stats.latencyAvg1000000, udiff, 1000000);
213 }
214
215 bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed)
216 {
217 if (responseLen < sizeof(dnsheader)) {
218 return false;
219 }
220
221 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response);
222 if (dh->qdcount == 0) {
223 if ((dh->rcode != RCode::NoError && dh->rcode != RCode::NXDomain) || g_allowEmptyResponse) {
224 return true;
225 }
226 else {
227 ++g_stats.nonCompliantResponses;
228 return false;
229 }
230 }
231
232 uint16_t rqtype, rqclass;
233 DNSName rqname;
234 try {
235 rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed);
236 }
237 catch(const std::exception& e) {
238 if(responseLen > 0 && static_cast<size_t>(responseLen) > sizeof(dnsheader)) {
239 infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what());
240 }
241 ++g_stats.nonCompliantResponses;
242 return false;
243 }
244
245 if (rqtype != qtype || rqclass != qclass || rqname != qname) {
246 return false;
247 }
248
249 return true;
250 }
251
252 static void restoreFlags(struct dnsheader* dh, uint16_t origFlags)
253 {
254 static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET;
255 static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET;
256 static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask);
257 uint16_t * flags = getFlagsFromDNSHeader(dh);
258 /* clear the flags we are about to restore */
259 *flags &= restoreFlagsMask;
260 /* only keep the flags we want to restore */
261 origFlags &= ~restoreFlagsMask;
262 /* set the saved flags as they were */
263 *flags |= origFlags;
264 }
265
266 static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
267 {
268 restoreFlags(dq.dh, origFlags);
269
270 return addEDNSToQueryTurnedResponse(dq);
271 }
272
273 static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector<uint8_t>& rewrittenResponse, uint16_t addRoom, bool* zeroScope)
274 {
275 if (*responseLen < sizeof(dnsheader)) {
276 return false;
277 }
278
279 struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(*response);
280 restoreFlags(dh, origFlags);
281
282 if (*responseLen == sizeof(dnsheader)) {
283 return true;
284 }
285
286 if(g_fixupCase) {
287 string realname = qname.toDNSString();
288 if (*responseLen >= (sizeof(dnsheader) + realname.length())) {
289 memcpy(*response + sizeof(dnsheader), realname.c_str(), realname.length());
290 }
291 }
292
293 if (ednsAdded || ecsAdded) {
294 uint16_t optStart;
295 size_t optLen = 0;
296 bool last = false;
297
298 const std::string responseStr(*response, *responseLen);
299 int res = locateEDNSOptRR(responseStr, &optStart, &optLen, &last);
300
301 if (res == 0) {
302 if (zeroScope) { // this finds if an EDNS Client Subnet scope was set, and if it is 0
303 size_t optContentStart = 0;
304 uint16_t optContentLen = 0;
305 /* we need at least 4 bytes after the option length (family: 2, source prefix-length: 1, scope prefix-length: 1) */
306 if (isEDNSOptionInOpt(responseStr, optStart, optLen, EDNSOptionCode::ECS, &optContentStart, &optContentLen) && optContentLen >= 4) {
307 /* see if the EDNS Client Subnet SCOPE PREFIX-LENGTH byte in position 3 is set to 0, which is the only thing
308 we care about. */
309 *zeroScope = responseStr.at(optContentStart + 3) == 0;
310 }
311 }
312
313 if (ednsAdded) {
314 /* we added the entire OPT RR,
315 therefore we need to remove it entirely */
316 if (last) {
317 /* simply remove the last AR */
318 *responseLen -= optLen;
319 uint16_t arcount = ntohs(dh->arcount);
320 arcount--;
321 dh->arcount = htons(arcount);
322 }
323 else {
324 /* Removing an intermediary RR could lead to compression error */
325 if (rewriteResponseWithoutEDNS(responseStr, rewrittenResponse) == 0) {
326 *responseLen = rewrittenResponse.size();
327 if (addRoom && (UINT16_MAX - *responseLen) > addRoom) {
328 rewrittenResponse.reserve(*responseLen + addRoom);
329 }
330 *responseSize = rewrittenResponse.capacity();
331 *response = reinterpret_cast<char*>(rewrittenResponse.data());
332 }
333 else {
334 warnlog("Error rewriting content");
335 }
336 }
337 }
338 else {
339 /* the OPT RR was already present, but without ECS,
340 we need to remove the ECS option if any */
341 if (last) {
342 /* nothing after the OPT RR, we can simply remove the
343 ECS option */
344 size_t existingOptLen = optLen;
345 removeEDNSOptionFromOPT(*response + optStart, &optLen, EDNSOptionCode::ECS);
346 *responseLen -= (existingOptLen - optLen);
347 }
348 else {
349 /* Removing an intermediary RR could lead to compression error */
350 if (rewriteResponseWithoutEDNSOption(responseStr, EDNSOptionCode::ECS, rewrittenResponse) == 0) {
351 *responseLen = rewrittenResponse.size();
352 if (addRoom && (UINT16_MAX - *responseLen) > addRoom) {
353 rewrittenResponse.reserve(*responseLen + addRoom);
354 }
355 *responseSize = rewrittenResponse.capacity();
356 *response = reinterpret_cast<char*>(rewrittenResponse.data());
357 }
358 else {
359 warnlog("Error rewriting content");
360 }
361 }
362 }
363 }
364 }
365
366 return true;
367 }
368
369 #ifdef HAVE_DNSCRYPT
370 static bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
371 {
372 if (dnsCryptQuery) {
373 uint16_t encryptedResponseLen = 0;
374
375 /* save the original header before encrypting it in place */
376 if (dh != nullptr && *dh != nullptr && dhCopy != nullptr) {
377 memcpy(dhCopy, *dh, sizeof(dnsheader));
378 *dh = dhCopy;
379 }
380
381 int res = dnsCryptQuery->encryptResponse(response, *responseLen, responseSize, tcp, &encryptedResponseLen);
382 if (res == 0) {
383 *responseLen = encryptedResponseLen;
384 } else {
385 /* dropping response */
386 vinfolog("Error encrypting the response, dropping.");
387 return false;
388 }
389 }
390 return true;
391 }
392 #endif /* HAVE_DNSCRYPT */
393
394 static bool applyRulesToResponse(LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr)
395 {
396 DNSResponseAction::Action action=DNSResponseAction::Action::None;
397 std::string ruleresult;
398 for(const auto& lr : *localRespRulactions) {
399 if(lr.d_rule->matches(&dr)) {
400 lr.d_rule->d_matches++;
401 action=(*lr.d_action)(&dr, &ruleresult);
402 switch(action) {
403 case DNSResponseAction::Action::Allow:
404 return true;
405 break;
406 case DNSResponseAction::Action::Drop:
407 return false;
408 break;
409 case DNSResponseAction::Action::HeaderModify:
410 return true;
411 break;
412 case DNSResponseAction::Action::ServFail:
413 dr.dh->rcode = RCode::ServFail;
414 return true;
415 break;
416 /* non-terminal actions follow */
417 case DNSResponseAction::Action::Delay:
418 dr.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
419 break;
420 case DNSResponseAction::Action::None:
421 break;
422 }
423 }
424 }
425
426 return true;
427 }
428
429 bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder<vector<DNSDistResponseRuleAction> >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector<uint8_t>& rewrittenResponse, bool muted)
430 {
431 if (!applyRulesToResponse(localRespRulactions, dr)) {
432 return false;
433 }
434
435 bool zeroScope = false;
436 if (!fixUpResponse(response, responseLen, responseSize, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, rewrittenResponse, addRoom, dr.useZeroScope ? &zeroScope : nullptr)) {
437 return false;
438 }
439
440 if (dr.packetCache && !dr.skipCache) {
441 if (!dr.useZeroScope) {
442 /* if the query was not suitable for zero-scope, for
443 example because it had an existing ECS entry so the hash is
444 not really 'no ECS', so just insert it for the existing subnet
445 since:
446 - we don't have the correct hash for a non-ECS query
447 - inserting with hash computed before the ECS replacement but with
448 the subnet extracted _after_ the replacement would not work.
449 */
450 zeroScope = false;
451 }
452 // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache
453 dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.origFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, *response, *responseLen, dr.tcp, dr.dh->rcode, dr.tempFailureTTL);
454 }
455
456 #ifdef HAVE_DNSCRYPT
457 if (!muted) {
458 if (!encryptResponse(*response, responseLen, *responseSize, dr.tcp, dr.dnsCryptQuery, nullptr, nullptr)) {
459 return false;
460 }
461 }
462 #endif /* HAVE_DNSCRYPT */
463
464 return true;
465 }
466
467 static bool sendUDPResponse(int origFD, const char* response, const uint16_t responseLen, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote)
468 {
469 if(delayMsec && g_delay) {
470 DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest};
471 g_delay->submit(dp, delayMsec);
472 }
473 else {
474 ssize_t res;
475 if(origDest.sin4.sin_family == 0) {
476 res = sendto(origFD, response, responseLen, 0, reinterpret_cast<const struct sockaddr*>(&origRemote), origRemote.getSocklen());
477 }
478 else {
479 res = sendfromto(origFD, response, responseLen, 0, origDest, origRemote);
480 }
481 if (res == -1) {
482 int err = errno;
483 vinfolog("Error sending response to %s: %s", origRemote.toStringWithPort(), strerror(err));
484 }
485 }
486
487 return true;
488 }
489
490
491 int pickBackendSocketForSending(std::shared_ptr<DownstreamState>& state)
492 {
493 return state->sockets[state->socketsOffset++ % state->sockets.size()];
494 }
495
496 static void pickBackendSocketsReadyForReceiving(const std::shared_ptr<DownstreamState>& state, std::vector<int>& ready)
497 {
498 ready.clear();
499
500 if (state->sockets.size() == 1) {
501 ready.push_back(state->sockets[0]);
502 return ;
503 }
504
505 {
506 std::lock_guard<std::mutex> lock(state->socketsLock);
507 state->mplexer->getAvailableFDs(ready, -1);
508 }
509 }
510
511 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
512 void responderThread(std::shared_ptr<DownstreamState> dss)
513 try {
514 setThreadName("dnsdist/respond");
515 auto localRespRulactions = g_resprulactions.getLocal();
516 char packet[4096 + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE];
517 static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t");
518 /* when the answer is encrypted in place, we need to get a copy
519 of the original header before encryption to fill the ring buffer */
520 dnsheader cleartextDH;
521 vector<uint8_t> rewrittenResponse;
522
523 uint16_t queryId = 0;
524 std::vector<int> sockets;
525 sockets.reserve(dss->sockets.size());
526
527 for(;;) {
528 dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
529 try {
530 pickBackendSocketsReadyForReceiving(dss, sockets);
531 for (const auto& fd : sockets) {
532 ssize_t got = recv(fd, packet, sizeof(packet), 0);
533 char * response = packet;
534 size_t responseSize = sizeof(packet);
535
536 if (got < 0 || static_cast<size_t>(got) < sizeof(dnsheader))
537 continue;
538
539 uint16_t responseLen = static_cast<uint16_t>(got);
540 queryId = dh->id;
541
542 if(queryId >= dss->idStates.size()) {
543 continue;
544 }
545
546 IDState* ids = &dss->idStates[queryId];
547 int64_t usageIndicator = ids->usageIndicator;
548
549 if(!IDState::isInUse(usageIndicator)) {
550 /* the corresponding state is marked as not in use, meaning that:
551 - it was already cleaned up by another thread and the state is gone ;
552 - we already got a response for this query and this one is a duplicate.
553 Either way, we don't touch it.
554 */
555 continue;
556 }
557
558 /* read the potential DOHUnit state as soon as possible, but don't use it
559 until we have confirmed that we own this state by updating usageIndicator */
560 auto du = ids->du;
561 /* setting age to 0 to prevent the maintainer thread from
562 cleaning this IDS while we process the response.
563 */
564 ids->age = 0;
565 int origFD = ids->origFD;
566
567 unsigned int consumed = 0;
568 if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote, consumed)) {
569 continue;
570 }
571
572 bool isDoH = du != nullptr;
573 /* atomically mark the state as available, but only if it has not been altered
574 in the meantime */
575 if (ids->tryMarkUnused(usageIndicator)) {
576 /* clear the potential DOHUnit asap, it's ours now
577 and since we just marked the state as unused,
578 someone could overwrite it. */
579 ids->du = nullptr;
580 /* we only decrement the outstanding counter if the value was not
581 altered in the meantime, which would mean that the state has been actively reused
582 and the other thread has not incremented the outstanding counter, so we don't
583 want it to be decremented twice. */
584 --dss->outstanding; // you'd think an attacker could game this, but we're using connected socket
585 } else {
586 /* someone updated the state in the meantime, we can't touch the existing pointer */
587 du = nullptr;
588 /* since the state has been updated, we can't safely access it so let's just drop
589 this response */
590 continue;
591 }
592
593 if(dh->tc && g_truncateTC) {
594 truncateTC(response, &responseLen, responseSize, consumed);
595 }
596
597 dh->id = ids->origID;
598
599 uint16_t addRoom = 0;
600 DNSResponse dr = makeDNSResponseFromIDState(*ids, dh, sizeof(packet), responseLen, false);
601 if (dr.dnsCryptQuery) {
602 addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
603 }
604
605 memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH));
606 if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, ids->cs && ids->cs->muted)) {
607 continue;
608 }
609
610 if (ids->cs && !ids->cs->muted) {
611 if (du) {
612 #ifdef HAVE_DNS_OVER_HTTPS
613 // DoH query
614 du->response = std::string(response, responseLen);
615 if (send(du->rsock, &du, sizeof(du), 0) != sizeof(du)) {
616 /* at this point we have the only remaining pointer on this
617 DOHUnit object since we did set ids->du to nullptr earlier */
618 delete du;
619 }
620 #endif /* HAVE_DNS_OVER_HTTPS */
621 du = nullptr;
622 }
623 else {
624 ComboAddress empty;
625 empty.sin4.sin_family = 0;
626 /* if ids->destHarvested is false, origDest holds the listening address.
627 We don't want to use that as a source since it could be 0.0.0.0 for example. */
628 sendUDPResponse(origFD, response, responseLen, dr.delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote);
629 }
630 }
631
632 ++g_stats.responses;
633
634 double udiff = ids->sentTime.udiff();
635 vinfolog("Got answer from %s, relayed to %s%s, took %f usec", dss->remote.toStringWithPort(), ids->origRemote.toStringWithPort(),
636 isDoH ? " (https)": "", udiff);
637
638 struct timespec ts;
639 gettime(&ts);
640 g_rings.insertResponse(ts, *dr.remote, *dr.qname, dr.qtype, static_cast<unsigned int>(udiff), static_cast<unsigned int>(got), cleartextDH, dss->remote);
641
642 switch (dh->rcode) {
643 case RCode::NXDomain:
644 ++g_stats.frontendNXDomain;
645 break;
646 case RCode::ServFail:
647 ++g_stats.servfailResponses;
648 ++g_stats.frontendServFail;
649 break;
650 case RCode::NoError:
651 ++g_stats.frontendNoError;
652 break;
653 }
654 dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff/128.0;
655
656 doLatencyStats(udiff);
657
658 rewrittenResponse.clear();
659 }
660 }
661 catch(const std::exception& e){
662 vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->remote.toStringWithPort(), queryId, e.what());
663 }
664 }
665 }
666 catch(const std::exception& e)
667 {
668 errlog("UDP responder thread died because of exception: %s", e.what());
669 }
670 catch(const PDNSException& e)
671 {
672 errlog("UDP responder thread died because of PowerDNS exception: %s", e.reason);
673 }
674 catch(...)
675 {
676 errlog("UDP responder thread died because of an exception: %s", "unknown");
677 }
678
679 bool DownstreamState::reconnect()
680 {
681 std::unique_lock<std::mutex> tl(connectLock, std::try_to_lock);
682 if (!tl.owns_lock()) {
683 /* we are already reconnecting */
684 return false;
685 }
686
687 connected = false;
688 for (auto& fd : sockets) {
689 if (fd != -1) {
690 if (sockets.size() > 1) {
691 std::lock_guard<std::mutex> lock(socketsLock);
692 mplexer->removeReadFD(fd);
693 }
694 /* shutdown() is needed to wake up recv() in the responderThread */
695 shutdown(fd, SHUT_RDWR);
696 close(fd);
697 fd = -1;
698 }
699 if (!IsAnyAddress(remote)) {
700 fd = SSocket(remote.sin4.sin_family, SOCK_DGRAM, 0);
701 if (!IsAnyAddress(sourceAddr)) {
702 SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
703 SBind(fd, sourceAddr);
704 }
705 try {
706 SConnect(fd, remote);
707 if (sockets.size() > 1) {
708 std::lock_guard<std::mutex> lock(socketsLock);
709 mplexer->addReadFD(fd, [](int, boost::any) {});
710 }
711 connected = true;
712 }
713 catch(const std::runtime_error& error) {
714 infolog("Error connecting to new server with address %s: %s", remote.toStringWithPort(), error.what());
715 connected = false;
716 break;
717 }
718 }
719 }
720
721 /* if at least one (re-)connection failed, close all sockets */
722 if (!connected) {
723 for (auto& fd : sockets) {
724 if (fd != -1) {
725 if (sockets.size() > 1) {
726 std::lock_guard<std::mutex> lock(socketsLock);
727 mplexer->removeReadFD(fd);
728 }
729 /* shutdown() is needed to wake up recv() in the responderThread */
730 shutdown(fd, SHUT_RDWR);
731 close(fd);
732 fd = -1;
733 }
734 }
735 }
736
737 return connected;
738 }
739 void DownstreamState::hash()
740 {
741 vinfolog("Computing hashes for id=%s and weight=%d", id, weight);
742 auto w = weight;
743 WriteLock wl(&d_lock);
744 hashes.clear();
745 while (w > 0) {
746 std::string uuid = boost::str(boost::format("%s-%d") % id % w);
747 unsigned int wshash = burtleCI((const unsigned char*)uuid.c_str(), uuid.size(), g_hashperturb);
748 hashes.insert(wshash);
749 --w;
750 }
751 }
752
753 void DownstreamState::setId(const boost::uuids::uuid& newId)
754 {
755 id = newId;
756 // compute hashes only if already done
757 if (!hashes.empty()) {
758 hash();
759 }
760 }
761
762 void DownstreamState::setWeight(int newWeight)
763 {
764 if (newWeight < 1) {
765 errlog("Error setting server's weight: downstream weight value must be greater than 0.");
766 return ;
767 }
768 weight = newWeight;
769 if (!hashes.empty()) {
770 hash();
771 }
772 }
773
774 DownstreamState::DownstreamState(const ComboAddress& remote_, const ComboAddress& sourceAddr_, unsigned int sourceItf_, size_t numberOfSockets): remote(remote_), sourceAddr(sourceAddr_), sourceItf(sourceItf_)
775 {
776 pthread_rwlock_init(&d_lock, nullptr);
777 id = getUniqueID();
778 threadStarted.clear();
779
780 mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
781
782 sockets.resize(numberOfSockets);
783 for (auto& fd : sockets) {
784 fd = -1;
785 }
786
787 if (!IsAnyAddress(remote)) {
788 reconnect();
789 idStates.resize(g_maxOutstanding);
790 sw.start();
791 infolog("Added downstream server %s", remote.toStringWithPort());
792 }
793
794 }
795
796 std::mutex g_luamutex;
797 LuaContext g_lua;
798
799 GlobalStateHolder<ServerPolicy> g_policy;
800
801 shared_ptr<DownstreamState> firstAvailable(const NumberedServerVector& servers, const DNSQuestion* dq)
802 {
803 for(auto& d : servers) {
804 if(d.second->isUp() && d.second->qps.check())
805 return d.second;
806 }
807 return leastOutstanding(servers, dq);
808 }
809
810 // get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest
811 shared_ptr<DownstreamState> leastOutstanding(const NumberedServerVector& servers, const DNSQuestion* dq)
812 {
813 if (servers.size() == 1 && servers[0].second->isUp()) {
814 return servers[0].second;
815 }
816
817 vector<pair<tuple<int,int,double>, shared_ptr<DownstreamState>>> poss;
818 /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort,
819 which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */
820 poss.reserve(servers.size());
821 for(auto& d : servers) {
822 if(d.second->isUp()) {
823 poss.push_back({make_tuple(d.second->outstanding.load(), d.second->order, d.second->latencyUsec), d.second});
824 }
825 }
826 if(poss.empty())
827 return shared_ptr<DownstreamState>();
828 nth_element(poss.begin(), poss.begin(), poss.end(), [](const decltype(poss)::value_type& a, const decltype(poss)::value_type& b) { return a.first < b.first; });
829 return poss.begin()->second;
830 }
831
832 shared_ptr<DownstreamState> valrandom(unsigned int val, const NumberedServerVector& servers, const DNSQuestion* dq)
833 {
834 vector<pair<int, shared_ptr<DownstreamState>>> poss;
835 int sum = 0;
836 int max = std::numeric_limits<int>::max();
837
838 for(auto& d : servers) { // w=1, w=10 -> 1, 11
839 if(d.second->isUp()) {
840 // Don't overflow sum when adding high weights
841 if(d.second->weight > max - sum) {
842 sum = max;
843 } else {
844 sum += d.second->weight;
845 }
846
847 poss.push_back({sum, d.second});
848 }
849 }
850
851 // Catch poss & sum are empty to avoid SIGFPE
852 if(poss.empty())
853 return shared_ptr<DownstreamState>();
854
855 int r = val % sum;
856 auto p = upper_bound(poss.begin(), poss.end(),r, [](int r_, const decltype(poss)::value_type& a) { return r_ < a.first;});
857 if(p==poss.end())
858 return shared_ptr<DownstreamState>();
859 return p->second;
860 }
861
862 shared_ptr<DownstreamState> wrandom(const NumberedServerVector& servers, const DNSQuestion* dq)
863 {
864 return valrandom(random(), servers, dq);
865 }
866
867 uint32_t g_hashperturb;
868 shared_ptr<DownstreamState> whashed(const NumberedServerVector& servers, const DNSQuestion* dq)
869 {
870 return valrandom(dq->qname->hash(g_hashperturb), servers, dq);
871 }
872
873 shared_ptr<DownstreamState> chashed(const NumberedServerVector& servers, const DNSQuestion* dq)
874 {
875 unsigned int qhash = dq->qname->hash(g_hashperturb);
876 unsigned int sel = std::numeric_limits<unsigned int>::max();
877 unsigned int min = std::numeric_limits<unsigned int>::max();
878 shared_ptr<DownstreamState> ret = nullptr, first = nullptr;
879
880 for (const auto& d: servers) {
881 if (d.second->isUp()) {
882 // make sure hashes have been computed
883 if (d.second->hashes.empty()) {
884 d.second->hash();
885 }
886 {
887 ReadLock rl(&(d.second->d_lock));
888 const auto& server = d.second;
889 // we want to keep track of the last hash
890 if (min > *(server->hashes.begin())) {
891 min = *(server->hashes.begin());
892 first = server;
893 }
894
895 auto hash_it = server->hashes.lower_bound(qhash);
896 if (hash_it != server->hashes.end()) {
897 if (*hash_it < sel) {
898 sel = *hash_it;
899 ret = server;
900 }
901 }
902 }
903 }
904 }
905 if (ret != nullptr) {
906 return ret;
907 }
908 if (first != nullptr) {
909 return first;
910 }
911 return shared_ptr<DownstreamState>();
912 }
913
914 shared_ptr<DownstreamState> roundrobin(const NumberedServerVector& servers, const DNSQuestion* dq)
915 {
916 NumberedServerVector poss;
917
918 for(auto& d : servers) {
919 if(d.second->isUp()) {
920 poss.push_back(d);
921 }
922 }
923
924 const auto *res=&poss;
925 if(poss.empty() && !g_roundrobinFailOnNoServer)
926 res = &servers;
927
928 if(res->empty())
929 return shared_ptr<DownstreamState>();
930
931 static unsigned int counter;
932
933 return (*res)[(counter++) % res->size()].second;
934 }
935
936 ComboAddress g_serverControl{"127.0.0.1:5199"};
937
938 std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName)
939 {
940 std::shared_ptr<ServerPool> pool;
941 pools_t::iterator it = pools.find(poolName);
942 if (it != pools.end()) {
943 pool = it->second;
944 }
945 else {
946 if (!poolName.empty())
947 vinfolog("Creating pool %s", poolName);
948 pool = std::make_shared<ServerPool>();
949 pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool));
950 }
951 return pool;
952 }
953
954 void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy)
955 {
956 std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
957 if (!poolName.empty()) {
958 vinfolog("Setting pool %s server selection policy to %s", poolName, policy->name);
959 } else {
960 vinfolog("Setting default pool server selection policy to %s", policy->name);
961 }
962 pool->policy = policy;
963 }
964
965 void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
966 {
967 std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName);
968 if (!poolName.empty()) {
969 vinfolog("Adding server to pool %s", poolName);
970 } else {
971 vinfolog("Adding server to default pool");
972 }
973 pool->addServer(server);
974 }
975
976 void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server)
977 {
978 std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
979
980 if (!poolName.empty()) {
981 vinfolog("Removing server from pool %s", poolName);
982 }
983 else {
984 vinfolog("Removing server from default pool");
985 }
986
987 pool->removeServer(server);
988 }
989
990 std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName)
991 {
992 pools_t::const_iterator it = pools.find(poolName);
993
994 if (it == pools.end()) {
995 throw std::out_of_range("No pool named " + poolName);
996 }
997
998 return it->second;
999 }
1000
1001 NumberedServerVector getDownstreamCandidates(const pools_t& pools, const std::string& poolName)
1002 {
1003 std::shared_ptr<ServerPool> pool = getPool(pools, poolName);
1004 return pool->getServers();
1005 }
1006
1007 static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent)
1008 {
1009 string result;
1010
1011 std::vector<std::string> addrs;
1012 stringtok(addrs, spoofContent, " ,");
1013
1014 if (addrs.size() == 1) {
1015 try {
1016 ComboAddress spoofAddr(spoofContent);
1017 SpoofAction sa({spoofAddr});
1018 sa(&dq, &result);
1019 }
1020 catch(const PDNSException &e) {
1021 SpoofAction sa(spoofContent); // CNAME then
1022 sa(&dq, &result);
1023 }
1024 } else {
1025 std::vector<ComboAddress> cas;
1026 for (const auto& addr : addrs) {
1027 try {
1028 cas.push_back(ComboAddress(addr));
1029 }
1030 catch (...) {
1031 }
1032 }
1033 SpoofAction sa(cas);
1034 sa(&dq, &result);
1035 }
1036 }
1037
1038 bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop)
1039 {
1040 switch(action) {
1041 case DNSAction::Action::Allow:
1042 return true;
1043 break;
1044 case DNSAction::Action::Drop:
1045 ++g_stats.ruleDrop;
1046 drop = true;
1047 return true;
1048 break;
1049 case DNSAction::Action::Nxdomain:
1050 dq.dh->rcode = RCode::NXDomain;
1051 dq.dh->qr=true;
1052 ++g_stats.ruleNXDomain;
1053 return true;
1054 break;
1055 case DNSAction::Action::Refused:
1056 dq.dh->rcode = RCode::Refused;
1057 dq.dh->qr=true;
1058 ++g_stats.ruleRefused;
1059 return true;
1060 break;
1061 case DNSAction::Action::ServFail:
1062 dq.dh->rcode = RCode::ServFail;
1063 dq.dh->qr=true;
1064 ++g_stats.ruleServFail;
1065 return true;
1066 break;
1067 case DNSAction::Action::Spoof:
1068 spoofResponseFromString(dq, ruleresult);
1069 return true;
1070 break;
1071 case DNSAction::Action::Truncate:
1072 dq.dh->tc = true;
1073 dq.dh->qr = true;
1074 return true;
1075 break;
1076 case DNSAction::Action::HeaderModify:
1077 return true;
1078 break;
1079 case DNSAction::Action::Pool:
1080 dq.poolname=ruleresult;
1081 return true;
1082 break;
1083 case DNSAction::Action::NoRecurse:
1084 dq.dh->rd = false;
1085 return true;
1086 break;
1087 /* non-terminal actions follow */
1088 case DNSAction::Action::Delay:
1089 dq.delayMsec = static_cast<int>(pdns_stou(ruleresult)); // sorry
1090 break;
1091 case DNSAction::Action::None:
1092 /* fall-through */
1093 case DNSAction::Action::NoOp:
1094 break;
1095 }
1096
1097 /* false means that we don't stop the processing */
1098 return false;
1099 }
1100
1101
1102 static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now)
1103 {
1104 g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh);
1105
1106 if(g_qcount.enabled) {
1107 string qname = (*dq.qname).toLogString();
1108 bool countQuery{true};
1109 if(g_qcount.filter) {
1110 std::lock_guard<std::mutex> lock(g_luamutex);
1111 std::tie (countQuery, qname) = g_qcount.filter(&dq);
1112 }
1113
1114 if(countQuery) {
1115 WriteLock wl(&g_qcount.queryLock);
1116 if(!g_qcount.records.count(qname)) {
1117 g_qcount.records[qname] = 0;
1118 }
1119 g_qcount.records[qname]++;
1120 }
1121 }
1122
1123 if(auto got = holders.dynNMGBlock->lookup(*dq.remote)) {
1124 auto updateBlockStats = [&got]() {
1125 ++g_stats.dynBlocked;
1126 got->second.blocks++;
1127 };
1128
1129 if(now < got->second.until) {
1130 DNSAction::Action action = got->second.action;
1131 if (action == DNSAction::Action::None) {
1132 action = g_dynBlockAction;
1133 }
1134 switch (action) {
1135 case DNSAction::Action::NoOp:
1136 /* do nothing */
1137 break;
1138
1139 case DNSAction::Action::Nxdomain:
1140 vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort());
1141 updateBlockStats();
1142
1143 dq.dh->rcode = RCode::NXDomain;
1144 dq.dh->qr=true;
1145 return true;
1146
1147 case DNSAction::Action::Refused:
1148 vinfolog("Query from %s refused because of dynamic block", dq.remote->toStringWithPort());
1149 updateBlockStats();
1150
1151 dq.dh->rcode = RCode::Refused;
1152 dq.dh->qr = true;
1153 return true;
1154
1155 case DNSAction::Action::Truncate:
1156 if(!dq.tcp) {
1157 updateBlockStats();
1158 vinfolog("Query from %s truncated because of dynamic block", dq.remote->toStringWithPort());
1159 dq.dh->tc = true;
1160 dq.dh->qr = true;
1161 return true;
1162 }
1163 else {
1164 vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
1165 }
1166 break;
1167 case DNSAction::Action::NoRecurse:
1168 updateBlockStats();
1169 vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort());
1170 dq.dh->rd = false;
1171 return true;
1172 default:
1173 updateBlockStats();
1174 vinfolog("Query from %s dropped because of dynamic block", dq.remote->toStringWithPort());
1175 return false;
1176 }
1177 }
1178 }
1179
1180 if(auto got = holders.dynSMTBlock->lookup(*dq.qname)) {
1181 auto updateBlockStats = [&got]() {
1182 ++g_stats.dynBlocked;
1183 got->blocks++;
1184 };
1185
1186 if(now < got->until) {
1187 DNSAction::Action action = got->action;
1188 if (action == DNSAction::Action::None) {
1189 action = g_dynBlockAction;
1190 }
1191 switch (action) {
1192 case DNSAction::Action::NoOp:
1193 /* do nothing */
1194 break;
1195 case DNSAction::Action::Nxdomain:
1196 vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
1197 updateBlockStats();
1198
1199 dq.dh->rcode = RCode::NXDomain;
1200 dq.dh->qr=true;
1201 return true;
1202 case DNSAction::Action::Refused:
1203 vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
1204 updateBlockStats();
1205
1206 dq.dh->rcode = RCode::Refused;
1207 dq.dh->qr=true;
1208 return true;
1209 case DNSAction::Action::Truncate:
1210 if(!dq.tcp) {
1211 updateBlockStats();
1212
1213 vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
1214 dq.dh->tc = true;
1215 dq.dh->qr = true;
1216 return true;
1217 }
1218 else {
1219 vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
1220 }
1221 break;
1222 case DNSAction::Action::NoRecurse:
1223 updateBlockStats();
1224 vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort());
1225 dq.dh->rd = false;
1226 return true;
1227 default:
1228 updateBlockStats();
1229 vinfolog("Query from %s for %s dropped because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString());
1230 return false;
1231 }
1232 }
1233 }
1234
1235 DNSAction::Action action=DNSAction::Action::None;
1236 string ruleresult;
1237 bool drop = false;
1238 for(const auto& lr : *holders.rulactions) {
1239 if(lr.d_rule->matches(&dq)) {
1240 lr.d_rule->d_matches++;
1241 action=(*lr.d_action)(&dq, &ruleresult);
1242 if (processRulesResult(action, dq, ruleresult, drop)) {
1243 break;
1244 }
1245 }
1246 }
1247
1248 if (drop) {
1249 return false;
1250 }
1251
1252 return true;
1253 }
1254
1255 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck)
1256 {
1257 ssize_t result;
1258
1259 if (ss->sourceItf == 0) {
1260 result = send(sd, request, requestLen, 0);
1261 }
1262 else {
1263 struct msghdr msgh;
1264 struct iovec iov;
1265 cmsgbuf_aligned cbuf;
1266 ComboAddress remote(ss->remote);
1267 fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), const_cast<char*>(request), requestLen, &remote);
1268 addCMsgSrcAddr(&msgh, &cbuf, &ss->sourceAddr, ss->sourceItf);
1269 result = sendmsg(sd, &msgh, 0);
1270 }
1271
1272 if (result == -1) {
1273 int savederrno = errno;
1274 vinfolog("Error sending request to backend %s: %d", ss->remote.toStringWithPort(), savederrno);
1275
1276 /* This might sound silly, but on Linux send() might fail with EINVAL
1277 if the interface the socket was bound to doesn't exist anymore.
1278 We don't want to reconnect the real socket if the healthcheck failed,
1279 because it's not using the same socket.
1280 */
1281 if (!healthCheck && (savederrno == EINVAL || savederrno == ENODEV)) {
1282 ss->reconnect();
1283 }
1284 }
1285
1286 return result;
1287 }
1288
1289 static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest)
1290 {
1291 if (msgh->msg_flags & MSG_TRUNC) {
1292 /* message was too large for our buffer */
1293 vinfolog("Dropping message too large for our buffer");
1294 ++g_stats.nonCompliantQueries;
1295 return false;
1296 }
1297
1298 if(!holders.acl->match(remote)) {
1299 vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort());
1300 ++g_stats.aclDrops;
1301 return false;
1302 }
1303
1304 cs.queries++;
1305 ++g_stats.queries;
1306
1307 if (HarvestDestinationAddress(msgh, &dest)) {
1308 /* we don't get the port, only the address */
1309 dest.sin4.sin_port = cs.local.sin4.sin_port;
1310 }
1311 else {
1312 dest.sin4.sin_family = 0;
1313 }
1314
1315 return true;
1316 }
1317
1318 boost::optional<std::vector<uint8_t>> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
1319 {
1320 if (cs.dnscryptCtx) {
1321 #ifdef HAVE_DNSCRYPT
1322 vector<uint8_t> response;
1323 uint16_t decryptedQueryLen = 0;
1324
1325 dnsCryptQuery = std::make_shared<DNSCryptQuery>(cs.dnscryptCtx);
1326
1327 bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, tcp, now, response);
1328
1329 if (!decrypted) {
1330 if (response.size() > 0) {
1331 return response;
1332 }
1333 throw std::runtime_error("Unable to decrypt DNSCrypt query, dropping.");
1334 }
1335
1336 len = decryptedQueryLen;
1337 #endif /* HAVE_DNSCRYPT */
1338 }
1339 return boost::none;
1340 }
1341
1342 bool checkQueryHeaders(const struct dnsheader* dh)
1343 {
1344 if (dh->qr) { // don't respond to responses
1345 ++g_stats.nonCompliantQueries;
1346 return false;
1347 }
1348
1349 if (dh->qdcount == 0) {
1350 ++g_stats.emptyQueries;
1351 return false;
1352 }
1353
1354 if (dh->rd) {
1355 ++g_stats.rdQueries;
1356 }
1357
1358 return true;
1359 }
1360
1361 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
1362 static void queueResponse(const ClientState& cs, const char* response, uint16_t responseLen, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, cmsgbuf_aligned* cbuf)
1363 {
1364 outMsg.msg_len = 0;
1365 fillMSGHdr(&outMsg.msg_hdr, iov, nullptr, 0, const_cast<char*>(response), responseLen, const_cast<ComboAddress*>(&remote));
1366
1367 if (dest.sin4.sin_family == 0) {
1368 outMsg.msg_hdr.msg_control = nullptr;
1369 }
1370 else {
1371 addCMsgSrcAddr(&outMsg.msg_hdr, cbuf, &dest, 0);
1372 }
1373 }
1374 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1375
1376 /* self-generated responses or cache hits */
1377 static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit)
1378 {
1379 DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast<dnsheader*>(dq.dh), dq.size, dq.len, dq.tcp, dq.queryTime);
1380
1381 #ifdef HAVE_PROTOBUF
1382 dr.uniqueId = dq.uniqueId;
1383 #endif
1384 dr.qTag = dq.qTag;
1385 dr.delayMsec = dq.delayMsec;
1386
1387 if (!applyRulesToResponse(cacheHit ? holders.cacheHitRespRulactions : holders.selfAnsweredRespRulactions, dr)) {
1388 return false;
1389 }
1390
1391 /* in case a rule changed it */
1392 dq.delayMsec = dr.delayMsec;
1393
1394 #ifdef HAVE_DNSCRYPT
1395 if (!cs.muted) {
1396 if (!encryptResponse(reinterpret_cast<char*>(dq.dh), &dq.len, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) {
1397 return false;
1398 }
1399 }
1400 #endif /* HAVE_DNSCRYPT */
1401
1402 if (cacheHit) {
1403 ++g_stats.cacheHits;
1404 }
1405
1406 switch (dr.dh->rcode) {
1407 case RCode::NXDomain:
1408 ++g_stats.frontendNXDomain;
1409 break;
1410 case RCode::ServFail:
1411 ++g_stats.frontendServFail;
1412 break;
1413 case RCode::NoError:
1414 ++g_stats.frontendNoError;
1415 break;
1416 }
1417
1418 doLatencyStats(0); // we're not going to measure this
1419 return true;
1420 }
1421
1422 ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr<DownstreamState>& selectedBackend)
1423 {
1424 const uint16_t queryId = ntohs(dq.dh->id);
1425
1426 try {
1427 /* we need an accurate ("real") value for the response and
1428 to store into the IDS, but not for insertion into the
1429 rings for example */
1430 struct timespec now;
1431 gettime(&now);
1432
1433 if (!applyRulesToQuery(holders, dq, now)) {
1434 return ProcessQueryResult::Drop;
1435 }
1436
1437 if(dq.dh->qr) { // something turned it into a response
1438 fixUpQueryTurnedResponse(dq, dq.origFlags);
1439
1440 if (!prepareOutgoingResponse(holders, cs, dq, false)) {
1441 return ProcessQueryResult::Drop;
1442 }
1443
1444 ++g_stats.selfAnswered;
1445 return ProcessQueryResult::SendAnswer;
1446 }
1447
1448 std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, dq.poolname);
1449 dq.packetCache = serverPool->packetCache;
1450 auto policy = *(holders.policy);
1451 if (serverPool->policy != nullptr) {
1452 policy = *(serverPool->policy);
1453 }
1454 auto servers = serverPool->getServers();
1455 if (policy.isLua) {
1456 std::lock_guard<std::mutex> lock(g_luamutex);
1457 selectedBackend = policy.policy(servers, &dq);
1458 }
1459 else {
1460 selectedBackend = policy.policy(servers, &dq);
1461 }
1462
1463 uint16_t cachedResponseSize = dq.size;
1464 uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL;
1465
1466 if (dq.packetCache && !dq.skipCache) {
1467 dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO);
1468 }
1469
1470 if (dq.useECS && ((selectedBackend && selectedBackend->useECS) || (!selectedBackend && serverPool->getECS()))) {
1471 // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope
1472 if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) {
1473 if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) {
1474 dq.len = cachedResponseSize;
1475
1476 if (!prepareOutgoingResponse(holders, cs, dq, true)) {
1477 return ProcessQueryResult::Drop;
1478 }
1479
1480 return ProcessQueryResult::SendAnswer;
1481 }
1482
1483 if (!dq.subnet) {
1484 /* there was no existing ECS on the query, enable the zero-scope feature */
1485 dq.useZeroScope = true;
1486 }
1487 }
1488
1489 if (!handleEDNSClientSubnet(dq, &(dq.ednsAdded), &(dq.ecsAdded), g_preserveTrailingData)) {
1490 vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort());
1491 return ProcessQueryResult::Drop;
1492 }
1493 }
1494
1495 if (dq.packetCache && !dq.skipCache) {
1496 if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast<char*>(dq.dh), &cachedResponseSize, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) {
1497 dq.len = cachedResponseSize;
1498
1499 if (!prepareOutgoingResponse(holders, cs, dq, true)) {
1500 return ProcessQueryResult::Drop;
1501 }
1502
1503 return ProcessQueryResult::SendAnswer;
1504 }
1505 ++g_stats.cacheMisses;
1506 }
1507
1508 if(!selectedBackend) {
1509 ++g_stats.noPolicy;
1510
1511 vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toLogString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort());
1512 if (g_servFailOnNoPolicy) {
1513 restoreFlags(dq.dh, dq.origFlags);
1514
1515 dq.dh->rcode = RCode::ServFail;
1516 dq.dh->qr = true;
1517
1518 if (!prepareOutgoingResponse(holders, cs, dq, false)) {
1519 return ProcessQueryResult::Drop;
1520 }
1521 // no response-only statistics counter to update.
1522 return ProcessQueryResult::SendAnswer;
1523 }
1524
1525 return ProcessQueryResult::Drop;
1526 }
1527
1528 if (dq.addXPF && selectedBackend->xpfRRCode != 0) {
1529 addXPF(dq, selectedBackend->xpfRRCode, g_preserveTrailingData);
1530 }
1531
1532 selectedBackend->queries++;
1533 return ProcessQueryResult::PassToBackend;
1534 }
1535 catch(const std::exception& e){
1536 vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what());
1537 }
1538 return ProcessQueryResult::Drop;
1539 }
1540
1541 static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, cmsgbuf_aligned* respCBuf)
1542 {
1543 assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr));
1544 uint16_t queryId = 0;
1545
1546 try {
1547 if (!isUDPQueryAcceptable(cs, holders, msgh, remote, dest)) {
1548 return;
1549 }
1550
1551 /* we need an accurate ("real") value for the response and
1552 to store into the IDS, but not for insertion into the
1553 rings for example */
1554 struct timespec queryRealTime;
1555 gettime(&queryRealTime, true);
1556
1557 std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
1558 auto dnsCryptResponse = checkDNSCryptQuery(cs, query, len, dnsCryptQuery, queryRealTime.tv_sec, false);
1559 if (dnsCryptResponse) {
1560 sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dnsCryptResponse->data()), static_cast<uint16_t>(dnsCryptResponse->size()), 0, dest, remote);
1561 return;
1562 }
1563
1564 struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query);
1565 queryId = ntohs(dh->id);
1566
1567 if (!checkQueryHeaders(dh)) {
1568 return;
1569 }
1570
1571 uint16_t qtype, qclass;
1572 unsigned int consumed = 0;
1573 DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
1574 DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime);
1575 dq.dnsCryptQuery = std::move(dnsCryptQuery);
1576 std::shared_ptr<DownstreamState> ss{nullptr};
1577 auto result = processQuery(dq, cs, holders, ss);
1578
1579 if (result == ProcessQueryResult::Drop) {
1580 return;
1581 }
1582
1583 if (result == ProcessQueryResult::SendAnswer) {
1584 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
1585 if (dq.delayMsec == 0 && responsesVect != nullptr) {
1586 queueResponse(cs, reinterpret_cast<char*>(dq.dh), dq.len, *dq.local, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf);
1587 (*queuedResponses)++;
1588 return;
1589 }
1590 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1591 /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */
1592 sendUDPResponse(cs.udpFD, reinterpret_cast<char*>(dq.dh), dq.len, dq.delayMsec, dest, *dq.remote);
1593 return;
1594 }
1595
1596 if (result != ProcessQueryResult::PassToBackend || ss == nullptr) {
1597 return;
1598 }
1599
1600 unsigned int idOffset = (ss->idOffset++) % ss->idStates.size();
1601 IDState* ids = &ss->idStates[idOffset];
1602 ids->age = 0;
1603 DOHUnit* du = nullptr;
1604
1605 /* that means that the state was in use, possibly with an allocated
1606 DOHUnit that we will need to handle, but we can't touch it before
1607 confirming that we now own this state */
1608 if (ids->isInUse()) {
1609 du = ids->du;
1610 }
1611
1612 /* we atomically replace the value, we now own this state */
1613 if (!ids->markAsUsed()) {
1614 /* the state was not in use.
1615 we reset 'du' because it might have still been in use when we read it. */
1616 du = nullptr;
1617 ++ss->outstanding;
1618 }
1619 else {
1620 /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need
1621 to handle it because it's about to be overwritten. */
1622 ids->du = nullptr;
1623 ++ss->reuseds;
1624 ++g_stats.downstreamTimeouts;
1625 handleDOHTimeout(du);
1626 }
1627
1628 ids->cs = &cs;
1629 ids->origFD = cs.udpFD;
1630 ids->origID = dh->id;
1631 setIDStateFromDNSQuestion(*ids, dq, std::move(qname));
1632
1633 /* If we couldn't harvest the real dest addr, still
1634 write down the listening addr since it will be useful
1635 (especially if it's not an 'any' one).
1636 We need to keep track of which one it is since we may
1637 want to use the real but not the listening addr to reply.
1638 */
1639 if (dest.sin4.sin_family != 0) {
1640 ids->origDest = dest;
1641 ids->destHarvested = true;
1642 }
1643 else {
1644 ids->origDest = cs.local;
1645 ids->destHarvested = false;
1646 }
1647
1648 dh->id = idOffset;
1649
1650 int fd = pickBackendSocketForSending(ss);
1651 ssize_t ret = udpClientSendRequestToBackend(ss, fd, query, dq.len);
1652
1653 if(ret < 0) {
1654 ++ss->sendErrors;
1655 ++g_stats.downstreamSendErrors;
1656 }
1657
1658 vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toLogString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
1659 }
1660 catch(const std::exception& e){
1661 vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
1662 }
1663 }
1664
1665 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
1666 static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders)
1667 {
1668 struct MMReceiver
1669 {
1670 char packet[4096];
1671 ComboAddress remote;
1672 ComboAddress dest;
1673 struct iovec iov;
1674 /* used by HarvestDestinationAddress */
1675 cmsgbuf_aligned cbuf;
1676 };
1677 const size_t vectSize = g_udpVectorSize;
1678 /* the actual buffer is larger because:
1679 - we may have to add EDNS and/or ECS
1680 - we use it for self-generated responses (from rule or cache)
1681 but we only accept incoming payloads up to that size
1682 */
1683 static_assert(s_udpIncomingBufferSize <= sizeof(MMReceiver::packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)");
1684
1685 auto recvData = std::unique_ptr<MMReceiver[]>(new MMReceiver[vectSize]);
1686 auto msgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
1687 auto outMsgVec = std::unique_ptr<struct mmsghdr[]>(new struct mmsghdr[vectSize]);
1688
1689 /* initialize the structures needed to receive our messages */
1690 for (size_t idx = 0; idx < vectSize; idx++) {
1691 recvData[idx].remote.sin4.sin_family = cs->local.sin4.sin_family;
1692 fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, &recvData[idx].cbuf, sizeof(recvData[idx].cbuf), recvData[idx].packet, s_udpIncomingBufferSize, &recvData[idx].remote);
1693 }
1694
1695 /* go now */
1696 for(;;) {
1697
1698 /* reset the IO vector, since it's also used to send the vector of responses
1699 to avoid having to copy the data around */
1700 for (size_t idx = 0; idx < vectSize; idx++) {
1701 recvData[idx].iov.iov_base = recvData[idx].packet;
1702 recvData[idx].iov.iov_len = sizeof(recvData[idx].packet);
1703 }
1704
1705 /* block until we have at least one message ready, but return
1706 as many as possible to save the syscall costs */
1707 int msgsGot = recvmmsg(cs->udpFD, msgVec.get(), vectSize, MSG_WAITFORONE | MSG_TRUNC, nullptr);
1708
1709 if (msgsGot <= 0) {
1710 vinfolog("Getting UDP messages via recvmmsg() failed with: %s", strerror(errno));
1711 continue;
1712 }
1713
1714 unsigned int msgsToSend = 0;
1715
1716 /* process the received messages */
1717 for (int msgIdx = 0; msgIdx < msgsGot; msgIdx++) {
1718 const struct msghdr* msgh = &msgVec[msgIdx].msg_hdr;
1719 unsigned int got = msgVec[msgIdx].msg_len;
1720 const ComboAddress& remote = recvData[msgIdx].remote;
1721
1722 if (static_cast<size_t>(got) < sizeof(struct dnsheader)) {
1723 ++g_stats.nonCompliantQueries;
1724 continue;
1725 }
1726
1727 processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, static_cast<uint16_t>(got), sizeof(recvData[msgIdx].packet), outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, &recvData[msgIdx].cbuf);
1728
1729 }
1730
1731 /* immediate (not delayed or sent to a backend) responses (mostly from a rule, dynamic block
1732 or the cache) can be sent in batch too */
1733
1734 if (msgsToSend > 0 && msgsToSend <= static_cast<unsigned int>(msgsGot)) {
1735 int sent = sendmmsg(cs->udpFD, outMsgVec.get(), msgsToSend, 0);
1736
1737 if (sent < 0 || static_cast<unsigned int>(sent) != msgsToSend) {
1738 vinfolog("Error sending responses with sendmmsg() (%d on %u): %s", sent, msgsToSend, strerror(errno));
1739 }
1740 }
1741
1742 }
1743 }
1744 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1745
1746 // listens to incoming queries, sends out to downstream servers, noting the intended return path
1747 static void udpClientThread(ClientState* cs)
1748 try
1749 {
1750 setThreadName("dnsdist/udpClie");
1751 LocalHolders holders;
1752
1753 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
1754 if (g_udpVectorSize > 1) {
1755 MultipleMessagesUDPClientThread(cs, holders);
1756
1757 }
1758 else
1759 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
1760 {
1761 char packet[4096];
1762 /* the actual buffer is larger because:
1763 - we may have to add EDNS and/or ECS
1764 - we use it for self-generated responses (from rule or cache)
1765 but we only accept incoming payloads up to that size
1766 */
1767 static_assert(s_udpIncomingBufferSize <= sizeof(packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)");
1768 struct msghdr msgh;
1769 struct iovec iov;
1770 /* used by HarvestDestinationAddress */
1771 cmsgbuf_aligned cbuf;
1772
1773 ComboAddress remote;
1774 ComboAddress dest;
1775 remote.sin4.sin_family = cs->local.sin4.sin_family;
1776 fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), packet, sizeof(packet), &remote);
1777
1778 for(;;) {
1779 ssize_t got = recvmsg(cs->udpFD, &msgh, 0);
1780
1781 if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) {
1782 ++g_stats.nonCompliantQueries;
1783 continue;
1784 }
1785
1786 processUDPQuery(*cs, holders, &msgh, remote, dest, packet, static_cast<uint16_t>(got), s_udpIncomingBufferSize, nullptr, nullptr, nullptr, nullptr);
1787 }
1788 }
1789 }
1790 catch(const std::exception &e)
1791 {
1792 errlog("UDP client thread died because of exception: %s", e.what());
1793 }
1794 catch(const PDNSException &e)
1795 {
1796 errlog("UDP client thread died because of PowerDNS exception: %s", e.reason);
1797 }
1798 catch(...)
1799 {
1800 errlog("UDP client thread died because of an exception: %s", "unknown");
1801 }
1802
1803 uint16_t getRandomDNSID()
1804 {
1805 #ifdef HAVE_LIBSODIUM
1806 return (randombytes_random() % 65536);
1807 #else
1808 return (random() % 65536);
1809 #endif
1810 }
1811
1812 static bool upCheck(const shared_ptr<DownstreamState>& ds)
1813 try
1814 {
1815 DNSName checkName = ds->checkName;
1816 uint16_t checkType = ds->checkType.getCode();
1817 uint16_t checkClass = ds->checkClass;
1818 dnsheader checkHeader;
1819 memset(&checkHeader, 0, sizeof(checkHeader));
1820
1821 checkHeader.qdcount = htons(1);
1822 checkHeader.id = getRandomDNSID();
1823
1824 checkHeader.rd = true;
1825 if (ds->setCD) {
1826 checkHeader.cd = true;
1827 }
1828
1829 if (ds->checkFunction) {
1830 std::lock_guard<std::mutex> lock(g_luamutex);
1831 auto ret = ds->checkFunction(checkName, checkType, checkClass, &checkHeader);
1832 checkName = std::get<0>(ret);
1833 checkType = std::get<1>(ret);
1834 checkClass = std::get<2>(ret);
1835 }
1836
1837 vector<uint8_t> packet;
1838 DNSPacketWriter dpw(packet, checkName, checkType, checkClass);
1839 dnsheader * requestHeader = dpw.getHeader();
1840 *requestHeader = checkHeader;
1841
1842 Socket sock(ds->remote.sin4.sin_family, SOCK_DGRAM);
1843 sock.setNonBlocking();
1844 if (!IsAnyAddress(ds->sourceAddr)) {
1845 sock.setReuseAddr();
1846 sock.bind(ds->sourceAddr);
1847 }
1848 sock.connect(ds->remote);
1849 ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), reinterpret_cast<char*>(&packet[0]), packet.size(), true);
1850 if (sent < 0) {
1851 int ret = errno;
1852 if (g_verboseHealthChecks)
1853 infolog("Error while sending a health check query to backend %s: %d", ds->getNameWithAddr(), ret);
1854 return false;
1855 }
1856
1857 int ret = waitForRWData(sock.getHandle(), true, /* ms to seconds */ ds->checkTimeout / 1000, /* remaining ms to us */ (ds->checkTimeout % 1000) * 1000);
1858 if(ret < 0 || !ret) { // error, timeout, both are down!
1859 if (ret < 0) {
1860 ret = errno;
1861 if (g_verboseHealthChecks)
1862 infolog("Error while waiting for the health check response from backend %s: %d", ds->getNameWithAddr(), ret);
1863 }
1864 else {
1865 if (g_verboseHealthChecks)
1866 infolog("Timeout while waiting for the health check response from backend %s", ds->getNameWithAddr());
1867 }
1868 return false;
1869 }
1870
1871 string reply;
1872 ComboAddress from;
1873 sock.recvFrom(reply, from);
1874
1875 /* we are using a connected socket but hey.. */
1876 if (from != ds->remote) {
1877 if (g_verboseHealthChecks)
1878 infolog("Invalid health check response received from %s, expecting one from %s", from.toStringWithPort(), ds->remote.toStringWithPort());
1879 return false;
1880 }
1881
1882 const dnsheader * responseHeader = reinterpret_cast<const dnsheader *>(reply.c_str());
1883
1884 if (reply.size() < sizeof(*responseHeader)) {
1885 if (g_verboseHealthChecks)
1886 infolog("Invalid health check response of size %d from backend %s, expecting at least %d", reply.size(), ds->getNameWithAddr(), sizeof(*responseHeader));
1887 return false;
1888 }
1889
1890 if (responseHeader->id != requestHeader->id) {
1891 if (g_verboseHealthChecks)
1892 infolog("Invalid health check response id %d from backend %s, expecting %d", responseHeader->id, ds->getNameWithAddr(), requestHeader->id);
1893 return false;
1894 }
1895
1896 if (!responseHeader->qr) {
1897 if (g_verboseHealthChecks)
1898 infolog("Invalid health check response from backend %s, expecting QR to be set", ds->getNameWithAddr());
1899 return false;
1900 }
1901
1902 if (responseHeader->rcode == RCode::ServFail) {
1903 if (g_verboseHealthChecks)
1904 infolog("Backend %s responded to health check with ServFail", ds->getNameWithAddr());
1905 return false;
1906 }
1907
1908 if (ds->mustResolve && (responseHeader->rcode == RCode::NXDomain || responseHeader->rcode == RCode::Refused)) {
1909 if (g_verboseHealthChecks)
1910 infolog("Backend %s responded to health check with %s while mustResolve is set", ds->getNameWithAddr(), responseHeader->rcode == RCode::NXDomain ? "NXDomain" : "Refused");
1911 return false;
1912 }
1913
1914 uint16_t receivedType;
1915 uint16_t receivedClass;
1916 DNSName receivedName(reply.c_str(), reply.size(), sizeof(dnsheader), false, &receivedType, &receivedClass);
1917
1918 if (receivedName != checkName || receivedType != checkType || receivedClass != checkClass) {
1919 if (g_verboseHealthChecks)
1920 infolog("Backend %s responded to health check with an invalid qname (%s vs %s), qtype (%s vs %s) or qclass (%d vs %d)", ds->getNameWithAddr(), receivedName.toLogString(), checkName.toLogString(), QType(receivedType).getName(), QType(checkType).getName(), receivedClass, checkClass);
1921 return false;
1922 }
1923
1924 return true;
1925 }
1926 catch(const std::exception& e)
1927 {
1928 if (g_verboseHealthChecks)
1929 infolog("Error checking the health of backend %s: %s", ds->getNameWithAddr(), e.what());
1930 return false;
1931 }
1932 catch(...)
1933 {
1934 if (g_verboseHealthChecks)
1935 infolog("Unknown exception while checking the health of backend %s", ds->getNameWithAddr());
1936 return false;
1937 }
1938
1939 uint64_t g_maxTCPClientThreads{10};
1940 std::atomic<uint16_t> g_cacheCleaningDelay{60};
1941 std::atomic<uint16_t> g_cacheCleaningPercentage{100};
1942
1943 void maintThread()
1944 {
1945 setThreadName("dnsdist/main");
1946 int interval = 1;
1947 size_t counter = 0;
1948 int32_t secondsToWaitLog = 0;
1949
1950 for(;;) {
1951 sleep(interval);
1952
1953 {
1954 std::lock_guard<std::mutex> lock(g_luamutex);
1955 auto f = g_lua.readVariable<boost::optional<std::function<void()> > >("maintenance");
1956 if(f) {
1957 try {
1958 (*f)();
1959 secondsToWaitLog = 0;
1960 }
1961 catch(std::exception &e) {
1962 if (secondsToWaitLog <= 0) {
1963 infolog("Error during execution of maintenance function: %s", e.what());
1964 secondsToWaitLog = 61;
1965 }
1966 secondsToWaitLog -= interval;
1967 }
1968 }
1969 }
1970
1971 counter++;
1972 if (counter >= g_cacheCleaningDelay) {
1973 /* keep track, for each cache, of whether we should keep
1974 expired entries */
1975 std::map<std::shared_ptr<DNSDistPacketCache>, bool> caches;
1976
1977 /* gather all caches actually used by at least one pool, and see
1978 if something prevents us from cleaning the expired entries */
1979 auto localPools = g_pools.getLocal();
1980 for (const auto& entry : *localPools) {
1981 auto& pool = entry.second;
1982
1983 auto packetCache = pool->packetCache;
1984 if (!packetCache) {
1985 continue;
1986 }
1987
1988 auto pair = caches.insert({packetCache, false});
1989 auto& iter = pair.first;
1990 /* if we need to keep stale data for this cache (ie, not clear
1991 expired entries when at least one pool using this cache
1992 has all its backends down) */
1993 if (packetCache->keepStaleData() && iter->second == false) {
1994 /* so far all pools had at least one backend up */
1995 if (pool->countServers(true) == 0) {
1996 iter->second = true;
1997 }
1998 }
1999 }
2000
2001 for (auto pair : caches) {
2002 /* shall we keep expired entries ? */
2003 if (pair.second == true) {
2004 continue;
2005 }
2006 auto& packetCache = pair.first;
2007 size_t upTo = (packetCache->getMaxEntries()* (100 - g_cacheCleaningPercentage)) / 100;
2008 packetCache->purgeExpired(upTo);
2009 }
2010 counter = 0;
2011 }
2012
2013 // ponder pruning g_dynblocks of expired entries here
2014 }
2015 }
2016
2017 static void secPollThread()
2018 {
2019 setThreadName("dnsdist/secpoll");
2020
2021 for (;;) {
2022 try {
2023 doSecPoll(g_secPollSuffix);
2024 }
2025 catch(...) {
2026 }
2027 sleep(g_secPollInterval);
2028 }
2029 }
2030
2031 static void healthChecksThread()
2032 {
2033 setThreadName("dnsdist/healthC");
2034
2035 int interval = 1;
2036
2037 for(;;) {
2038 sleep(interval);
2039
2040 if(g_tcpclientthreads->getQueuedCount() > 1 && !g_tcpclientthreads->hasReachedMaxThreads())
2041 g_tcpclientthreads->addTCPClientThread();
2042
2043 auto states = g_dstates.getLocal(); // this points to the actual shared_ptrs!
2044 for(auto& dss : *states) {
2045 if(++dss->lastCheck < dss->checkInterval)
2046 continue;
2047 dss->lastCheck = 0;
2048 if(dss->availability==DownstreamState::Availability::Auto) {
2049 bool newState=upCheck(dss);
2050 if (newState) {
2051 /* check succeeded */
2052 dss->currentCheckFailures = 0;
2053
2054 if (!dss->upStatus) {
2055 /* we were marked as down */
2056 dss->consecutiveSuccessfulChecks++;
2057 if (dss->consecutiveSuccessfulChecks < dss->minRiseSuccesses) {
2058 /* if we need more than one successful check to rise
2059 and we didn't reach the threshold yet,
2060 let's stay down */
2061 newState = false;
2062 }
2063 }
2064 }
2065 else {
2066 /* check failed */
2067 dss->consecutiveSuccessfulChecks = 0;
2068
2069 if (dss->upStatus) {
2070 /* we are currently up */
2071 dss->currentCheckFailures++;
2072 if (dss->currentCheckFailures < dss->maxCheckFailures) {
2073 /* we need more than one failure to be marked as down,
2074 and we did not reach the threshold yet, let's stay down */
2075 newState = true;
2076 }
2077 }
2078 }
2079
2080 if(newState != dss->upStatus) {
2081 warnlog("Marking downstream %s as '%s'", dss->getNameWithAddr(), newState ? "up" : "down");
2082
2083 if (newState && !dss->connected) {
2084 newState = dss->reconnect();
2085
2086 if (dss->connected && !dss->threadStarted.test_and_set()) {
2087 dss->tid = thread(responderThread, dss);
2088 }
2089 }
2090
2091 dss->upStatus = newState;
2092 dss->currentCheckFailures = 0;
2093 dss->consecutiveSuccessfulChecks = 0;
2094 if (g_snmpAgent && g_snmpTrapsEnabled) {
2095 g_snmpAgent->sendBackendStatusChangeTrap(dss);
2096 }
2097 }
2098 }
2099
2100 auto delta = dss->sw.udiffAndSet()/1000000.0;
2101 dss->queryLoad = 1.0*(dss->queries.load() - dss->prev.queries.load())/delta;
2102 dss->dropRate = 1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta;
2103 dss->prev.queries.store(dss->queries.load());
2104 dss->prev.reuseds.store(dss->reuseds.load());
2105
2106 for(IDState& ids : dss->idStates) { // timeouts
2107 int64_t usageIndicator = ids.usageIndicator;
2108 if(IDState::isInUse(usageIndicator) && ids.age++ > g_udpTimeout) {
2109 /* We mark the state as unused as soon as possible
2110 to limit the risk of racing with the
2111 responder thread.
2112 */
2113 auto oldDU = ids.du;
2114
2115 if (!ids.tryMarkUnused(usageIndicator)) {
2116 /* this state has been altered in the meantime,
2117 don't go anywhere near it */
2118 continue;
2119 }
2120 ids.du = nullptr;
2121 handleDOHTimeout(oldDU);
2122 ids.age = 0;
2123 dss->reuseds++;
2124 --dss->outstanding;
2125 ++g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout
2126 vinfolog("Had a downstream timeout from %s (%s) for query for %s|%s from %s",
2127 dss->remote.toStringWithPort(), dss->name,
2128 ids.qname.toLogString(), QType(ids.qtype).getName(), ids.origRemote.toStringWithPort());
2129
2130 struct timespec ts;
2131 gettime(&ts);
2132
2133 struct dnsheader fake;
2134 memset(&fake, 0, sizeof(fake));
2135 fake.id = ids.origID;
2136
2137 g_rings.insertResponse(ts, ids.origRemote, ids.qname, ids.qtype, std::numeric_limits<unsigned int>::max(), 0, fake, dss->remote);
2138 }
2139 }
2140 }
2141 }
2142 }
2143
2144 static void bindAny(int af, int sock)
2145 {
2146 __attribute__((unused)) int one = 1;
2147
2148 #ifdef IP_FREEBIND
2149 if (setsockopt(sock, IPPROTO_IP, IP_FREEBIND, &one, sizeof(one)) < 0)
2150 warnlog("Warning: IP_FREEBIND setsockopt failed: %s", strerror(errno));
2151 #endif
2152
2153 #ifdef IP_BINDANY
2154 if (af == AF_INET)
2155 if (setsockopt(sock, IPPROTO_IP, IP_BINDANY, &one, sizeof(one)) < 0)
2156 warnlog("Warning: IP_BINDANY setsockopt failed: %s", strerror(errno));
2157 #endif
2158 #ifdef IPV6_BINDANY
2159 if (af == AF_INET6)
2160 if (setsockopt(sock, IPPROTO_IPV6, IPV6_BINDANY, &one, sizeof(one)) < 0)
2161 warnlog("Warning: IPV6_BINDANY setsockopt failed: %s", strerror(errno));
2162 #endif
2163 #ifdef SO_BINDANY
2164 if (setsockopt(sock, SOL_SOCKET, SO_BINDANY, &one, sizeof(one)) < 0)
2165 warnlog("Warning: SO_BINDANY setsockopt failed: %s", strerror(errno));
2166 #endif
2167 }
2168
2169 static void dropGroupPrivs(gid_t gid)
2170 {
2171 if (gid) {
2172 if (setgid(gid) == 0) {
2173 if (setgroups(0, NULL) < 0) {
2174 warnlog("Warning: Unable to drop supplementary gids: %s", strerror(errno));
2175 }
2176 }
2177 else {
2178 warnlog("Warning: Unable to set group ID to %d: %s", gid, strerror(errno));
2179 }
2180 }
2181 }
2182
2183 static void dropUserPrivs(uid_t uid)
2184 {
2185 if(uid) {
2186 if(setuid(uid) < 0) {
2187 warnlog("Warning: Unable to set user ID to %d: %s", uid, strerror(errno));
2188 }
2189 }
2190 }
2191
2192 static void checkFileDescriptorsLimits(size_t udpBindsCount, size_t tcpBindsCount)
2193 {
2194 /* stdin, stdout, stderr */
2195 size_t requiredFDsCount = 3;
2196 auto backends = g_dstates.getLocal();
2197 /* UDP sockets to backends */
2198 size_t backendUDPSocketsCount = 0;
2199 for (const auto& backend : *backends) {
2200 backendUDPSocketsCount += backend->sockets.size();
2201 }
2202 requiredFDsCount += backendUDPSocketsCount;
2203 /* TCP sockets to backends */
2204 requiredFDsCount += (backends->size() * g_maxTCPClientThreads);
2205 /* listening sockets */
2206 requiredFDsCount += udpBindsCount;
2207 requiredFDsCount += tcpBindsCount;
2208 /* max TCP connections currently served */
2209 requiredFDsCount += g_maxTCPClientThreads;
2210 /* max pipes for communicating between TCP acceptors and client threads */
2211 requiredFDsCount += (g_maxTCPClientThreads * 2);
2212 /* max TCP queued connections */
2213 requiredFDsCount += g_maxTCPQueuedConnections;
2214 /* DelayPipe pipe */
2215 requiredFDsCount += 2;
2216 /* syslog socket */
2217 requiredFDsCount++;
2218 /* webserver main socket */
2219 requiredFDsCount++;
2220 /* console main socket */
2221 requiredFDsCount++;
2222 /* carbon export */
2223 requiredFDsCount++;
2224 /* history file */
2225 requiredFDsCount++;
2226 struct rlimit rl;
2227 getrlimit(RLIMIT_NOFILE, &rl);
2228 if (rl.rlim_cur <= requiredFDsCount) {
2229 warnlog("Warning, this configuration can use more than %d file descriptors, web server and console connections not included, and the current limit is %d.", std::to_string(requiredFDsCount), std::to_string(rl.rlim_cur));
2230 #ifdef HAVE_SYSTEMD
2231 warnlog("You can increase this value by using LimitNOFILE= in the systemd unit file or ulimit.");
2232 #else
2233 warnlog("You can increase this value by using ulimit.");
2234 #endif
2235 }
2236 }
2237
2238 static void setUpLocalBind(std::unique_ptr<ClientState>& cs)
2239 {
2240 /* skip some warnings if there is an identical UDP context */
2241 bool warn = cs->tcp == false || cs->tlsFrontend != nullptr || cs->dohFrontend != nullptr;
2242 int& fd = cs->tcp == false ? cs->udpFD : cs->tcpFD;
2243 (void) warn;
2244
2245 fd = SSocket(cs->local.sin4.sin_family, cs->tcp == false ? SOCK_DGRAM : SOCK_STREAM, 0);
2246
2247 if (cs->tcp) {
2248 SSetsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 1);
2249 #ifdef TCP_DEFER_ACCEPT
2250 SSetsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1);
2251 #endif
2252 if (cs->fastOpenQueueSize > 0) {
2253 #ifdef TCP_FASTOPEN
2254 SSetsockopt(fd, IPPROTO_TCP, TCP_FASTOPEN, cs->fastOpenQueueSize);
2255 #else
2256 if (warn) {
2257 warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
2258 }
2259 #endif
2260 }
2261 }
2262
2263 if(cs->local.sin4.sin_family == AF_INET6) {
2264 SSetsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, 1);
2265 }
2266
2267 bindAny(cs->local.sin4.sin_family, fd);
2268
2269 if(!cs->tcp && IsAnyAddress(cs->local)) {
2270 int one=1;
2271 setsockopt(fd, IPPROTO_IP, GEN_IP_PKTINFO, &one, sizeof(one)); // linux supports this, so why not - might fail on other systems
2272 #ifdef IPV6_RECVPKTINFO
2273 setsockopt(fd, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one));
2274 #endif
2275 }
2276
2277 if (cs->reuseport) {
2278 #ifdef SO_REUSEPORT
2279 SSetsockopt(fd, SOL_SOCKET, SO_REUSEPORT, 1);
2280 #else
2281 if (warn) {
2282 /* no need to warn again if configured but support is not available, we already did for UDP */
2283 warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", cs->local.toStringWithPort());
2284 }
2285 #endif
2286 }
2287
2288 if (!cs->tcp) {
2289 if (cs->local.isIPv4()) {
2290 try {
2291 setSocketIgnorePMTU(cs->udpFD);
2292 }
2293 catch(const std::exception& e) {
2294 warnlog("Failed to set IP_MTU_DISCOVER on UDP server socket for local address '%s': %s", cs->local.toStringWithPort(), e.what());
2295 }
2296 }
2297 }
2298
2299 const std::string& itf = cs->interface;
2300 if (!itf.empty()) {
2301 #ifdef SO_BINDTODEVICE
2302 int res = setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length());
2303 if (res != 0) {
2304 warnlog("Error setting up the interface on local address '%s': %s", cs->local.toStringWithPort(), strerror(errno));
2305 }
2306 #else
2307 if (warn) {
2308 warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", cs->local.toStringWithPort());
2309 }
2310 #endif
2311 }
2312
2313 #ifdef HAVE_EBPF
2314 if (g_defaultBPFFilter) {
2315 cs->attachFilter(g_defaultBPFFilter);
2316 vinfolog("Attaching default BPF Filter to %s frontend %s", (!cs->tcp ? "UDP" : "TCP"), cs->local.toStringWithPort());
2317 }
2318 #endif /* HAVE_EBPF */
2319
2320 if (cs->tlsFrontend != nullptr) {
2321 if (!cs->tlsFrontend->setupTLS()) {
2322 errlog("Error while setting up TLS on local address '%s', exiting", cs->local.toStringWithPort());
2323 _exit(EXIT_FAILURE);
2324 }
2325 }
2326
2327 if (cs->dohFrontend != nullptr) {
2328 cs->dohFrontend->setup();
2329 }
2330
2331 SBind(fd, cs->local);
2332
2333 if (cs->tcp) {
2334 SListen(cs->tcpFD, SOMAXCONN);
2335 if (cs->tlsFrontend != nullptr) {
2336 warnlog("Listening on %s for TLS", cs->local.toStringWithPort());
2337 }
2338 else if (cs->dohFrontend != nullptr) {
2339 warnlog("Listening on %s for DoH", cs->local.toStringWithPort());
2340 }
2341 else if (cs->dnscryptCtx != nullptr) {
2342 warnlog("Listening on %s for DNSCrypt", cs->local.toStringWithPort());
2343 }
2344 else {
2345 warnlog("Listening on %s", cs->local.toStringWithPort());
2346 }
2347 }
2348
2349 cs->ready = true;
2350 }
2351
2352 struct
2353 {
2354 vector<string> locals;
2355 vector<string> remotes;
2356 bool checkConfig{false};
2357 bool beClient{false};
2358 bool beSupervised{false};
2359 string command;
2360 string config;
2361 string uid;
2362 string gid;
2363 } g_cmdLine;
2364
2365 std::atomic<bool> g_configurationDone{false};
2366
2367 static void usage()
2368 {
2369 cout<<endl;
2370 cout<<"Syntax: dnsdist [-C,--config file] [-c,--client [IP[:PORT]]]\n";
2371 cout<<"[-e,--execute cmd] [-h,--help] [-l,--local addr]\n";
2372 cout<<"[-v,--verbose] [--check-config] [--version]\n";
2373 cout<<"\n";
2374 cout<<"-a,--acl netmask Add this netmask to the ACL\n";
2375 cout<<"-C,--config file Load configuration from 'file'\n";
2376 cout<<"-c,--client Operate as a client, connect to dnsdist. This reads\n";
2377 cout<<" controlSocket from your configuration file, but also\n";
2378 cout<<" accepts an IP:PORT argument\n";
2379 #ifdef HAVE_LIBSODIUM
2380 cout<<"-k,--setkey KEY Use KEY for encrypted communication to dnsdist. This\n";
2381 cout<<" is similar to setting setKey in the configuration file.\n";
2382 cout<<" NOTE: this will leak this key in your shell's history\n";
2383 cout<<" and in the systems running process list.\n";
2384 #endif
2385 cout<<"--check-config Validate the configuration file and exit. The exit-code\n";
2386 cout<<" reflects the validation, 0 is OK, 1 means an error.\n";
2387 cout<<" Any errors are printed as well.\n";
2388 cout<<"-e,--execute cmd Connect to dnsdist and execute 'cmd'\n";
2389 cout<<"-g,--gid gid Change the process group ID after binding sockets\n";
2390 cout<<"-h,--help Display this helpful message\n";
2391 cout<<"-l,--local address Listen on this local address\n";
2392 cout<<"--supervised Don't open a console, I'm supervised\n";
2393 cout<<" (use with e.g. systemd and daemontools)\n";
2394 cout<<"--disable-syslog Don't log to syslog, only to stdout\n";
2395 cout<<" (use with e.g. systemd)\n";
2396 cout<<"-u,--uid uid Change the process user ID after binding sockets\n";
2397 cout<<"-v,--verbose Enable verbose mode\n";
2398 cout<<"-V,--version Show dnsdist version information and exit\n";
2399 }
2400
2401 int main(int argc, char** argv)
2402 try
2403 {
2404 size_t udpBindsCount = 0;
2405 size_t tcpBindsCount = 0;
2406 rl_attempted_completion_function = my_completion;
2407 rl_completion_append_character = 0;
2408
2409 signal(SIGPIPE, SIG_IGN);
2410 signal(SIGCHLD, SIG_IGN);
2411 openlog("dnsdist", LOG_PID|LOG_NDELAY, LOG_DAEMON);
2412
2413 #ifdef HAVE_LIBSODIUM
2414 if (sodium_init() == -1) {
2415 cerr<<"Unable to initialize crypto library"<<endl;
2416 exit(EXIT_FAILURE);
2417 }
2418 g_hashperturb=randombytes_uniform(0xffffffff);
2419 srandom(randombytes_uniform(0xffffffff));
2420 #else
2421 {
2422 struct timeval tv;
2423 gettimeofday(&tv, 0);
2424 srandom(tv.tv_sec ^ tv.tv_usec ^ getpid());
2425 g_hashperturb=random();
2426 }
2427
2428 #endif
2429 ComboAddress clientAddress = ComboAddress();
2430 g_cmdLine.config=SYSCONFDIR "/dnsdist.conf";
2431 struct option longopts[]={
2432 {"acl", required_argument, 0, 'a'},
2433 {"check-config", no_argument, 0, 1},
2434 {"client", no_argument, 0, 'c'},
2435 {"config", required_argument, 0, 'C'},
2436 {"disable-syslog", no_argument, 0, 2},
2437 {"execute", required_argument, 0, 'e'},
2438 {"gid", required_argument, 0, 'g'},
2439 {"help", no_argument, 0, 'h'},
2440 {"local", required_argument, 0, 'l'},
2441 {"setkey", required_argument, 0, 'k'},
2442 {"supervised", no_argument, 0, 3},
2443 {"uid", required_argument, 0, 'u'},
2444 {"verbose", no_argument, 0, 'v'},
2445 {"version", no_argument, 0, 'V'},
2446 {0,0,0,0}
2447 };
2448 int longindex=0;
2449 string optstring;
2450 for(;;) {
2451 int c=getopt_long(argc, argv, "a:cC:e:g:hk:l:u:vV", longopts, &longindex);
2452 if(c==-1)
2453 break;
2454 switch(c) {
2455 case 1:
2456 g_cmdLine.checkConfig=true;
2457 break;
2458 case 2:
2459 g_syslog=false;
2460 break;
2461 case 3:
2462 g_cmdLine.beSupervised=true;
2463 break;
2464 case 'C':
2465 g_cmdLine.config=optarg;
2466 break;
2467 case 'c':
2468 g_cmdLine.beClient=true;
2469 break;
2470 case 'e':
2471 g_cmdLine.command=optarg;
2472 break;
2473 case 'g':
2474 g_cmdLine.gid=optarg;
2475 break;
2476 case 'h':
2477 cout<<"dnsdist "<<VERSION<<endl;
2478 usage();
2479 cout<<"\n";
2480 exit(EXIT_SUCCESS);
2481 break;
2482 case 'a':
2483 optstring=optarg;
2484 g_ACL.modify([optstring](NetmaskGroup& nmg) { nmg.addMask(optstring); });
2485 break;
2486 case 'k':
2487 #ifdef HAVE_LIBSODIUM
2488 if (B64Decode(string(optarg), g_consoleKey) < 0) {
2489 cerr<<"Unable to decode key '"<<optarg<<"'."<<endl;
2490 exit(EXIT_FAILURE);
2491 }
2492 #else
2493 cerr<<"dnsdist has been built without libsodium, -k/--setkey is unsupported."<<endl;
2494 exit(EXIT_FAILURE);
2495 #endif
2496 break;
2497 case 'l':
2498 g_cmdLine.locals.push_back(trim_copy(string(optarg)));
2499 break;
2500 case 'u':
2501 g_cmdLine.uid=optarg;
2502 break;
2503 case 'v':
2504 g_verbose=true;
2505 break;
2506 case 'V':
2507 #ifdef LUAJIT_VERSION
2508 cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<" ["<<LUAJIT_VERSION<<"])"<<endl;
2509 #else
2510 cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<")"<<endl;
2511 #endif
2512 cout<<"Enabled features: ";
2513 #ifdef HAVE_DNS_OVER_TLS
2514 cout<<"dns-over-tls(";
2515 #ifdef HAVE_GNUTLS
2516 cout<<"gnutls";
2517 #ifdef HAVE_LIBSSL
2518 cout<<" ";
2519 #endif
2520 #endif
2521 #ifdef HAVE_LIBSSL
2522 cout<<"openssl";
2523 #endif
2524 cout<<") ";
2525 #endif
2526 #ifdef HAVE_DNS_OVER_HTTPS
2527 cout<<"dns-over-https(DOH) ";
2528 #endif
2529 #ifdef HAVE_DNSCRYPT
2530 cout<<"dnscrypt ";
2531 #endif
2532 #ifdef HAVE_EBPF
2533 cout<<"ebpf ";
2534 #endif
2535 #ifdef HAVE_FSTRM
2536 cout<<"fstrm ";
2537 #endif
2538 #ifdef HAVE_LIBCRYPTO
2539 cout<<"ipcipher ";
2540 #endif
2541 #ifdef HAVE_LIBSODIUM
2542 cout<<"libsodium ";
2543 #endif
2544 #ifdef HAVE_PROTOBUF
2545 cout<<"protobuf ";
2546 #endif
2547 #ifdef HAVE_RE2
2548 cout<<"re2 ";
2549 #endif
2550 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
2551 cout<<"recvmmsg/sendmmsg ";
2552 #endif
2553 #ifdef HAVE_NET_SNMP
2554 cout<<"snmp ";
2555 #endif
2556 #ifdef HAVE_SYSTEMD
2557 cout<<"systemd";
2558 #endif
2559 cout<<endl;
2560 exit(EXIT_SUCCESS);
2561 break;
2562 case '?':
2563 //getopt_long printed an error message.
2564 usage();
2565 exit(EXIT_FAILURE);
2566 break;
2567 }
2568 }
2569
2570 argc-=optind;
2571 argv+=optind;
2572 for(auto p = argv; *p; ++p) {
2573 if(g_cmdLine.beClient) {
2574 clientAddress = ComboAddress(*p, 5199);
2575 } else {
2576 g_cmdLine.remotes.push_back(*p);
2577 }
2578 }
2579
2580 ServerPolicy leastOutstandingPol{"leastOutstanding", leastOutstanding, false};
2581
2582 g_policy.setState(leastOutstandingPol);
2583 if(g_cmdLine.beClient || !g_cmdLine.command.empty()) {
2584 setupLua(true, g_cmdLine.config);
2585 if (clientAddress != ComboAddress())
2586 g_serverControl = clientAddress;
2587 doClient(g_serverControl, g_cmdLine.command);
2588 _exit(EXIT_SUCCESS);
2589 }
2590
2591 auto acl = g_ACL.getCopy();
2592 if(acl.empty()) {
2593 for(auto& addr : {"127.0.0.0/8", "10.0.0.0/8", "100.64.0.0/10", "169.254.0.0/16", "192.168.0.0/16", "172.16.0.0/12", "::1/128", "fc00::/7", "fe80::/10"})
2594 acl.addMask(addr);
2595 g_ACL.setState(acl);
2596 }
2597
2598 auto consoleACL = g_consoleACL.getCopy();
2599 for (const auto& mask : { "127.0.0.1/8", "::1/128" }) {
2600 consoleACL.addMask(mask);
2601 }
2602 g_consoleACL.setState(consoleACL);
2603
2604 if (g_cmdLine.checkConfig) {
2605 setupLua(true, g_cmdLine.config);
2606 // No exception was thrown
2607 infolog("Configuration '%s' OK!", g_cmdLine.config);
2608 _exit(EXIT_SUCCESS);
2609 }
2610
2611 auto todo=setupLua(false, g_cmdLine.config);
2612
2613 auto localPools = g_pools.getCopy();
2614 {
2615 bool precompute = false;
2616 if (g_policy.getLocal()->name == "chashed") {
2617 precompute = true;
2618 } else {
2619 for (const auto& entry: localPools) {
2620 if (entry.second->policy != nullptr && entry.second->policy->name == "chashed") {
2621 precompute = true;
2622 break ;
2623 }
2624 }
2625 }
2626 if (precompute) {
2627 vinfolog("Pre-computing hashes for consistent hash load-balancing policy");
2628 // pre compute hashes
2629 auto backends = g_dstates.getLocal();
2630 for (auto& backend: *backends) {
2631 backend->hash();
2632 }
2633 }
2634 }
2635
2636 if (!g_cmdLine.locals.empty()) {
2637 for (auto it = g_frontends.begin(); it != g_frontends.end(); ) {
2638 /* DoH, DoT and DNSCrypt frontends are separate */
2639 if ((*it)->dohFrontend == nullptr && (*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr) {
2640 it = g_frontends.erase(it);
2641 }
2642 else {
2643 ++it;
2644 }
2645 }
2646
2647 for(const auto& loc : g_cmdLine.locals) {
2648 /* UDP */
2649 g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), false, false, 0, "", {})));
2650 /* TCP */
2651 g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), true, false, 0, "", {})));
2652 }
2653 }
2654
2655 if (g_frontends.empty()) {
2656 /* UDP */
2657 g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), false, false, 0, "", {})));
2658 /* TCP */
2659 g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), true, false, 0, "", {})));
2660 }
2661
2662 g_configurationDone = true;
2663
2664 for(auto& frontend : g_frontends) {
2665 setUpLocalBind(frontend);
2666
2667 if (frontend->tcp == false) {
2668 ++udpBindsCount;
2669 }
2670 else {
2671 ++tcpBindsCount;
2672 }
2673 }
2674
2675 warnlog("dnsdist %s comes with ABSOLUTELY NO WARRANTY. This is free software, and you are welcome to redistribute it according to the terms of the GPL version 2", VERSION);
2676
2677 vector<string> vec;
2678 std::string acls;
2679 g_ACL.getLocal()->toStringVector(&vec);
2680 for(const auto& s : vec) {
2681 if (!acls.empty())
2682 acls += ", ";
2683 acls += s;
2684 }
2685 infolog("ACL allowing queries from: %s", acls.c_str());
2686 vec.clear();
2687 acls.clear();
2688 g_consoleACL.getLocal()->toStringVector(&vec);
2689 for (const auto& entry : vec) {
2690 if (!acls.empty()) {
2691 acls += ", ";
2692 }
2693 acls += entry;
2694 }
2695 infolog("Console ACL allowing connections from: %s", acls.c_str());
2696
2697 #ifdef HAVE_LIBSODIUM
2698 if (g_consoleEnabled && g_consoleKey.empty()) {
2699 warnlog("Warning, the console has been enabled via 'controlSocket()' but no key has been set with 'setKey()' so all connections will fail until a key has been set");
2700 }
2701 #endif
2702
2703 uid_t newgid=0;
2704 gid_t newuid=0;
2705
2706 if(!g_cmdLine.gid.empty())
2707 newgid = strToGID(g_cmdLine.gid.c_str());
2708
2709 if(!g_cmdLine.uid.empty())
2710 newuid = strToUID(g_cmdLine.uid.c_str());
2711
2712 dropGroupPrivs(newgid);
2713 dropUserPrivs(newuid);
2714 try {
2715 /* we might still have capabilities remaining,
2716 for example if we have been started as root
2717 without --uid or --gid (please don't do that)
2718 or as an unprivileged user with ambient
2719 capabilities like CAP_NET_BIND_SERVICE.
2720 */
2721 dropCapabilities();
2722 }
2723 catch(const std::exception& e) {
2724 warnlog("%s", e.what());
2725 }
2726
2727 /* this need to be done _after_ dropping privileges */
2728 g_delay = new DelayPipe<DelayedPacket>();
2729
2730 if (g_snmpAgent) {
2731 g_snmpAgent->run();
2732 }
2733
2734 g_tcpclientthreads = std::unique_ptr<TCPClientCollection>(new TCPClientCollection(g_maxTCPClientThreads, g_useTCPSinglePipe));
2735
2736 for(auto& t : todo)
2737 t();
2738
2739 localPools = g_pools.getCopy();
2740 /* create the default pool no matter what */
2741 createPoolIfNotExists(localPools, "");
2742 if(g_cmdLine.remotes.size()) {
2743 for(const auto& address : g_cmdLine.remotes) {
2744 auto ret=std::make_shared<DownstreamState>(ComboAddress(address, 53));
2745 addServerToPool(localPools, "", ret);
2746 if (ret->connected && !ret->threadStarted.test_and_set()) {
2747 ret->tid = thread(responderThread, ret);
2748 }
2749 g_dstates.modify([ret](servers_t& servers) { servers.push_back(ret); });
2750 }
2751 }
2752 g_pools.setState(localPools);
2753
2754 if(g_dstates.getLocal()->empty()) {
2755 errlog("No downstream servers defined: all packets will get dropped");
2756 // you might define them later, but you need to know
2757 }
2758
2759 checkFileDescriptorsLimits(udpBindsCount, tcpBindsCount);
2760
2761 for(auto& dss : g_dstates.getCopy()) { // it is a copy, but the internal shared_ptrs are the real deal
2762 if(dss->availability==DownstreamState::Availability::Auto) {
2763 bool newState=upCheck(dss);
2764 warnlog("Marking downstream %s as '%s'", dss->getNameWithAddr(), newState ? "up" : "down");
2765 dss->upStatus = newState;
2766 }
2767 }
2768
2769 for(auto& cs : g_frontends) {
2770 if (cs->dohFrontend != nullptr) {
2771 #ifdef HAVE_DNS_OVER_HTTPS
2772 std::thread t1(dohThread, cs.get());
2773 if (!cs->cpus.empty()) {
2774 mapThreadToCPUList(t1.native_handle(), cs->cpus);
2775 }
2776 t1.detach();
2777 #endif /* HAVE_DNS_OVER_HTTPS */
2778 continue;
2779 }
2780 if (cs->udpFD >= 0) {
2781 thread t1(udpClientThread, cs.get());
2782 if (!cs->cpus.empty()) {
2783 mapThreadToCPUList(t1.native_handle(), cs->cpus);
2784 }
2785 t1.detach();
2786 }
2787 else if (cs->tcpFD >= 0) {
2788 thread t1(tcpAcceptorThread, cs.get());
2789 if (!cs->cpus.empty()) {
2790 mapThreadToCPUList(t1.native_handle(), cs->cpus);
2791 }
2792 t1.detach();
2793 }
2794 }
2795
2796 thread carbonthread(carbonDumpThread);
2797 carbonthread.detach();
2798
2799 thread stattid(maintThread);
2800 stattid.detach();
2801
2802 thread healththread(healthChecksThread);
2803
2804 if (!g_secPollSuffix.empty()) {
2805 thread secpollthread(secPollThread);
2806 secpollthread.detach();
2807 }
2808
2809 if(g_cmdLine.beSupervised) {
2810 #ifdef HAVE_SYSTEMD
2811 sd_notify(0, "READY=1");
2812 #endif
2813 healththread.join();
2814 }
2815 else {
2816 healththread.detach();
2817 doConsole();
2818 }
2819 _exit(EXIT_SUCCESS);
2820
2821 }
2822 catch(const LuaContext::ExecutionErrorException& e) {
2823 try {
2824 errlog("Fatal Lua error: %s", e.what());
2825 std::rethrow_if_nested(e);
2826 } catch(const std::exception& ne) {
2827 errlog("Details: %s", ne.what());
2828 }
2829 catch(PDNSException &ae)
2830 {
2831 errlog("Fatal pdns error: %s", ae.reason);
2832 }
2833 _exit(EXIT_FAILURE);
2834 }
2835 catch(std::exception &e)
2836 {
2837 errlog("Fatal error: %s", e.what());
2838 _exit(EXIT_FAILURE);
2839 }
2840 catch(PDNSException &ae)
2841 {
2842 errlog("Fatal pdns error: %s", ae.reason);
2843 _exit(EXIT_FAILURE);
2844 }
2845
2846 uint64_t getLatencyCount(const std::string&)
2847 {
2848 return g_stats.responses + g_stats.selfAnswered + g_stats.cacheHits;
2849 }