]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-ecs.cc
dnsdist: Add SetNegativeAndSOAAction() and its Lua binding
[thirdparty/pdns.git] / pdns / dnsdist-ecs.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dolog.hh"
23 #include "dnsdist.hh"
24 #include "dnsdist-ecs.hh"
25 #include "dnsparser.hh"
26 #include "dnswriter.hh"
27 #include "ednsoptions.hh"
28 #include "ednssubnet.hh"
29
30 /* when we add EDNS to a query, we don't want to advertise
31 a large buffer size */
32 size_t g_EdnsUDPPayloadSize = 512;
33 uint16_t g_PayloadSizeSelfGenAnswers{s_udpIncomingBufferSize};
34
35 /* draft-ietf-dnsop-edns-client-subnet-04 "11.1. Privacy" */
36 uint16_t g_ECSSourcePrefixV4 = 24;
37 uint16_t g_ECSSourcePrefixV6 = 56;
38
39 bool g_ECSOverride{false};
40 bool g_addEDNSToSelfGeneratedResponses{true};
41
42 int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent)
43 {
44 assert(initialPacket.size() >= sizeof(dnsheader));
45 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
46
47 if (ntohs(dh->arcount) == 0)
48 return ENOENT;
49
50 if (ntohs(dh->qdcount) == 0)
51 return ENOENT;
52
53 PacketReader pr(initialPacket);
54
55 size_t idx = 0;
56 DNSName rrname;
57 uint16_t qdcount = ntohs(dh->qdcount);
58 uint16_t ancount = ntohs(dh->ancount);
59 uint16_t nscount = ntohs(dh->nscount);
60 uint16_t arcount = ntohs(dh->arcount);
61 uint16_t rrtype;
62 uint16_t rrclass;
63 string blob;
64 struct dnsrecordheader ah;
65
66 rrname = pr.getName();
67 rrtype = pr.get16BitInt();
68 rrclass = pr.get16BitInt();
69
70 DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
71 pw.getHeader()->id=dh->id;
72 pw.getHeader()->qr=dh->qr;
73 pw.getHeader()->aa=dh->aa;
74 pw.getHeader()->tc=dh->tc;
75 pw.getHeader()->rd=dh->rd;
76 pw.getHeader()->ra=dh->ra;
77 pw.getHeader()->ad=dh->ad;
78 pw.getHeader()->cd=dh->cd;
79 pw.getHeader()->rcode=dh->rcode;
80
81 /* consume remaining qd if any */
82 if (qdcount > 1) {
83 for(idx = 1; idx < qdcount; idx++) {
84 rrname = pr.getName();
85 rrtype = pr.get16BitInt();
86 rrclass = pr.get16BitInt();
87 (void) rrtype;
88 (void) rrclass;
89 }
90 }
91
92 /* copy AN and NS */
93 for (idx = 0; idx < ancount; idx++) {
94 rrname = pr.getName();
95 pr.getDnsrecordheader(ah);
96
97 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
98 pr.xfrBlob(blob);
99 pw.xfrBlob(blob);
100 }
101
102 for (idx = 0; idx < nscount; idx++) {
103 rrname = pr.getName();
104 pr.getDnsrecordheader(ah);
105
106 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
107 pr.xfrBlob(blob);
108 pw.xfrBlob(blob);
109 }
110 /* consume AR, looking for OPT */
111 for (idx = 0; idx < arcount; idx++) {
112 rrname = pr.getName();
113 pr.getDnsrecordheader(ah);
114
115 if (ah.d_type != QType::OPT) {
116 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
117 pr.xfrBlob(blob);
118 pw.xfrBlob(blob);
119 } else {
120
121 pr.skip(ah.d_clen);
122 }
123 }
124 pw.commit();
125
126 return 0;
127 }
128
129 static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>& options, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
130 {
131 for (auto it = options.begin(); it != options.end(); ) {
132 if (it->first == EDNSOptionCode::ECS) {
133 ecsAdded = false;
134
135 if (!overrideExisting) {
136 return false;
137 }
138
139 it = options.erase(it);
140 }
141 else {
142 ++it;
143 }
144 }
145
146 options.emplace_back(EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
147 return true;
148 }
149
150 static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
151 {
152 assert(initialPacket.size() >= sizeof(dnsheader));
153 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
154
155 ecsAdded = false;
156 ednsAdded = true;
157
158 if (ntohs(dh->qdcount) == 0) {
159 return false;
160 }
161
162 if (ntohs(dh->arcount) == 0) {
163 throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
164 }
165
166 PacketReader pr(initialPacket);
167
168 size_t idx = 0;
169 DNSName rrname;
170 uint16_t qdcount = ntohs(dh->qdcount);
171 uint16_t ancount = ntohs(dh->ancount);
172 uint16_t nscount = ntohs(dh->nscount);
173 uint16_t arcount = ntohs(dh->arcount);
174 uint16_t rrtype;
175 uint16_t rrclass;
176 string blob;
177 struct dnsrecordheader ah;
178
179 rrname = pr.getName();
180 rrtype = pr.get16BitInt();
181 rrclass = pr.get16BitInt();
182
183 DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
184 pw.getHeader()->id=dh->id;
185 pw.getHeader()->qr=dh->qr;
186 pw.getHeader()->aa=dh->aa;
187 pw.getHeader()->tc=dh->tc;
188 pw.getHeader()->rd=dh->rd;
189 pw.getHeader()->ra=dh->ra;
190 pw.getHeader()->ad=dh->ad;
191 pw.getHeader()->cd=dh->cd;
192 pw.getHeader()->rcode=dh->rcode;
193
194 /* consume remaining qd if any */
195 if (qdcount > 1) {
196 for(idx = 1; idx < qdcount; idx++) {
197 rrname = pr.getName();
198 rrtype = pr.get16BitInt();
199 rrclass = pr.get16BitInt();
200 (void) rrtype;
201 (void) rrclass;
202 }
203 }
204
205 /* copy AN and NS */
206 for (idx = 0; idx < ancount; idx++) {
207 rrname = pr.getName();
208 pr.getDnsrecordheader(ah);
209
210 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
211 pr.xfrBlob(blob);
212 pw.xfrBlob(blob);
213 }
214
215 for (idx = 0; idx < nscount; idx++) {
216 rrname = pr.getName();
217 pr.getDnsrecordheader(ah);
218
219 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
220 pr.xfrBlob(blob);
221 pw.xfrBlob(blob);
222 }
223
224 /* consume AR, looking for OPT */
225 for (idx = 0; idx < arcount; idx++) {
226 rrname = pr.getName();
227 pr.getDnsrecordheader(ah);
228
229 if (ah.d_type != QType::OPT) {
230 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
231 pr.xfrBlob(blob);
232 pw.xfrBlob(blob);
233 } else {
234
235 ednsAdded = false;
236 pr.xfrBlob(blob);
237
238 std::vector<std::pair<uint16_t, std::string>> options;
239 getEDNSOptionsFromContent(blob, options);
240
241 EDNS0Record edns0;
242 static_assert(sizeof(edns0) == sizeof(ah.d_ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
243 memcpy(&edns0, &ah.d_ttl, sizeof(edns0));
244
245 /* addOrReplaceECSOption will set it to false if there is already an existing option */
246 ecsAdded = true;
247 addOrReplaceECSOption(options, ecsAdded, overrideExisting, newECSOption);
248 pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
249 }
250 }
251
252 if (ednsAdded) {
253 pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{EDNSOptionCode::ECS, std::string(&newECSOption.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newECSOption.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
254 ecsAdded = true;
255 }
256
257 pw.commit();
258
259 return true;
260 }
261
262 static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::shared_ptr<std::map<uint16_t, EDNSOptionView> >& options)
263 {
264 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
265
266 if (len < sizeof(dnsheader) || ntohs(dh->qdcount) == 0)
267 {
268 return false;
269 }
270
271 if (ntohs(dh->arcount) == 0) {
272 throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
273 }
274
275 try {
276 uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
277 DNSPacketMangler dpm(const_cast<char*>(packet), len);
278 uint64_t n;
279 for(n=0; n < ntohs(dh->qdcount) ; ++n) {
280 dpm.skipDomainName();
281 /* type and class */
282 dpm.skipBytes(4);
283 }
284
285 for(n=0; n < numrecords; ++n) {
286 dpm.skipDomainName();
287
288 uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
289 uint16_t dnstype = dpm.get16BitInt();
290 dpm.get16BitInt();
291 dpm.skipBytes(4); /* TTL */
292
293 if(section == 3 && dnstype == QType::OPT) {
294 uint32_t offset = dpm.getOffset();
295 if (offset >= len) {
296 return false;
297 }
298 /* if we survive this call, we can parse it safely */
299 dpm.skipRData();
300 return getEDNSOptions(packet + offset, len - offset, *options) == 0;
301 }
302 else {
303 dpm.skipRData();
304 }
305 }
306 }
307 catch(...)
308 {
309 return false;
310 }
311
312 return true;
313 }
314
315 int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last)
316 {
317 assert(optStart != NULL);
318 assert(optLen != NULL);
319 assert(last != NULL);
320 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
321
322 if (ntohs(dh->arcount) == 0)
323 return ENOENT;
324
325 PacketReader pr(packet);
326 size_t idx = 0;
327 DNSName rrname;
328 uint16_t qdcount = ntohs(dh->qdcount);
329 uint16_t ancount = ntohs(dh->ancount);
330 uint16_t nscount = ntohs(dh->nscount);
331 uint16_t arcount = ntohs(dh->arcount);
332 uint16_t rrtype;
333 uint16_t rrclass;
334 struct dnsrecordheader ah;
335
336 /* consume qd */
337 for(idx = 0; idx < qdcount; idx++) {
338 rrname = pr.getName();
339 rrtype = pr.get16BitInt();
340 rrclass = pr.get16BitInt();
341 (void) rrtype;
342 (void) rrclass;
343 }
344
345 /* consume AN and NS */
346 for (idx = 0; idx < ancount + nscount; idx++) {
347 rrname = pr.getName();
348 pr.getDnsrecordheader(ah);
349 pr.skip(ah.d_clen);
350 }
351
352 /* consume AR, looking for OPT */
353 for (idx = 0; idx < arcount; idx++) {
354 uint16_t start = pr.getPosition();
355 rrname = pr.getName();
356 pr.getDnsrecordheader(ah);
357
358 if (ah.d_type == QType::OPT) {
359 *optStart = start;
360 *optLen = (pr.getPosition() - start) + ah.d_clen;
361
362 if (packet.size() < (*optStart + *optLen)) {
363 throw std::range_error("Opt record overflow");
364 }
365
366 if (idx == ((size_t) arcount - 1)) {
367 *last = true;
368 }
369 else {
370 *last = false;
371 }
372 return 0;
373 }
374 pr.skip(ah.d_clen);
375 }
376
377 return ENOENT;
378 }
379
380 /* extract the start of the OPT RR in a QUERY packet if any */
381 int getEDNSOptionsStart(const char* packet, const size_t offset, const size_t len, uint16_t* optRDPosition, size_t * remaining)
382 {
383 assert(packet != nullptr);
384 assert(optRDPosition != nullptr);
385 assert(remaining != nullptr);
386 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
387
388 if (offset >= len) {
389 return ENOENT;
390 }
391
392 if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0)
393 return ENOENT;
394
395 size_t pos = sizeof(dnsheader) + offset;
396 pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
397
398 if (pos >= len)
399 return ENOENT;
400
401 if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= len) {
402 return ENOENT;
403 }
404
405 if (packet[pos] != 0) {
406 /* not the root so not an OPT record */
407 return ENOENT;
408 }
409 pos += 1;
410
411 uint16_t qtype = (reinterpret_cast<const unsigned char*>(packet)[pos])*256 + reinterpret_cast<const unsigned char*>(packet)[pos+1];
412 pos += DNS_TYPE_SIZE;
413 pos += DNS_CLASS_SIZE;
414
415 if(qtype != QType::OPT || (len - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE))
416 return ENOENT;
417
418 pos += DNS_TTL_SIZE;
419 *optRDPosition = pos;
420 *remaining = len - pos;
421
422 return 0;
423 }
424
425 void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
426 {
427 Netmask sourceNetmask(source, ECSPrefixLength);
428 EDNSSubnetOpts ecsOpts;
429 ecsOpts.source = sourceNetmask;
430 string payload = makeEDNSSubnetOptsString(ecsOpts);
431 generateEDNSOption(EDNSOptionCode::ECS, payload, res);
432 }
433
434 void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
435 {
436 const uint8_t name = 0;
437 dnsrecordheader dh;
438 EDNS0Record edns0;
439 edns0.extRCode = ednsrcode;
440 edns0.version = 0;
441 edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
442
443 dh.d_type = htons(QType::OPT);
444 dh.d_class = htons(udpPayloadSize);
445 static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
446 memcpy(&dh.d_ttl, &edns0, sizeof edns0);
447 dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
448 res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
449 res.assign(reinterpret_cast<const char *>(&name), sizeof name);
450 res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
451 res.append(optRData.c_str(), optRData.length());
452 }
453
454 static bool replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, const string& newECSOption)
455 {
456 assert(packet != NULL);
457 assert(len != NULL);
458 assert(oldEcsOptionStart != NULL);
459 assert(optRDLen != NULL);
460
461 if (newECSOption.size() == oldEcsOptionSize) {
462 /* same size as the existing option */
463 memcpy(oldEcsOptionStart, newECSOption.c_str(), oldEcsOptionSize);
464 }
465 else {
466 /* different size than the existing option */
467 const unsigned int newPacketLen = *len + (newECSOption.length() - oldEcsOptionSize);
468 const size_t beforeOptionLen = oldEcsOptionStart - packet;
469 const size_t dataBehindSize = *len - beforeOptionLen - oldEcsOptionSize;
470
471 /* check that it fits in the existing buffer */
472 if (newPacketLen > packetSize) {
473 return false;
474 }
475
476 /* fix the size of ECS Option RDLen */
477 uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1];
478 newRDLen += (newECSOption.size() - oldEcsOptionSize);
479 optRDLen[0] = newRDLen / 256;
480 optRDLen[1] = newRDLen % 256;
481
482 if (dataBehindSize > 0) {
483 memmove(oldEcsOptionStart, oldEcsOptionStart + oldEcsOptionSize, dataBehindSize);
484 }
485 memcpy(oldEcsOptionStart + dataBehindSize, newECSOption.c_str(), newECSOption.size());
486 *len = newPacketLen;
487 }
488
489 return true;
490 }
491
492 /* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
493 and false otherwise. */
494 bool parseEDNSOptions(DNSQuestion& dq)
495 {
496 assert(dq.dh != nullptr);
497 assert(dq.consumed <= dq.len);
498 assert(dq.len <= dq.size);
499
500 if (dq.ednsOptions != nullptr) {
501 return true;
502 }
503
504 dq.ednsOptions = std::make_shared<std::map<uint16_t, EDNSOptionView> >();
505
506 if (ntohs(dq.dh->ancount) != 0 || ntohs(dq.dh->nscount) != 0 || (ntohs(dq.dh->arcount) != 0 && ntohs(dq.dh->arcount) != 1)) {
507 return slowParseEDNSOptions(reinterpret_cast<const char*>(dq.dh), dq.len, dq.ednsOptions);
508 }
509
510 const char* packet = reinterpret_cast<const char*>(dq.dh);
511
512 size_t remaining = 0;
513 uint16_t optRDPosition;
514 int res = getEDNSOptionsStart(packet, dq.consumed, dq.len, &optRDPosition, &remaining);
515
516 if (res == 0) {
517 res = getEDNSOptions(packet + optRDPosition, remaining, *dq.ednsOptions);
518 return (res == 0);
519 }
520
521 return false;
522 }
523
524 static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool& ecsAdded)
525 {
526 /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
527 /* getEDNSOptionsStart has already checked that there is exactly one AR,
528 no NS and no AN */
529
530 /* check if the existing buffer is large enough */
531 const size_t newECSOptionSize = newECSOption.size();
532 if (packetSize - *len <= newECSOptionSize) {
533 return false;
534 }
535
536 uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1];
537 newRDLen += newECSOptionSize;
538 optRDLen[0] = newRDLen / 256;
539 optRDLen[1] = newRDLen % 256;
540
541 memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize);
542 *len += newECSOptionSize;
543 ecsAdded = true;
544
545 return true;
546 }
547
548 static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
549 {
550 /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
551 string EDNSRR;
552 struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet);
553 generateOptRR(newECSOption, EDNSRR, g_EdnsUDPPayloadSize, 0, false);
554
555 /* does it fit in the existing buffer? */
556 if (packetSize - *len <= EDNSRR.size()) {
557 return false;
558 }
559
560 uint32_t realPacketLen = getDNSPacketLength(packet, *len);
561 if (realPacketLen < *len && preserveTrailingData) {
562 size_t toMove = *len - realPacketLen;
563 memmove(packet + realPacketLen + EDNSRR.size(), packet + realPacketLen, toMove);
564 *len += EDNSRR.size();
565 }
566 else {
567 *len = realPacketLen + EDNSRR.size();
568 }
569
570 uint16_t arcount = ntohs(dh->arcount);
571 arcount++;
572 dh->arcount = htons(arcount);
573 ednsAdded = true;
574 ecsAdded = true;
575
576 memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size());
577
578 return true;
579 }
580
581 bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData)
582 {
583 assert(packet != nullptr);
584 assert(len != nullptr);
585 assert(consumed <= (size_t) *len);
586
587 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
588
589 if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
590 vector<uint8_t> newContent;
591 newContent.reserve(packetSize);
592
593 if (!slowRewriteQueryWithExistingEDNS(std::string(packet, *len), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
594 ednsAdded = false;
595 ecsAdded = false;
596 return false;
597 }
598
599 if (newContent.size() > packetSize) {
600 ednsAdded = false;
601 ecsAdded = false;
602 return false;
603 }
604
605 memcpy(packet, &newContent.at(0), newContent.size());
606 *len = newContent.size();
607 return true;
608 }
609
610 uint16_t optRDPosition = 0;
611 size_t remaining = 0;
612
613 int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining);
614
615 if (res != 0) {
616 return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, ecsAdded, preserveTrailingData);
617 }
618
619 unsigned char* optRDLen = reinterpret_cast<unsigned char*>(packet) + optRDPosition;
620 char * ecsOptionStart = nullptr;
621 size_t ecsOptionSize = 0;
622
623 res = getEDNSOption(reinterpret_cast<char*>(optRDLen), remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize);
624
625 if (res == 0) {
626 /* there is already an ECS value */
627 if (!overrideExisting) {
628 return true;
629 }
630
631 return replaceEDNSClientSubnetOption(packet, packetSize, len, ecsOptionStart, ecsOptionSize, optRDLen, newECSOption);
632 } else {
633 /* we have an EDNS OPT RR but no existing ECS option */
634 return addECSToExistingOPT(packet, packetSize, len, newECSOption, optRDLen, ecsAdded);
635 }
636
637 return true;
638 }
639
640 bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData)
641 {
642 assert(dq.remote != nullptr);
643 string newECSOption;
644 generateECSOption(dq.ecsSet ? dq.ecs.getNetwork() : *dq.remote, newECSOption, dq.ecsSet ? dq.ecs.getBits() : dq.ecsPrefixLength);
645 char* packet = reinterpret_cast<char*>(dq.dh);
646
647 return handleEDNSClientSubnet(packet, dq.size, dq.consumed, &dq.len, ednsAdded, ecsAdded, dq.ecsOverride, newECSOption, preserveTrailingData);
648 }
649
650 static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
651 {
652 unsigned char* p = optionsStart;
653 size_t pos = 0;
654 while ((pos + 4) <= optionsLen) {
655 unsigned char* optionBegin = p;
656 const uint16_t optionCode = 0x100*p[0] + p[1];
657 p += sizeof(optionCode);
658 pos += sizeof(optionCode);
659 const uint16_t optionLen = 0x100*p[0] + p[1];
660 p += sizeof(optionLen);
661 pos += sizeof(optionLen);
662 if ((pos + optionLen) > optionsLen) {
663 return EINVAL;
664 }
665 if (optionCode == optionCodeToRemove) {
666 if (pos + optionLen < optionsLen) {
667 /* move remaining options over the removed one,
668 if any */
669 memmove(optionBegin, p + optionLen, optionsLen - (pos + optionLen));
670 }
671 *newOptionsLen = optionsLen - (sizeof(optionCode) + sizeof(optionLen) + optionLen);
672 return 0;
673 }
674 p += optionLen;
675 pos += optionLen;
676 }
677 return ENOENT;
678 }
679
680 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
681 {
682 if (*optLen < optRecordMinimumSize) {
683 return EINVAL;
684 }
685 const unsigned char* end = (const unsigned char*) optStart + *optLen;
686 unsigned char* p = (unsigned char*) optStart + 9;
687 unsigned char* rdLenPtr = p;
688 uint16_t rdLen = (0x100*p[0] + p[1]);
689 p += sizeof(rdLen);
690 if (p + rdLen != end) {
691 return EINVAL;
692 }
693 uint16_t newRdLen = 0;
694 int res = removeEDNSOptionFromOptions(p, rdLen, optionCodeToRemove, &newRdLen);
695 if (res != 0) {
696 return res;
697 }
698 *optLen -= (rdLen - newRdLen);
699 rdLenPtr[0] = newRdLen / 0x100;
700 rdLenPtr[1] = newRdLen % 0x100;
701 return 0;
702 }
703
704 bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
705 {
706 if (optLen < optRecordMinimumSize) {
707 return false;
708 }
709 size_t p = optStart + 9;
710 uint16_t rdLen = (0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1)));
711 p += sizeof(rdLen);
712 if (rdLen > (optLen - optRecordMinimumSize)) {
713 return false;
714 }
715
716 size_t rdEnd = p + rdLen;
717 while ((p + 4) <= rdEnd) {
718 const uint16_t optionCode = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
719 p += sizeof(optionCode);
720 const uint16_t optionLen = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
721 p += sizeof(optionLen);
722
723 if ((p + optionLen) > rdEnd) {
724 return false;
725 }
726
727 if (optionCode == optionCodeToFind) {
728 if (optContentStart != nullptr) {
729 *optContentStart = p;
730 }
731
732 if (optContentLen != nullptr) {
733 *optContentLen = optionLen;
734 }
735
736 return true;
737 }
738 p += optionLen;
739 }
740 return false;
741 }
742
743 int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent)
744 {
745 assert(initialPacket.size() >= sizeof(dnsheader));
746 const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
747
748 if (ntohs(dh->arcount) == 0)
749 return ENOENT;
750
751 if (ntohs(dh->qdcount) == 0)
752 return ENOENT;
753
754 PacketReader pr(initialPacket);
755
756 size_t idx = 0;
757 DNSName rrname;
758 uint16_t qdcount = ntohs(dh->qdcount);
759 uint16_t ancount = ntohs(dh->ancount);
760 uint16_t nscount = ntohs(dh->nscount);
761 uint16_t arcount = ntohs(dh->arcount);
762 uint16_t rrtype;
763 uint16_t rrclass;
764 string blob;
765 struct dnsrecordheader ah;
766
767 rrname = pr.getName();
768 rrtype = pr.get16BitInt();
769 rrclass = pr.get16BitInt();
770
771 DNSPacketWriter pw(newContent, rrname, rrtype, rrclass, dh->opcode);
772 pw.getHeader()->id=dh->id;
773 pw.getHeader()->qr=dh->qr;
774 pw.getHeader()->aa=dh->aa;
775 pw.getHeader()->tc=dh->tc;
776 pw.getHeader()->rd=dh->rd;
777 pw.getHeader()->ra=dh->ra;
778 pw.getHeader()->ad=dh->ad;
779 pw.getHeader()->cd=dh->cd;
780 pw.getHeader()->rcode=dh->rcode;
781
782 /* consume remaining qd if any */
783 if (qdcount > 1) {
784 for(idx = 1; idx < qdcount; idx++) {
785 rrname = pr.getName();
786 rrtype = pr.get16BitInt();
787 rrclass = pr.get16BitInt();
788 (void) rrtype;
789 (void) rrclass;
790 }
791 }
792
793 /* copy AN and NS */
794 for (idx = 0; idx < ancount; idx++) {
795 rrname = pr.getName();
796 pr.getDnsrecordheader(ah);
797
798 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
799 pr.xfrBlob(blob);
800 pw.xfrBlob(blob);
801 }
802
803 for (idx = 0; idx < nscount; idx++) {
804 rrname = pr.getName();
805 pr.getDnsrecordheader(ah);
806
807 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
808 pr.xfrBlob(blob);
809 pw.xfrBlob(blob);
810 }
811
812 /* consume AR, looking for OPT */
813 for (idx = 0; idx < arcount; idx++) {
814 rrname = pr.getName();
815 pr.getDnsrecordheader(ah);
816
817 if (ah.d_type != QType::OPT) {
818 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
819 pr.xfrBlob(blob);
820 pw.xfrBlob(blob);
821 } else {
822 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, false);
823 pr.xfrBlob(blob);
824 uint16_t rdLen = blob.length();
825 removeEDNSOptionFromOptions((unsigned char*)blob.c_str(), rdLen, optionCodeToSkip, &rdLen);
826 /* xfrBlob(string, size) completely ignores size.. */
827 if (rdLen > 0) {
828 blob.resize((size_t)rdLen);
829 pw.xfrBlob(blob);
830 } else {
831 pw.commit();
832 }
833 }
834 }
835 pw.commit();
836
837 return 0;
838 }
839
840 bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
841 {
842 std::string optRecord;
843 generateOptRR(std::string(), optRecord, payloadSize, ednsrcode, dnssecOK);
844
845 if (optRecord.size() >= size || (size - optRecord.size()) < len) {
846 return false;
847 }
848
849 char * optPtr = reinterpret_cast<char*>(dh) + len;
850 memcpy(optPtr, optRecord.data(), optRecord.size());
851 len += optRecord.size();
852 dh->arcount = htons(ntohs(dh->arcount) + 1);
853
854 return true;
855 }
856
857 /*
858 This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
859 generating a NXD or NODATA answer with a SOA record in the additional section.
860 */
861 bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum)
862 {
863 if (ntohs(dq.dh->qdcount) != 1) {
864 return false;
865 }
866
867 assert(dq.consumed == dq.qname->wirelength());
868 size_t queryPartSize = sizeof(dnsheader) + dq.consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
869 if (dq.len < queryPartSize) {
870 /* something is already wrong, don't build on flawed foundations */
871 return false;
872 }
873
874 size_t available = dq.size - queryPartSize;
875 uint16_t qtype = htons(QType::SOA);
876 uint16_t qclass = htons(QClass::IN);
877 uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
878 size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
879
880 if (soaSize > available) {
881 /* not enough space left to add the SOA, sorry! */
882 return false;
883 }
884
885 bool hadEDNS = false;
886 bool dnssecOK = false;
887
888 if (g_addEDNSToSelfGeneratedResponses) {
889 uint16_t payloadSize = 0;
890 uint16_t z = 0;
891 hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(dq.dh), dq.len, &payloadSize, &z);
892 if (hadEDNS) {
893 dnssecOK = z & EDNS_HEADER_FLAG_DO;
894 }
895 }
896
897 /* chop off everything after the question */
898 dq.len = queryPartSize;
899 if (nxd) {
900 dq.dh->rcode = RCode::NXDomain;
901 }
902 else {
903 dq.dh->rcode = RCode::NoError;
904 }
905 dq.dh->qr = true;
906 dq.dh->ancount = 0;
907 dq.dh->nscount = 0;
908 dq.dh->arcount = 0;
909
910 rdLength = htons(rdLength);
911 ttl = htonl(ttl);
912 serial = htonl(serial);
913 refresh = htonl(refresh);
914 retry = htonl(retry);
915 expire = htonl(expire);
916 minimum = htonl(minimum);
917
918 std::string soa;
919 soa.reserve(soaSize);
920 soa.append(zone.toDNSString());
921 soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
922 soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
923 soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
924 soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
925 soa.append(mname.toDNSString());
926 soa.append(rname.toDNSString());
927 soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
928 soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
929 soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
930 soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
931 soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
932
933 if (soa.size() != soaSize) {
934 throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
935 }
936
937 memcpy(reinterpret_cast<char*>(dq.dh) + queryPartSize, soa.c_str(), soa.size());
938
939 dq.len += soa.size();
940
941 dq.dh->arcount = htons(1);
942
943 if (g_addEDNSToSelfGeneratedResponses) {
944 /* now we need to add a new OPT record */
945 return addEDNS(dq.dh, dq.len, dq.size, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
946 }
947
948 return true;
949 }
950
951 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
952 {
953 uint16_t optRDPosition;
954 /* remaining is at least the size of the rdlen + the options if any + the following records if any */
955 size_t remaining = 0;
956
957 int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDPosition, &remaining);
958
959 if (res != 0) {
960 /* if the initial query did not have EDNS0, we are done */
961 return true;
962 }
963
964 const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining;
965 if (existingOptLen >= dq.len) {
966 /* something is wrong, bail out */
967 return false;
968 }
969
970 char* optRDLen = reinterpret_cast<char*>(dq.dh) + optRDPosition;
971 char * optPtr = (optRDLen - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
972
973 const uint8_t* zPtr = reinterpret_cast<const uint8_t*>(optPtr) + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
974 uint16_t z = 0x100 * (*zPtr) + *(zPtr + 1);
975 bool dnssecOK = z & EDNS_HEADER_FLAG_DO;
976
977 /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
978 dq.len -= existingOptLen;
979 dq.dh->arcount = 0;
980
981 if (g_addEDNSToSelfGeneratedResponses) {
982 /* now we need to add a new OPT record */
983 return addEDNS(dq.dh, dq.len, dq.size, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
984 }
985
986 /* otherwise we are just fine */
987 return true;
988 }
989
990 // goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
991 int getEDNSZ(const DNSQuestion& dq)
992 try
993 {
994 if (ntohs(dq.dh->qdcount) != 1 || dq.dh->ancount != 0 || ntohs(dq.dh->arcount) != 1 || dq.dh->nscount != 0) {
995 return 0;
996 }
997
998 if (dq.len <= sizeof(dnsheader)) {
999 return 0;
1000 }
1001
1002 size_t pos = sizeof(dnsheader) + dq.consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
1003
1004 if (dq.len <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) {
1005 return 0;
1006 }
1007
1008 const char* packet = reinterpret_cast<const char*>(dq.dh);
1009
1010 if (packet[pos] != 0) {
1011 /* not root, so not a valid OPT record */
1012 return 0;
1013 }
1014
1015 pos++;
1016
1017 uint16_t qtype = (reinterpret_cast<const unsigned char*>(packet)[pos])*256 + reinterpret_cast<const unsigned char*>(packet)[pos+1];
1018 pos += DNS_TYPE_SIZE;
1019 pos += DNS_CLASS_SIZE;
1020
1021 if (qtype != QType::OPT || (pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= dq.len) {
1022 return 0;
1023 }
1024
1025 const uint8_t* z = reinterpret_cast<const uint8_t*>(packet) + pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
1026 return 0x100 * (*z) + *(z+1);
1027 }
1028 catch(...)
1029 {
1030 return 0;
1031 }
1032
1033 bool queryHasEDNS(const DNSQuestion& dq)
1034 {
1035 uint16_t optRDPosition;
1036 size_t ecsRemaining = 0;
1037
1038 int res = getEDNSOptionsStart(reinterpret_cast<char*>(dq.dh), dq.consumed, dq.len, &optRDPosition, &ecsRemaining);
1039 if (res == 0) {
1040 return true;
1041 }
1042
1043 return false;
1044 }
1045
1046 bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
1047 {
1048 uint16_t optStart;
1049 size_t optLen = 0;
1050 bool last = false;
1051 const char * packet = reinterpret_cast<const char*>(dq.dh);
1052 std::string packetStr(packet, dq.len);
1053 int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
1054 if (res != 0) {
1055 // no EDNS OPT RR
1056 return false;
1057 }
1058
1059 if (optLen < optRecordMinimumSize) {
1060 return false;
1061 }
1062
1063 if (optStart < dq.len && packetStr.at(optStart) != 0) {
1064 // OPT RR Name != '.'
1065 return false;
1066 }
1067
1068 static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
1069 // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
1070 memcpy(&edns0, packet + optStart + 5, sizeof edns0);
1071 return true;
1072 }