]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdistdist/dnsdist-ecs.cc
5fa61c5b19927fd80dee124ff6b7857c4962e233
[thirdparty/pdns.git] / pdns / dnsdistdist / dnsdist-ecs.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dolog.hh"
23 #include "dnsdist.hh"
24 #include "dnsdist-dnsparser.hh"
25 #include "dnsdist-ecs.hh"
26 #include "dnsparser.hh"
27 #include "dnswriter.hh"
28 #include "ednsoptions.hh"
29 #include "ednssubnet.hh"
30
31 /* when we add EDNS to a query, we don't want to advertise
32 a large buffer size */
33 size_t g_EdnsUDPPayloadSize = 512;
34 static const uint16_t defaultPayloadSizeSelfGenAnswers = 1232;
35 static_assert(defaultPayloadSizeSelfGenAnswers < s_udpIncomingBufferSize, "The UDP responder's payload size should be smaller or equal to our incoming buffer size");
36 uint16_t g_PayloadSizeSelfGenAnswers{defaultPayloadSizeSelfGenAnswers};
37
38 /* draft-ietf-dnsop-edns-client-subnet-04 "11.1. Privacy" */
39 uint16_t g_ECSSourcePrefixV4 = 24;
40 uint16_t g_ECSSourcePrefixV6 = 56;
41
42 bool g_ECSOverride{false};
43 bool g_addEDNSToSelfGeneratedResponses{true};
44
45 int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
46 {
47 assert(initialPacket.size() >= sizeof(dnsheader));
48 const dnsheader_aligned dh(initialPacket.data());
49
50 if (ntohs(dh->arcount) == 0) {
51 return ENOENT;
52 }
53
54 if (ntohs(dh->qdcount) == 0) {
55 return ENOENT;
56 }
57
58 PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
59
60 size_t idx = 0;
61 DNSName rrname;
62 uint16_t qdcount = ntohs(dh->qdcount);
63 uint16_t ancount = ntohs(dh->ancount);
64 uint16_t nscount = ntohs(dh->nscount);
65 uint16_t arcount = ntohs(dh->arcount);
66 uint16_t rrtype;
67 uint16_t rrclass;
68 string blob;
69 struct dnsrecordheader ah;
70
71 rrname = pr.getName();
72 rrtype = pr.get16BitInt();
73 rrclass = pr.get16BitInt();
74
75 GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
76 pw.getHeader()->id = dh->id;
77 pw.getHeader()->qr = dh->qr;
78 pw.getHeader()->aa = dh->aa;
79 pw.getHeader()->tc = dh->tc;
80 pw.getHeader()->rd = dh->rd;
81 pw.getHeader()->ra = dh->ra;
82 pw.getHeader()->ad = dh->ad;
83 pw.getHeader()->cd = dh->cd;
84 pw.getHeader()->rcode = dh->rcode;
85
86 /* consume remaining qd if any */
87 if (qdcount > 1) {
88 for (idx = 1; idx < qdcount; idx++) {
89 rrname = pr.getName();
90 rrtype = pr.get16BitInt();
91 rrclass = pr.get16BitInt();
92 (void)rrtype;
93 (void)rrclass;
94 }
95 }
96
97 /* copy AN and NS */
98 for (idx = 0; idx < ancount; idx++) {
99 rrname = pr.getName();
100 pr.getDnsrecordheader(ah);
101
102 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
103 pr.xfrBlob(blob);
104 pw.xfrBlob(blob);
105 }
106
107 for (idx = 0; idx < nscount; idx++) {
108 rrname = pr.getName();
109 pr.getDnsrecordheader(ah);
110
111 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
112 pr.xfrBlob(blob);
113 pw.xfrBlob(blob);
114 }
115 /* consume AR, looking for OPT */
116 for (idx = 0; idx < arcount; idx++) {
117 rrname = pr.getName();
118 pr.getDnsrecordheader(ah);
119
120 if (ah.d_type != QType::OPT) {
121 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
122 pr.xfrBlob(blob);
123 pw.xfrBlob(blob);
124 }
125 else {
126
127 pr.skip(ah.d_clen);
128 }
129 }
130 pw.commit();
131
132 return 0;
133 }
134
135 static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>& options, uint16_t optionCode, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
136 {
137 for (auto it = options.begin(); it != options.end();) {
138 if (it->first == optionCode) {
139 optionAdded = false;
140
141 if (!overrideExisting) {
142 return false;
143 }
144
145 it = options.erase(it);
146 }
147 else {
148 ++it;
149 }
150 }
151
152 options.emplace_back(optionCode, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
153 return true;
154 }
155
156 bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
157 {
158 assert(initialPacket.size() >= sizeof(dnsheader));
159 const dnsheader_aligned dh(initialPacket.data());
160
161 if (ntohs(dh->qdcount) == 0) {
162 return false;
163 }
164
165 if (ntohs(dh->ancount) == 0 && ntohs(dh->nscount) == 0 && ntohs(dh->arcount) == 0) {
166 throw std::runtime_error(std::string(__PRETTY_FUNCTION__) + " should not be called for queries that have no records");
167 }
168
169 optionAdded = false;
170 ednsAdded = true;
171
172 PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
173
174 size_t idx = 0;
175 DNSName rrname;
176 uint16_t qdcount = ntohs(dh->qdcount);
177 uint16_t ancount = ntohs(dh->ancount);
178 uint16_t nscount = ntohs(dh->nscount);
179 uint16_t arcount = ntohs(dh->arcount);
180 uint16_t rrtype;
181 uint16_t rrclass;
182 string blob;
183 struct dnsrecordheader ah;
184
185 rrname = pr.getName();
186 rrtype = pr.get16BitInt();
187 rrclass = pr.get16BitInt();
188
189 GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
190 pw.getHeader()->id = dh->id;
191 pw.getHeader()->qr = dh->qr;
192 pw.getHeader()->aa = dh->aa;
193 pw.getHeader()->tc = dh->tc;
194 pw.getHeader()->rd = dh->rd;
195 pw.getHeader()->ra = dh->ra;
196 pw.getHeader()->ad = dh->ad;
197 pw.getHeader()->cd = dh->cd;
198 pw.getHeader()->rcode = dh->rcode;
199
200 /* consume remaining qd if any */
201 if (qdcount > 1) {
202 for (idx = 1; idx < qdcount; idx++) {
203 rrname = pr.getName();
204 rrtype = pr.get16BitInt();
205 rrclass = pr.get16BitInt();
206 (void)rrtype;
207 (void)rrclass;
208 }
209 }
210
211 /* copy AN and NS */
212 for (idx = 0; idx < ancount; idx++) {
213 rrname = pr.getName();
214 pr.getDnsrecordheader(ah);
215
216 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
217 pr.xfrBlob(blob);
218 pw.xfrBlob(blob);
219 }
220
221 for (idx = 0; idx < nscount; idx++) {
222 rrname = pr.getName();
223 pr.getDnsrecordheader(ah);
224
225 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
226 pr.xfrBlob(blob);
227 pw.xfrBlob(blob);
228 }
229
230 /* consume AR, looking for OPT */
231 for (idx = 0; idx < arcount; idx++) {
232 rrname = pr.getName();
233 pr.getDnsrecordheader(ah);
234
235 if (ah.d_type != QType::OPT) {
236 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
237 pr.xfrBlob(blob);
238 pw.xfrBlob(blob);
239 }
240 else {
241
242 ednsAdded = false;
243 pr.xfrBlob(blob);
244
245 std::vector<std::pair<uint16_t, std::string>> options;
246 getEDNSOptionsFromContent(blob, options);
247
248 /* getDnsrecordheader() has helpfully converted the TTL for us, which we do not want in that case */
249 uint32_t ttl = htonl(ah.d_ttl);
250 EDNS0Record edns0;
251 static_assert(sizeof(edns0) == sizeof(ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
252 memcpy(&edns0, &ttl, sizeof(edns0));
253
254 /* addOrReplaceEDNSOption will set it to false if there is already an existing option */
255 optionAdded = true;
256 addOrReplaceEDNSOption(options, optionToReplace, optionAdded, overrideExisting, newOptionContent);
257 pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
258 }
259 }
260
261 if (ednsAdded) {
262 pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{optionToReplace, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
263 optionAdded = true;
264 }
265
266 pw.commit();
267
268 return true;
269 }
270
271 static bool slowParseEDNSOptions(const PacketBuffer& packet, EDNSOptionViewMap& options)
272 {
273 if (packet.size() < sizeof(dnsheader)) {
274 return false;
275 }
276
277 const dnsheader_aligned dh(packet.data());
278
279 if (ntohs(dh->qdcount) == 0) {
280 return false;
281 }
282
283 if (ntohs(dh->arcount) == 0) {
284 throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
285 }
286
287 try {
288 uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
289 DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(&packet.at(0))), packet.size());
290 uint64_t n;
291 for (n = 0; n < ntohs(dh->qdcount); ++n) {
292 dpm.skipDomainName();
293 /* type and class */
294 dpm.skipBytes(4);
295 }
296
297 for (n = 0; n < numrecords; ++n) {
298 dpm.skipDomainName();
299
300 uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
301 uint16_t dnstype = dpm.get16BitInt();
302 dpm.get16BitInt();
303 dpm.skipBytes(4); /* TTL */
304
305 if (section == 3 && dnstype == QType::OPT) {
306 uint32_t offset = dpm.getOffset();
307 if (offset >= packet.size()) {
308 return false;
309 }
310 /* if we survive this call, we can parse it safely */
311 dpm.skipRData();
312 return getEDNSOptions(reinterpret_cast<const char*>(&packet.at(offset)), packet.size() - offset, options) == 0;
313 }
314 else {
315 dpm.skipRData();
316 }
317 }
318 }
319 catch (...) {
320 return false;
321 }
322
323 return true;
324 }
325
326 int locateEDNSOptRR(const PacketBuffer& packet, uint16_t* optStart, size_t* optLen, bool* last)
327 {
328 assert(optStart != NULL);
329 assert(optLen != NULL);
330 assert(last != NULL);
331 const dnsheader_aligned dh(packet.data());
332
333 if (ntohs(dh->arcount) == 0) {
334 return ENOENT;
335 }
336
337 PacketReader pr(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
338
339 size_t idx = 0;
340 DNSName rrname;
341 uint16_t qdcount = ntohs(dh->qdcount);
342 uint16_t ancount = ntohs(dh->ancount);
343 uint16_t nscount = ntohs(dh->nscount);
344 uint16_t arcount = ntohs(dh->arcount);
345 uint16_t rrtype;
346 uint16_t rrclass;
347 struct dnsrecordheader ah;
348
349 /* consume qd */
350 for (idx = 0; idx < qdcount; idx++) {
351 rrname = pr.getName();
352 rrtype = pr.get16BitInt();
353 rrclass = pr.get16BitInt();
354 (void)rrtype;
355 (void)rrclass;
356 }
357
358 /* consume AN and NS */
359 for (idx = 0; idx < ancount + nscount; idx++) {
360 rrname = pr.getName();
361 pr.getDnsrecordheader(ah);
362 pr.skip(ah.d_clen);
363 }
364
365 /* consume AR, looking for OPT */
366 for (idx = 0; idx < arcount; idx++) {
367 uint16_t start = pr.getPosition();
368 rrname = pr.getName();
369 pr.getDnsrecordheader(ah);
370
371 if (ah.d_type == QType::OPT) {
372 *optStart = start;
373 *optLen = (pr.getPosition() - start) + ah.d_clen;
374
375 if (packet.size() < (*optStart + *optLen)) {
376 throw std::range_error("Opt record overflow");
377 }
378
379 if (idx == ((size_t)arcount - 1)) {
380 *last = true;
381 }
382 else {
383 *last = false;
384 }
385 return 0;
386 }
387 pr.skip(ah.d_clen);
388 }
389
390 return ENOENT;
391 }
392
393 /* extract the start of the OPT RR in a QUERY packet if any */
394 int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_t* optRDPosition, size_t* remaining)
395 {
396 assert(optRDPosition != nullptr);
397 assert(remaining != nullptr);
398 const dnsheader_aligned dh(packet.data());
399
400 if (offset >= packet.size()) {
401 return ENOENT;
402 }
403
404 if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0) {
405 return ENOENT;
406 }
407
408 size_t pos = sizeof(dnsheader) + offset;
409 pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
410
411 if (pos >= packet.size())
412 return ENOENT;
413
414 if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= packet.size()) {
415 return ENOENT;
416 }
417
418 if (packet[pos] != 0) {
419 /* not the root so not an OPT record */
420 return ENOENT;
421 }
422 pos += 1;
423
424 uint16_t qtype = packet.at(pos) * 256 + packet.at(pos + 1);
425 pos += DNS_TYPE_SIZE;
426 pos += DNS_CLASS_SIZE;
427
428 if (qtype != QType::OPT || (packet.size() - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE)) {
429 return ENOENT;
430 }
431
432 pos += DNS_TTL_SIZE;
433 *optRDPosition = pos;
434 *remaining = packet.size() - pos;
435
436 return 0;
437 }
438
439 void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
440 {
441 Netmask sourceNetmask(source, ECSPrefixLength);
442 EDNSSubnetOpts ecsOpts;
443 ecsOpts.source = sourceNetmask;
444 string payload = makeEDNSSubnetOptsString(ecsOpts);
445 generateEDNSOption(EDNSOptionCode::ECS, payload, res);
446 }
447
448 bool generateOptRR(const std::string& optRData, PacketBuffer& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
449 {
450 const uint8_t name = 0;
451 dnsrecordheader dh;
452 EDNS0Record edns0;
453 edns0.extRCode = ednsrcode;
454 edns0.version = 0;
455 edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
456
457 if ((maximumSize - res.size()) < (sizeof(name) + sizeof(dh) + optRData.length())) {
458 return false;
459 }
460
461 dh.d_type = htons(QType::OPT);
462 dh.d_class = htons(udpPayloadSize);
463 static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
464 memcpy(&dh.d_ttl, &edns0, sizeof edns0);
465 dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
466
467 res.reserve(res.size() + sizeof(name) + sizeof(dh) + optRData.length());
468 res.insert(res.end(), reinterpret_cast<const uint8_t*>(&name), reinterpret_cast<const uint8_t*>(&name) + sizeof(name));
469 res.insert(res.end(), reinterpret_cast<const uint8_t*>(&dh), reinterpret_cast<const uint8_t*>(&dh) + sizeof(dh));
470 res.insert(res.end(), reinterpret_cast<const uint8_t*>(optRData.data()), reinterpret_cast<const uint8_t*>(optRData.data()) + optRData.length());
471
472 return true;
473 }
474
475 static bool replaceEDNSClientSubnetOption(PacketBuffer& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption)
476 {
477 assert(oldEcsOptionStartPosition < packet.size());
478 assert(optRDLenPosition < packet.size());
479
480 if (newECSOption.size() == oldEcsOptionSize) {
481 /* same size as the existing option */
482 memcpy(&packet.at(oldEcsOptionStartPosition), newECSOption.c_str(), oldEcsOptionSize);
483 }
484 else {
485 /* different size than the existing option */
486 const unsigned int newPacketLen = packet.size() + (newECSOption.length() - oldEcsOptionSize);
487 const size_t beforeOptionLen = oldEcsOptionStartPosition;
488 const size_t dataBehindSize = packet.size() - beforeOptionLen - oldEcsOptionSize;
489
490 /* check that it fits in the existing buffer */
491 if (newPacketLen > packet.size()) {
492 if (newPacketLen > maximumSize) {
493 return false;
494 }
495
496 packet.resize(newPacketLen);
497 }
498
499 /* fix the size of ECS Option RDLen */
500 uint16_t newRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
501 newRDLen += (newECSOption.size() - oldEcsOptionSize);
502 packet.at(optRDLenPosition) = newRDLen / 256;
503 packet.at(optRDLenPosition + 1) = newRDLen % 256;
504
505 if (dataBehindSize > 0) {
506 memmove(&packet.at(oldEcsOptionStartPosition), &packet.at(oldEcsOptionStartPosition + oldEcsOptionSize), dataBehindSize);
507 }
508 memcpy(&packet.at(oldEcsOptionStartPosition + dataBehindSize), newECSOption.c_str(), newECSOption.size());
509 packet.resize(newPacketLen);
510 }
511
512 return true;
513 }
514
515 /* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
516 and false otherwise. */
517 bool parseEDNSOptions(const DNSQuestion& dq)
518 {
519 const auto dh = dq.getHeader();
520 if (dq.ednsOptions != nullptr) {
521 return true;
522 }
523
524 // dq.ednsOptions is mutable
525 dq.ednsOptions = std::make_unique<EDNSOptionViewMap>();
526
527 if (ntohs(dh->arcount) == 0) {
528 /* nothing in additional so no EDNS */
529 return false;
530 }
531
532 if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || ntohs(dh->arcount) > 1) {
533 return slowParseEDNSOptions(dq.getData(), *dq.ednsOptions);
534 }
535
536 size_t remaining = 0;
537 uint16_t optRDPosition;
538 int res = getEDNSOptionsStart(dq.getData(), dq.ids.qname.wirelength(), &optRDPosition, &remaining);
539
540 if (res == 0) {
541 res = getEDNSOptions(reinterpret_cast<const char*>(&dq.getData().at(optRDPosition)), remaining, *dq.ednsOptions);
542 return (res == 0);
543 }
544
545 return false;
546 }
547
548 static bool addECSToExistingOPT(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, size_t optRDLenPosition, bool& ecsAdded)
549 {
550 /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
551 /* getEDNSOptionsStart has already checked that there is exactly one AR,
552 no NS and no AN */
553 uint16_t oldRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
554 if (packet.size() != (optRDLenPosition + sizeof(uint16_t) + oldRDLen)) {
555 /* we are supposed to be the last record, do we have some trailing data to remove? */
556 uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
557 packet.resize(realPacketLen);
558 }
559
560 if ((maximumSize - packet.size()) < newECSOption.size()) {
561 return false;
562 }
563
564 uint16_t newRDLen = oldRDLen + newECSOption.size();
565 packet.at(optRDLenPosition) = newRDLen / 256;
566 packet.at(optRDLenPosition + 1) = newRDLen % 256;
567
568 packet.insert(packet.end(), newECSOption.begin(), newECSOption.end());
569 ecsAdded = true;
570
571 return true;
572 }
573
574 static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded)
575 {
576 if (!generateOptRR(newECSOption, packet, maximumSize, g_EdnsUDPPayloadSize, 0, false)) {
577 return false;
578 }
579
580 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
581 uint16_t arcount = ntohs(header.arcount);
582 arcount++;
583 header.arcount = htons(arcount);
584 return true;
585 });
586 ednsAdded = true;
587 ecsAdded = true;
588
589 return true;
590 }
591
592 bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, const size_t qnameWireLength, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
593 {
594 assert(qnameWireLength <= packet.size());
595
596 const dnsheader_aligned dh(packet.data());
597
598 if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
599 PacketBuffer newContent;
600 newContent.reserve(packet.size());
601
602 if (!slowRewriteEDNSOptionInQueryWithRecords(packet, newContent, ednsAdded, EDNSOptionCode::ECS, ecsAdded, overrideExisting, newECSOption)) {
603 return false;
604 }
605
606 if (newContent.size() > maximumSize) {
607 ednsAdded = false;
608 ecsAdded = false;
609 return false;
610 }
611
612 packet = std::move(newContent);
613 return true;
614 }
615
616 uint16_t optRDPosition = 0;
617 size_t remaining = 0;
618
619 int res = getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
620
621 if (res != 0) {
622 /* no EDNS but there might be another record in additional (TSIG?) */
623 /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */
624 size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t);
625 if (packet.size() > minimumPacketSize) {
626 if (ntohs(dh->arcount) == 0) {
627 /* well now.. */
628 packet.resize(minimumPacketSize);
629 }
630 else {
631 uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
632 packet.resize(realPacketLen);
633 }
634 }
635
636 return addEDNSWithECS(packet, maximumSize, newECSOption, ednsAdded, ecsAdded);
637 }
638
639 size_t ecsOptionStartPosition = 0;
640 size_t ecsOptionSize = 0;
641
642 res = getEDNSOption(reinterpret_cast<const char*>(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize);
643
644 if (res == 0) {
645 /* there is already an ECS value */
646 if (!overrideExisting) {
647 return true;
648 }
649
650 return replaceEDNSClientSubnetOption(packet, maximumSize, optRDPosition + ecsOptionStartPosition, ecsOptionSize, optRDPosition, newECSOption);
651 }
652 else {
653 /* we have an EDNS OPT RR but no existing ECS option */
654 return addECSToExistingOPT(packet, maximumSize, newECSOption, optRDPosition, ecsAdded);
655 }
656
657 return true;
658 }
659
660 bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded)
661 {
662 string newECSOption;
663 generateECSOption(dq.ecs ? dq.ecs->getNetwork() : dq.ids.origRemote, newECSOption, dq.ecs ? dq.ecs->getBits() : dq.ecsPrefixLength);
664
665 return handleEDNSClientSubnet(dq.getMutableData(), dq.getMaximumSize(), dq.ids.qname.wirelength(), ednsAdded, ecsAdded, dq.ecsOverride, newECSOption);
666 }
667
668 static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
669 {
670 unsigned char* p = optionsStart;
671 size_t pos = 0;
672 while ((pos + 4) <= optionsLen) {
673 unsigned char* optionBegin = p;
674 const uint16_t optionCode = 0x100 * p[0] + p[1];
675 p += sizeof(optionCode);
676 pos += sizeof(optionCode);
677 const uint16_t optionLen = 0x100 * p[0] + p[1];
678 p += sizeof(optionLen);
679 pos += sizeof(optionLen);
680 if ((pos + optionLen) > optionsLen) {
681 return EINVAL;
682 }
683 if (optionCode == optionCodeToRemove) {
684 if (pos + optionLen < optionsLen) {
685 /* move remaining options over the removed one,
686 if any */
687 memmove(optionBegin, p + optionLen, optionsLen - (pos + optionLen));
688 }
689 *newOptionsLen = optionsLen - (sizeof(optionCode) + sizeof(optionLen) + optionLen);
690 return 0;
691 }
692 p += optionLen;
693 pos += optionLen;
694 }
695 return ENOENT;
696 }
697
698 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
699 {
700 if (*optLen < optRecordMinimumSize) {
701 return EINVAL;
702 }
703 const unsigned char* end = (const unsigned char*)optStart + *optLen;
704 unsigned char* p = (unsigned char*)optStart + 9;
705 unsigned char* rdLenPtr = p;
706 uint16_t rdLen = (0x100 * p[0] + p[1]);
707 p += sizeof(rdLen);
708 if (p + rdLen != end) {
709 return EINVAL;
710 }
711 uint16_t newRdLen = 0;
712 int res = removeEDNSOptionFromOptions(p, rdLen, optionCodeToRemove, &newRdLen);
713 if (res != 0) {
714 return res;
715 }
716 *optLen -= (rdLen - newRdLen);
717 rdLenPtr[0] = newRdLen / 0x100;
718 rdLenPtr[1] = newRdLen % 0x100;
719 return 0;
720 }
721
722 bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
723 {
724 if (optLen < optRecordMinimumSize) {
725 return false;
726 }
727 size_t p = optStart + 9;
728 uint16_t rdLen = (0x100 * static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p + 1)));
729 p += sizeof(rdLen);
730 if (rdLen > (optLen - optRecordMinimumSize)) {
731 return false;
732 }
733
734 size_t rdEnd = p + rdLen;
735 while ((p + 4) <= rdEnd) {
736 const uint16_t optionCode = 0x100 * static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p + 1));
737 p += sizeof(optionCode);
738 const uint16_t optionLen = 0x100 * static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p + 1));
739 p += sizeof(optionLen);
740
741 if ((p + optionLen) > rdEnd) {
742 return false;
743 }
744
745 if (optionCode == optionCodeToFind) {
746 if (optContentStart != nullptr) {
747 *optContentStart = p;
748 }
749
750 if (optContentLen != nullptr) {
751 *optContentLen = optionLen;
752 }
753
754 return true;
755 }
756 p += optionLen;
757 }
758 return false;
759 }
760
761 int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
762 {
763 assert(initialPacket.size() >= sizeof(dnsheader));
764 const dnsheader_aligned dh(initialPacket.data());
765
766 if (ntohs(dh->arcount) == 0)
767 return ENOENT;
768
769 if (ntohs(dh->qdcount) == 0)
770 return ENOENT;
771
772 PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
773
774 size_t idx = 0;
775 DNSName rrname;
776 uint16_t qdcount = ntohs(dh->qdcount);
777 uint16_t ancount = ntohs(dh->ancount);
778 uint16_t nscount = ntohs(dh->nscount);
779 uint16_t arcount = ntohs(dh->arcount);
780 uint16_t rrtype;
781 uint16_t rrclass;
782 string blob;
783 struct dnsrecordheader ah;
784
785 rrname = pr.getName();
786 rrtype = pr.get16BitInt();
787 rrclass = pr.get16BitInt();
788
789 GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
790 pw.getHeader()->id = dh->id;
791 pw.getHeader()->qr = dh->qr;
792 pw.getHeader()->aa = dh->aa;
793 pw.getHeader()->tc = dh->tc;
794 pw.getHeader()->rd = dh->rd;
795 pw.getHeader()->ra = dh->ra;
796 pw.getHeader()->ad = dh->ad;
797 pw.getHeader()->cd = dh->cd;
798 pw.getHeader()->rcode = dh->rcode;
799
800 /* consume remaining qd if any */
801 if (qdcount > 1) {
802 for (idx = 1; idx < qdcount; idx++) {
803 rrname = pr.getName();
804 rrtype = pr.get16BitInt();
805 rrclass = pr.get16BitInt();
806 (void)rrtype;
807 (void)rrclass;
808 }
809 }
810
811 /* copy AN and NS */
812 for (idx = 0; idx < ancount; idx++) {
813 rrname = pr.getName();
814 pr.getDnsrecordheader(ah);
815
816 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
817 pr.xfrBlob(blob);
818 pw.xfrBlob(blob);
819 }
820
821 for (idx = 0; idx < nscount; idx++) {
822 rrname = pr.getName();
823 pr.getDnsrecordheader(ah);
824
825 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
826 pr.xfrBlob(blob);
827 pw.xfrBlob(blob);
828 }
829
830 /* consume AR, looking for OPT */
831 for (idx = 0; idx < arcount; idx++) {
832 rrname = pr.getName();
833 pr.getDnsrecordheader(ah);
834
835 if (ah.d_type != QType::OPT) {
836 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
837 pr.xfrBlob(blob);
838 pw.xfrBlob(blob);
839 }
840 else {
841 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, false);
842 pr.xfrBlob(blob);
843 uint16_t rdLen = blob.length();
844 removeEDNSOptionFromOptions((unsigned char*)blob.c_str(), rdLen, optionCodeToSkip, &rdLen);
845 /* xfrBlob(string, size) completely ignores size.. */
846 if (rdLen > 0) {
847 blob.resize((size_t)rdLen);
848 pw.xfrBlob(blob);
849 }
850 else {
851 pw.commit();
852 }
853 }
854 }
855 pw.commit();
856
857 return 0;
858 }
859
860 bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
861 {
862 if (!generateOptRR(std::string(), packet, maximumSize, payloadSize, ednsrcode, dnssecOK)) {
863 return false;
864 }
865
866 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
867 header.arcount = htons(ntohs(header.arcount) + 1);
868 return true;
869 });
870
871 return true;
872 }
873
874 /*
875 This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
876 generating a NXD or NODATA answer with a SOA record in the additional section (or optionally the authority section for a full cacheable NXDOMAIN/NODATA).
877 */
878 bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, bool soaInAuthoritySection)
879 {
880 auto& packet = dq.getMutableData();
881 auto dh = dq.getHeader();
882 if (ntohs(dh->qdcount) != 1) {
883 return false;
884 }
885
886 size_t queryPartSize = sizeof(dnsheader) + dq.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
887 if (packet.size() < queryPartSize) {
888 /* something is already wrong, don't build on flawed foundations */
889 return false;
890 }
891
892 uint16_t qtype = htons(QType::SOA);
893 uint16_t qclass = htons(QClass::IN);
894 uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
895 size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
896 bool hadEDNS = false;
897 bool dnssecOK = false;
898
899 if (g_addEDNSToSelfGeneratedResponses) {
900 uint16_t payloadSize = 0;
901 uint16_t z = 0;
902 hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(packet.data()), packet.size(), &payloadSize, &z);
903 if (hadEDNS) {
904 dnssecOK = z & EDNS_HEADER_FLAG_DO;
905 }
906 }
907
908 /* chop off everything after the question */
909 packet.resize(queryPartSize);
910 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) {
911 if (nxd) {
912 header.rcode = RCode::NXDomain;
913 }
914 else {
915 header.rcode = RCode::NoError;
916 }
917 header.qr = true;
918 header.ancount = 0;
919 header.nscount = 0;
920 header.arcount = 0;
921 return true;
922 });
923
924 rdLength = htons(rdLength);
925 ttl = htonl(ttl);
926 serial = htonl(serial);
927 refresh = htonl(refresh);
928 retry = htonl(retry);
929 expire = htonl(expire);
930 minimum = htonl(minimum);
931
932 std::string soa;
933 soa.reserve(soaSize);
934 soa.append(zone.toDNSString());
935 soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
936 soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
937 soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
938 soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
939 soa.append(mname.toDNSString());
940 soa.append(rname.toDNSString());
941 soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
942 soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
943 soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
944 soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
945 soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
946
947 if (soa.size() != soaSize) {
948 throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
949 }
950
951 packet.insert(packet.end(), soa.begin(), soa.end());
952
953 /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
954 NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
955 and have EDNS added to AR afterwards */
956 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) {
957 if (soaInAuthoritySection) {
958 header.nscount = htons(1);
959 }
960 else {
961 header.arcount = htons(1);
962 }
963 return true;
964 });
965
966 if (hadEDNS) {
967 /* now we need to add a new OPT record */
968 return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
969 }
970
971 return true;
972 }
973
974 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
975 {
976 uint16_t optRDPosition;
977 /* remaining is at least the size of the rdlen + the options if any + the following records if any */
978 size_t remaining = 0;
979
980 auto& packet = dq.getMutableData();
981 int res = getEDNSOptionsStart(packet, dq.ids.qname.wirelength(), &optRDPosition, &remaining);
982
983 if (res != 0) {
984 /* if the initial query did not have EDNS0, we are done */
985 return true;
986 }
987
988 const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining;
989 if (existingOptLen >= packet.size()) {
990 /* something is wrong, bail out */
991 return false;
992 }
993
994 uint8_t* optRDLen = &packet.at(optRDPosition);
995 uint8_t* optPtr = (optRDLen - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
996
997 const uint8_t* zPtr = optPtr + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
998 uint16_t z = 0x100 * (*zPtr) + *(zPtr + 1);
999 bool dnssecOK = z & EDNS_HEADER_FLAG_DO;
1000
1001 /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
1002 packet.resize(packet.size() - existingOptLen);
1003 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
1004 header.arcount = 0;
1005 return true;
1006 });
1007
1008 if (g_addEDNSToSelfGeneratedResponses) {
1009 /* now we need to add a new OPT record */
1010 return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
1011 }
1012
1013 /* otherwise we are just fine */
1014 return true;
1015 }
1016
1017 // goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
1018 int getEDNSZ(const DNSQuestion& dq)
1019 {
1020 try {
1021 const auto& dh = dq.getHeader();
1022 if (ntohs(dh->qdcount) != 1 || dh->ancount != 0 || ntohs(dh->arcount) != 1 || dh->nscount != 0) {
1023 return 0;
1024 }
1025
1026 if (dq.getData().size() <= sizeof(dnsheader)) {
1027 return 0;
1028 }
1029
1030 size_t pos = sizeof(dnsheader) + dq.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
1031
1032 if (dq.getData().size() <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) {
1033 return 0;
1034 }
1035
1036 auto& packet = dq.getData();
1037
1038 if (packet.at(pos) != 0) {
1039 /* not root, so not a valid OPT record */
1040 return 0;
1041 }
1042
1043 pos++;
1044
1045 uint16_t qtype = packet.at(pos) * 256 + packet.at(pos + 1);
1046 pos += DNS_TYPE_SIZE;
1047 pos += DNS_CLASS_SIZE;
1048
1049 if (qtype != QType::OPT || (pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) {
1050 return 0;
1051 }
1052
1053 const uint8_t* z = &packet.at(pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE);
1054 return 0x100 * (*z) + *(z + 1);
1055 }
1056 catch (...) {
1057 return 0;
1058 }
1059 }
1060
1061 bool queryHasEDNS(const DNSQuestion& dq)
1062 {
1063 uint16_t optRDPosition;
1064 size_t ecsRemaining = 0;
1065
1066 int res = getEDNSOptionsStart(dq.getData(), dq.ids.qname.wirelength(), &optRDPosition, &ecsRemaining);
1067 if (res == 0) {
1068 return true;
1069 }
1070
1071 return false;
1072 }
1073
1074 bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0)
1075 {
1076 uint16_t optStart;
1077 size_t optLen = 0;
1078 bool last = false;
1079 int res = locateEDNSOptRR(packet, &optStart, &optLen, &last);
1080 if (res != 0) {
1081 // no EDNS OPT RR
1082 return false;
1083 }
1084
1085 if (optLen < optRecordMinimumSize) {
1086 return false;
1087 }
1088
1089 if (optStart < packet.size() && packet.at(optStart) != 0) {
1090 // OPT RR Name != '.'
1091 return false;
1092 }
1093
1094 static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
1095 // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
1096 memcpy(&edns0, &packet.at(optStart + 5), sizeof edns0);
1097 return true;
1098 }
1099
1100 bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsData)
1101 {
1102 std::string optRData;
1103 generateEDNSOption(ednsCode, ednsData, optRData);
1104
1105 if (dq.getHeader()->arcount) {
1106 bool ednsAdded = false;
1107 bool optionAdded = false;
1108 PacketBuffer newContent;
1109 newContent.reserve(dq.getData().size());
1110
1111 if (!slowRewriteEDNSOptionInQueryWithRecords(dq.getData(), newContent, ednsAdded, ednsCode, optionAdded, true, optRData)) {
1112 return false;
1113 }
1114
1115 if (newContent.size() > dq.getMaximumSize()) {
1116 return false;
1117 }
1118
1119 dq.getMutableData() = std::move(newContent);
1120 if (!dq.ids.ednsAdded && ednsAdded) {
1121 dq.ids.ednsAdded = true;
1122 }
1123
1124 return true;
1125 }
1126
1127 auto& data = dq.getMutableData();
1128 if (generateOptRR(optRData, data, dq.getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
1129 dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
1130 header.arcount = htons(1);
1131 return true;
1132 });
1133 // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
1134 dq.ids.ednsAdded = true;
1135 }
1136
1137 return true;
1138 }
1139
1140 namespace dnsdist
1141 {
1142 bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uint8_t rcode, bool clearAnswers)
1143 {
1144 const auto qnameLength = state.qname.wirelength();
1145 if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
1146 return false;
1147 }
1148
1149 EDNS0Record edns0;
1150 bool hadEDNS = false;
1151 if (clearAnswers) {
1152 hadEDNS = getEDNS0Record(buffer, edns0);
1153 }
1154
1155 dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode, clearAnswers](dnsheader& header) {
1156 header.rcode = rcode;
1157 header.ad = false;
1158 header.aa = false;
1159 header.ra = header.rd;
1160 header.qr = true;
1161
1162 if (clearAnswers) {
1163 header.ancount = 0;
1164 header.nscount = 0;
1165 header.arcount = 0;
1166 }
1167 return true;
1168 });
1169
1170 if (clearAnswers) {
1171 buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
1172 if (hadEDNS) {
1173 DNSQuestion dq(state, buffer);
1174 if (!addEDNS(buffer, dq.getMaximumSize(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) {
1175 return false;
1176 }
1177 }
1178 }
1179
1180 return true;
1181 }
1182 }