2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 #include "dnsdist-dnsparser.hh"
25 #include "dnsdist-ecs.hh"
26 #include "dnsparser.hh"
27 #include "dnswriter.hh"
28 #include "ednsoptions.hh"
29 #include "ednssubnet.hh"
31 /* when we add EDNS to a query, we don't want to advertise
32 a large buffer size */
33 size_t g_EdnsUDPPayloadSize
= 512;
34 static const uint16_t defaultPayloadSizeSelfGenAnswers
= 1232;
35 static_assert(defaultPayloadSizeSelfGenAnswers
< s_udpIncomingBufferSize
, "The UDP responder's payload size should be smaller or equal to our incoming buffer size");
36 uint16_t g_PayloadSizeSelfGenAnswers
{defaultPayloadSizeSelfGenAnswers
};
38 /* draft-ietf-dnsop-edns-client-subnet-04 "11.1. Privacy" */
39 uint16_t g_ECSSourcePrefixV4
= 24;
40 uint16_t g_ECSSourcePrefixV6
= 56;
42 bool g_ECSOverride
{false};
43 bool g_addEDNSToSelfGeneratedResponses
{true};
45 int rewriteResponseWithoutEDNS(const PacketBuffer
& initialPacket
, PacketBuffer
& newContent
)
47 assert(initialPacket
.size() >= sizeof(dnsheader
));
48 const dnsheader_aligned
dh(initialPacket
.data());
50 if (ntohs(dh
->arcount
) == 0) {
54 if (ntohs(dh
->qdcount
) == 0) {
58 PacketReader
pr(std::string_view(reinterpret_cast<const char*>(initialPacket
.data()), initialPacket
.size()));
62 uint16_t qdcount
= ntohs(dh
->qdcount
);
63 uint16_t ancount
= ntohs(dh
->ancount
);
64 uint16_t nscount
= ntohs(dh
->nscount
);
65 uint16_t arcount
= ntohs(dh
->arcount
);
69 struct dnsrecordheader ah
;
71 rrname
= pr
.getName();
72 rrtype
= pr
.get16BitInt();
73 rrclass
= pr
.get16BitInt();
75 GenericDNSPacketWriter
<PacketBuffer
> pw(newContent
, rrname
, rrtype
, rrclass
, dh
->opcode
);
76 pw
.getHeader()->id
= dh
->id
;
77 pw
.getHeader()->qr
= dh
->qr
;
78 pw
.getHeader()->aa
= dh
->aa
;
79 pw
.getHeader()->tc
= dh
->tc
;
80 pw
.getHeader()->rd
= dh
->rd
;
81 pw
.getHeader()->ra
= dh
->ra
;
82 pw
.getHeader()->ad
= dh
->ad
;
83 pw
.getHeader()->cd
= dh
->cd
;
84 pw
.getHeader()->rcode
= dh
->rcode
;
86 /* consume remaining qd if any */
88 for (idx
= 1; idx
< qdcount
; idx
++) {
89 rrname
= pr
.getName();
90 rrtype
= pr
.get16BitInt();
91 rrclass
= pr
.get16BitInt();
98 for (idx
= 0; idx
< ancount
; idx
++) {
99 rrname
= pr
.getName();
100 pr
.getDnsrecordheader(ah
);
102 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ANSWER
, true);
107 for (idx
= 0; idx
< nscount
; idx
++) {
108 rrname
= pr
.getName();
109 pr
.getDnsrecordheader(ah
);
111 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::AUTHORITY
, true);
115 /* consume AR, looking for OPT */
116 for (idx
= 0; idx
< arcount
; idx
++) {
117 rrname
= pr
.getName();
118 pr
.getDnsrecordheader(ah
);
120 if (ah
.d_type
!= QType::OPT
) {
121 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ADDITIONAL
, true);
135 static bool addOrReplaceEDNSOption(std::vector
<std::pair
<uint16_t, std::string
>>& options
, uint16_t optionCode
, bool& optionAdded
, bool overrideExisting
, const string
& newOptionContent
)
137 for (auto it
= options
.begin(); it
!= options
.end();) {
138 if (it
->first
== optionCode
) {
141 if (!overrideExisting
) {
145 it
= options
.erase(it
);
152 options
.emplace_back(optionCode
, std::string(&newOptionContent
.at(EDNS_OPTION_CODE_SIZE
+ EDNS_OPTION_LENGTH_SIZE
), newOptionContent
.size() - (EDNS_OPTION_CODE_SIZE
+ EDNS_OPTION_LENGTH_SIZE
)));
156 bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer
& initialPacket
, PacketBuffer
& newContent
, bool& ednsAdded
, uint16_t optionToReplace
, bool& optionAdded
, bool overrideExisting
, const string
& newOptionContent
)
158 assert(initialPacket
.size() >= sizeof(dnsheader
));
159 const dnsheader_aligned
dh(initialPacket
.data());
161 if (ntohs(dh
->qdcount
) == 0) {
165 if (ntohs(dh
->ancount
) == 0 && ntohs(dh
->nscount
) == 0 && ntohs(dh
->arcount
) == 0) {
166 throw std::runtime_error(std::string(__PRETTY_FUNCTION__
) + " should not be called for queries that have no records");
172 PacketReader
pr(std::string_view(reinterpret_cast<const char*>(initialPacket
.data()), initialPacket
.size()));
176 uint16_t qdcount
= ntohs(dh
->qdcount
);
177 uint16_t ancount
= ntohs(dh
->ancount
);
178 uint16_t nscount
= ntohs(dh
->nscount
);
179 uint16_t arcount
= ntohs(dh
->arcount
);
183 struct dnsrecordheader ah
;
185 rrname
= pr
.getName();
186 rrtype
= pr
.get16BitInt();
187 rrclass
= pr
.get16BitInt();
189 GenericDNSPacketWriter
<PacketBuffer
> pw(newContent
, rrname
, rrtype
, rrclass
, dh
->opcode
);
190 pw
.getHeader()->id
= dh
->id
;
191 pw
.getHeader()->qr
= dh
->qr
;
192 pw
.getHeader()->aa
= dh
->aa
;
193 pw
.getHeader()->tc
= dh
->tc
;
194 pw
.getHeader()->rd
= dh
->rd
;
195 pw
.getHeader()->ra
= dh
->ra
;
196 pw
.getHeader()->ad
= dh
->ad
;
197 pw
.getHeader()->cd
= dh
->cd
;
198 pw
.getHeader()->rcode
= dh
->rcode
;
200 /* consume remaining qd if any */
202 for (idx
= 1; idx
< qdcount
; idx
++) {
203 rrname
= pr
.getName();
204 rrtype
= pr
.get16BitInt();
205 rrclass
= pr
.get16BitInt();
212 for (idx
= 0; idx
< ancount
; idx
++) {
213 rrname
= pr
.getName();
214 pr
.getDnsrecordheader(ah
);
216 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ANSWER
, true);
221 for (idx
= 0; idx
< nscount
; idx
++) {
222 rrname
= pr
.getName();
223 pr
.getDnsrecordheader(ah
);
225 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::AUTHORITY
, true);
230 /* consume AR, looking for OPT */
231 for (idx
= 0; idx
< arcount
; idx
++) {
232 rrname
= pr
.getName();
233 pr
.getDnsrecordheader(ah
);
235 if (ah
.d_type
!= QType::OPT
) {
236 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ADDITIONAL
, true);
245 std::vector
<std::pair
<uint16_t, std::string
>> options
;
246 getEDNSOptionsFromContent(blob
, options
);
248 /* getDnsrecordheader() has helpfully converted the TTL for us, which we do not want in that case */
249 uint32_t ttl
= htonl(ah
.d_ttl
);
251 static_assert(sizeof(edns0
) == sizeof(ttl
), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
252 memcpy(&edns0
, &ttl
, sizeof(edns0
));
254 /* addOrReplaceEDNSOption will set it to false if there is already an existing option */
256 addOrReplaceEDNSOption(options
, optionToReplace
, optionAdded
, overrideExisting
, newOptionContent
);
257 pw
.addOpt(ah
.d_class
, edns0
.extRCode
, edns0
.extFlags
, options
, edns0
.version
);
262 pw
.addOpt(g_EdnsUDPPayloadSize
, 0, 0, {{optionToReplace
, std::string(&newOptionContent
.at(EDNS_OPTION_CODE_SIZE
+ EDNS_OPTION_LENGTH_SIZE
), newOptionContent
.size() - (EDNS_OPTION_CODE_SIZE
+ EDNS_OPTION_LENGTH_SIZE
))}}, 0);
271 static bool slowParseEDNSOptions(const PacketBuffer
& packet
, EDNSOptionViewMap
& options
)
273 if (packet
.size() < sizeof(dnsheader
)) {
277 const dnsheader_aligned
dh(packet
.data());
279 if (ntohs(dh
->qdcount
) == 0) {
283 if (ntohs(dh
->arcount
) == 0) {
284 throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
288 uint64_t numrecords
= ntohs(dh
->ancount
) + ntohs(dh
->nscount
) + ntohs(dh
->arcount
);
289 DNSPacketMangler
dpm(const_cast<char*>(reinterpret_cast<const char*>(&packet
.at(0))), packet
.size());
291 for (n
= 0; n
< ntohs(dh
->qdcount
); ++n
) {
292 dpm
.skipDomainName();
297 for (n
= 0; n
< numrecords
; ++n
) {
298 dpm
.skipDomainName();
300 uint8_t section
= n
< ntohs(dh
->ancount
) ? 1 : (n
< (ntohs(dh
->ancount
) + ntohs(dh
->nscount
)) ? 2 : 3);
301 uint16_t dnstype
= dpm
.get16BitInt();
303 dpm
.skipBytes(4); /* TTL */
305 if (section
== 3 && dnstype
== QType::OPT
) {
306 uint32_t offset
= dpm
.getOffset();
307 if (offset
>= packet
.size()) {
310 /* if we survive this call, we can parse it safely */
312 return getEDNSOptions(reinterpret_cast<const char*>(&packet
.at(offset
)), packet
.size() - offset
, options
) == 0;
326 int locateEDNSOptRR(const PacketBuffer
& packet
, uint16_t* optStart
, size_t* optLen
, bool* last
)
328 assert(optStart
!= NULL
);
329 assert(optLen
!= NULL
);
330 assert(last
!= NULL
);
331 const dnsheader_aligned
dh(packet
.data());
333 if (ntohs(dh
->arcount
) == 0) {
337 PacketReader
pr(std::string_view(reinterpret_cast<const char*>(packet
.data()), packet
.size()));
341 uint16_t qdcount
= ntohs(dh
->qdcount
);
342 uint16_t ancount
= ntohs(dh
->ancount
);
343 uint16_t nscount
= ntohs(dh
->nscount
);
344 uint16_t arcount
= ntohs(dh
->arcount
);
347 struct dnsrecordheader ah
;
350 for (idx
= 0; idx
< qdcount
; idx
++) {
351 rrname
= pr
.getName();
352 rrtype
= pr
.get16BitInt();
353 rrclass
= pr
.get16BitInt();
358 /* consume AN and NS */
359 for (idx
= 0; idx
< ancount
+ nscount
; idx
++) {
360 rrname
= pr
.getName();
361 pr
.getDnsrecordheader(ah
);
365 /* consume AR, looking for OPT */
366 for (idx
= 0; idx
< arcount
; idx
++) {
367 uint16_t start
= pr
.getPosition();
368 rrname
= pr
.getName();
369 pr
.getDnsrecordheader(ah
);
371 if (ah
.d_type
== QType::OPT
) {
373 *optLen
= (pr
.getPosition() - start
) + ah
.d_clen
;
375 if (packet
.size() < (*optStart
+ *optLen
)) {
376 throw std::range_error("Opt record overflow");
379 if (idx
== ((size_t)arcount
- 1)) {
393 /* extract the start of the OPT RR in a QUERY packet if any */
394 int getEDNSOptionsStart(const PacketBuffer
& packet
, const size_t offset
, uint16_t* optRDPosition
, size_t* remaining
)
396 assert(optRDPosition
!= nullptr);
397 assert(remaining
!= nullptr);
398 const dnsheader_aligned
dh(packet
.data());
400 if (offset
>= packet
.size()) {
404 if (ntohs(dh
->qdcount
) != 1 || ntohs(dh
->ancount
) != 0 || ntohs(dh
->arcount
) != 1 || ntohs(dh
->nscount
) != 0) {
408 size_t pos
= sizeof(dnsheader
) + offset
;
409 pos
+= DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
;
411 if (pos
>= packet
.size())
414 if ((pos
+ /* root */ 1 + DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
) >= packet
.size()) {
418 if (packet
[pos
] != 0) {
419 /* not the root so not an OPT record */
424 uint16_t qtype
= packet
.at(pos
) * 256 + packet
.at(pos
+ 1);
425 pos
+= DNS_TYPE_SIZE
;
426 pos
+= DNS_CLASS_SIZE
;
428 if (qtype
!= QType::OPT
|| (packet
.size() - pos
) < (DNS_TTL_SIZE
+ DNS_RDLENGTH_SIZE
)) {
433 *optRDPosition
= pos
;
434 *remaining
= packet
.size() - pos
;
439 void generateECSOption(const ComboAddress
& source
, string
& res
, uint16_t ECSPrefixLength
)
441 Netmask
sourceNetmask(source
, ECSPrefixLength
);
442 EDNSSubnetOpts ecsOpts
;
443 ecsOpts
.source
= sourceNetmask
;
444 string payload
= makeEDNSSubnetOptsString(ecsOpts
);
445 generateEDNSOption(EDNSOptionCode::ECS
, payload
, res
);
448 bool generateOptRR(const std::string
& optRData
, PacketBuffer
& res
, size_t maximumSize
, uint16_t udpPayloadSize
, uint8_t ednsrcode
, bool dnssecOK
)
450 const uint8_t name
= 0;
453 edns0
.extRCode
= ednsrcode
;
455 edns0
.extFlags
= dnssecOK
? htons(EDNS_HEADER_FLAG_DO
) : 0;
457 if ((maximumSize
- res
.size()) < (sizeof(name
) + sizeof(dh
) + optRData
.length())) {
461 dh
.d_type
= htons(QType::OPT
);
462 dh
.d_class
= htons(udpPayloadSize
);
463 static_assert(sizeof(EDNS0Record
) == sizeof(dh
.d_ttl
), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
464 memcpy(&dh
.d_ttl
, &edns0
, sizeof edns0
);
465 dh
.d_clen
= htons(static_cast<uint16_t>(optRData
.length()));
467 res
.reserve(res
.size() + sizeof(name
) + sizeof(dh
) + optRData
.length());
468 res
.insert(res
.end(), reinterpret_cast<const uint8_t*>(&name
), reinterpret_cast<const uint8_t*>(&name
) + sizeof(name
));
469 res
.insert(res
.end(), reinterpret_cast<const uint8_t*>(&dh
), reinterpret_cast<const uint8_t*>(&dh
) + sizeof(dh
));
470 res
.insert(res
.end(), reinterpret_cast<const uint8_t*>(optRData
.data()), reinterpret_cast<const uint8_t*>(optRData
.data()) + optRData
.length());
475 static bool replaceEDNSClientSubnetOption(PacketBuffer
& packet
, size_t maximumSize
, size_t const oldEcsOptionStartPosition
, size_t const oldEcsOptionSize
, size_t const optRDLenPosition
, const string
& newECSOption
)
477 assert(oldEcsOptionStartPosition
< packet
.size());
478 assert(optRDLenPosition
< packet
.size());
480 if (newECSOption
.size() == oldEcsOptionSize
) {
481 /* same size as the existing option */
482 memcpy(&packet
.at(oldEcsOptionStartPosition
), newECSOption
.c_str(), oldEcsOptionSize
);
485 /* different size than the existing option */
486 const unsigned int newPacketLen
= packet
.size() + (newECSOption
.length() - oldEcsOptionSize
);
487 const size_t beforeOptionLen
= oldEcsOptionStartPosition
;
488 const size_t dataBehindSize
= packet
.size() - beforeOptionLen
- oldEcsOptionSize
;
490 /* check that it fits in the existing buffer */
491 if (newPacketLen
> packet
.size()) {
492 if (newPacketLen
> maximumSize
) {
496 packet
.resize(newPacketLen
);
499 /* fix the size of ECS Option RDLen */
500 uint16_t newRDLen
= (packet
.at(optRDLenPosition
) * 256) + packet
.at(optRDLenPosition
+ 1);
501 newRDLen
+= (newECSOption
.size() - oldEcsOptionSize
);
502 packet
.at(optRDLenPosition
) = newRDLen
/ 256;
503 packet
.at(optRDLenPosition
+ 1) = newRDLen
% 256;
505 if (dataBehindSize
> 0) {
506 memmove(&packet
.at(oldEcsOptionStartPosition
), &packet
.at(oldEcsOptionStartPosition
+ oldEcsOptionSize
), dataBehindSize
);
508 memcpy(&packet
.at(oldEcsOptionStartPosition
+ dataBehindSize
), newECSOption
.c_str(), newECSOption
.size());
509 packet
.resize(newPacketLen
);
515 /* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
516 and false otherwise. */
517 bool parseEDNSOptions(const DNSQuestion
& dq
)
519 const auto dh
= dq
.getHeader();
520 if (dq
.ednsOptions
!= nullptr) {
524 // dq.ednsOptions is mutable
525 dq
.ednsOptions
= std::make_unique
<EDNSOptionViewMap
>();
527 if (ntohs(dh
->arcount
) == 0) {
528 /* nothing in additional so no EDNS */
532 if (ntohs(dh
->ancount
) != 0 || ntohs(dh
->nscount
) != 0 || ntohs(dh
->arcount
) > 1) {
533 return slowParseEDNSOptions(dq
.getData(), *dq
.ednsOptions
);
536 size_t remaining
= 0;
537 uint16_t optRDPosition
;
538 int res
= getEDNSOptionsStart(dq
.getData(), dq
.ids
.qname
.wirelength(), &optRDPosition
, &remaining
);
541 res
= getEDNSOptions(reinterpret_cast<const char*>(&dq
.getData().at(optRDPosition
)), remaining
, *dq
.ednsOptions
);
548 static bool addECSToExistingOPT(PacketBuffer
& packet
, size_t maximumSize
, const string
& newECSOption
, size_t optRDLenPosition
, bool& ecsAdded
)
550 /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
551 /* getEDNSOptionsStart has already checked that there is exactly one AR,
553 uint16_t oldRDLen
= (packet
.at(optRDLenPosition
) * 256) + packet
.at(optRDLenPosition
+ 1);
554 if (packet
.size() != (optRDLenPosition
+ sizeof(uint16_t) + oldRDLen
)) {
555 /* we are supposed to be the last record, do we have some trailing data to remove? */
556 uint32_t realPacketLen
= getDNSPacketLength(reinterpret_cast<const char*>(packet
.data()), packet
.size());
557 packet
.resize(realPacketLen
);
560 if ((maximumSize
- packet
.size()) < newECSOption
.size()) {
564 uint16_t newRDLen
= oldRDLen
+ newECSOption
.size();
565 packet
.at(optRDLenPosition
) = newRDLen
/ 256;
566 packet
.at(optRDLenPosition
+ 1) = newRDLen
% 256;
568 packet
.insert(packet
.end(), newECSOption
.begin(), newECSOption
.end());
574 static bool addEDNSWithECS(PacketBuffer
& packet
, size_t maximumSize
, const string
& newECSOption
, bool& ednsAdded
, bool& ecsAdded
)
576 if (!generateOptRR(newECSOption
, packet
, maximumSize
, g_EdnsUDPPayloadSize
, 0, false)) {
580 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet
, [](dnsheader
& header
) {
581 uint16_t arcount
= ntohs(header
.arcount
);
583 header
.arcount
= htons(arcount
);
592 bool handleEDNSClientSubnet(PacketBuffer
& packet
, const size_t maximumSize
, const size_t qnameWireLength
, bool& ednsAdded
, bool& ecsAdded
, bool overrideExisting
, const string
& newECSOption
)
594 assert(qnameWireLength
<= packet
.size());
596 const dnsheader_aligned
dh(packet
.data());
598 if (ntohs(dh
->ancount
) != 0 || ntohs(dh
->nscount
) != 0 || (ntohs(dh
->arcount
) != 0 && ntohs(dh
->arcount
) != 1)) {
599 PacketBuffer newContent
;
600 newContent
.reserve(packet
.size());
602 if (!slowRewriteEDNSOptionInQueryWithRecords(packet
, newContent
, ednsAdded
, EDNSOptionCode::ECS
, ecsAdded
, overrideExisting
, newECSOption
)) {
606 if (newContent
.size() > maximumSize
) {
612 packet
= std::move(newContent
);
616 uint16_t optRDPosition
= 0;
617 size_t remaining
= 0;
619 int res
= getEDNSOptionsStart(packet
, qnameWireLength
, &optRDPosition
, &remaining
);
622 /* no EDNS but there might be another record in additional (TSIG?) */
623 /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */
624 size_t minimumPacketSize
= sizeof(dnsheader
) + qnameWireLength
+ sizeof(uint16_t) + sizeof(uint16_t);
625 if (packet
.size() > minimumPacketSize
) {
626 if (ntohs(dh
->arcount
) == 0) {
628 packet
.resize(minimumPacketSize
);
631 uint32_t realPacketLen
= getDNSPacketLength(reinterpret_cast<const char*>(packet
.data()), packet
.size());
632 packet
.resize(realPacketLen
);
636 return addEDNSWithECS(packet
, maximumSize
, newECSOption
, ednsAdded
, ecsAdded
);
639 size_t ecsOptionStartPosition
= 0;
640 size_t ecsOptionSize
= 0;
642 res
= getEDNSOption(reinterpret_cast<const char*>(&packet
.at(optRDPosition
)), remaining
, EDNSOptionCode::ECS
, &ecsOptionStartPosition
, &ecsOptionSize
);
645 /* there is already an ECS value */
646 if (!overrideExisting
) {
650 return replaceEDNSClientSubnetOption(packet
, maximumSize
, optRDPosition
+ ecsOptionStartPosition
, ecsOptionSize
, optRDPosition
, newECSOption
);
653 /* we have an EDNS OPT RR but no existing ECS option */
654 return addECSToExistingOPT(packet
, maximumSize
, newECSOption
, optRDPosition
, ecsAdded
);
660 bool handleEDNSClientSubnet(DNSQuestion
& dq
, bool& ednsAdded
, bool& ecsAdded
)
663 generateECSOption(dq
.ecs
? dq
.ecs
->getNetwork() : dq
.ids
.origRemote
, newECSOption
, dq
.ecs
? dq
.ecs
->getBits() : dq
.ecsPrefixLength
);
665 return handleEDNSClientSubnet(dq
.getMutableData(), dq
.getMaximumSize(), dq
.ids
.qname
.wirelength(), ednsAdded
, ecsAdded
, dq
.ecsOverride
, newECSOption
);
668 static int removeEDNSOptionFromOptions(unsigned char* optionsStart
, const uint16_t optionsLen
, const uint16_t optionCodeToRemove
, uint16_t* newOptionsLen
)
670 unsigned char* p
= optionsStart
;
672 while ((pos
+ 4) <= optionsLen
) {
673 unsigned char* optionBegin
= p
;
674 const uint16_t optionCode
= 0x100 * p
[0] + p
[1];
675 p
+= sizeof(optionCode
);
676 pos
+= sizeof(optionCode
);
677 const uint16_t optionLen
= 0x100 * p
[0] + p
[1];
678 p
+= sizeof(optionLen
);
679 pos
+= sizeof(optionLen
);
680 if ((pos
+ optionLen
) > optionsLen
) {
683 if (optionCode
== optionCodeToRemove
) {
684 if (pos
+ optionLen
< optionsLen
) {
685 /* move remaining options over the removed one,
687 memmove(optionBegin
, p
+ optionLen
, optionsLen
- (pos
+ optionLen
));
689 *newOptionsLen
= optionsLen
- (sizeof(optionCode
) + sizeof(optionLen
) + optionLen
);
698 int removeEDNSOptionFromOPT(char* optStart
, size_t* optLen
, const uint16_t optionCodeToRemove
)
700 if (*optLen
< optRecordMinimumSize
) {
703 const unsigned char* end
= (const unsigned char*)optStart
+ *optLen
;
704 unsigned char* p
= (unsigned char*)optStart
+ 9;
705 unsigned char* rdLenPtr
= p
;
706 uint16_t rdLen
= (0x100 * p
[0] + p
[1]);
708 if (p
+ rdLen
!= end
) {
711 uint16_t newRdLen
= 0;
712 int res
= removeEDNSOptionFromOptions(p
, rdLen
, optionCodeToRemove
, &newRdLen
);
716 *optLen
-= (rdLen
- newRdLen
);
717 rdLenPtr
[0] = newRdLen
/ 0x100;
718 rdLenPtr
[1] = newRdLen
% 0x100;
722 bool isEDNSOptionInOpt(const PacketBuffer
& packet
, const size_t optStart
, const size_t optLen
, const uint16_t optionCodeToFind
, size_t* optContentStart
, uint16_t* optContentLen
)
724 if (optLen
< optRecordMinimumSize
) {
727 size_t p
= optStart
+ 9;
728 uint16_t rdLen
= (0x100 * static_cast<unsigned char>(packet
.at(p
)) + static_cast<unsigned char>(packet
.at(p
+ 1)));
730 if (rdLen
> (optLen
- optRecordMinimumSize
)) {
734 size_t rdEnd
= p
+ rdLen
;
735 while ((p
+ 4) <= rdEnd
) {
736 const uint16_t optionCode
= 0x100 * static_cast<unsigned char>(packet
.at(p
)) + static_cast<unsigned char>(packet
.at(p
+ 1));
737 p
+= sizeof(optionCode
);
738 const uint16_t optionLen
= 0x100 * static_cast<unsigned char>(packet
.at(p
)) + static_cast<unsigned char>(packet
.at(p
+ 1));
739 p
+= sizeof(optionLen
);
741 if ((p
+ optionLen
) > rdEnd
) {
745 if (optionCode
== optionCodeToFind
) {
746 if (optContentStart
!= nullptr) {
747 *optContentStart
= p
;
750 if (optContentLen
!= nullptr) {
751 *optContentLen
= optionLen
;
761 int rewriteResponseWithoutEDNSOption(const PacketBuffer
& initialPacket
, const uint16_t optionCodeToSkip
, PacketBuffer
& newContent
)
763 assert(initialPacket
.size() >= sizeof(dnsheader
));
764 const dnsheader_aligned
dh(initialPacket
.data());
766 if (ntohs(dh
->arcount
) == 0)
769 if (ntohs(dh
->qdcount
) == 0)
772 PacketReader
pr(std::string_view(reinterpret_cast<const char*>(initialPacket
.data()), initialPacket
.size()));
776 uint16_t qdcount
= ntohs(dh
->qdcount
);
777 uint16_t ancount
= ntohs(dh
->ancount
);
778 uint16_t nscount
= ntohs(dh
->nscount
);
779 uint16_t arcount
= ntohs(dh
->arcount
);
783 struct dnsrecordheader ah
;
785 rrname
= pr
.getName();
786 rrtype
= pr
.get16BitInt();
787 rrclass
= pr
.get16BitInt();
789 GenericDNSPacketWriter
<PacketBuffer
> pw(newContent
, rrname
, rrtype
, rrclass
, dh
->opcode
);
790 pw
.getHeader()->id
= dh
->id
;
791 pw
.getHeader()->qr
= dh
->qr
;
792 pw
.getHeader()->aa
= dh
->aa
;
793 pw
.getHeader()->tc
= dh
->tc
;
794 pw
.getHeader()->rd
= dh
->rd
;
795 pw
.getHeader()->ra
= dh
->ra
;
796 pw
.getHeader()->ad
= dh
->ad
;
797 pw
.getHeader()->cd
= dh
->cd
;
798 pw
.getHeader()->rcode
= dh
->rcode
;
800 /* consume remaining qd if any */
802 for (idx
= 1; idx
< qdcount
; idx
++) {
803 rrname
= pr
.getName();
804 rrtype
= pr
.get16BitInt();
805 rrclass
= pr
.get16BitInt();
812 for (idx
= 0; idx
< ancount
; idx
++) {
813 rrname
= pr
.getName();
814 pr
.getDnsrecordheader(ah
);
816 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ANSWER
, true);
821 for (idx
= 0; idx
< nscount
; idx
++) {
822 rrname
= pr
.getName();
823 pr
.getDnsrecordheader(ah
);
825 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::AUTHORITY
, true);
830 /* consume AR, looking for OPT */
831 for (idx
= 0; idx
< arcount
; idx
++) {
832 rrname
= pr
.getName();
833 pr
.getDnsrecordheader(ah
);
835 if (ah
.d_type
!= QType::OPT
) {
836 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ADDITIONAL
, true);
841 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ADDITIONAL
, false);
843 uint16_t rdLen
= blob
.length();
844 removeEDNSOptionFromOptions((unsigned char*)blob
.c_str(), rdLen
, optionCodeToSkip
, &rdLen
);
845 /* xfrBlob(string, size) completely ignores size.. */
847 blob
.resize((size_t)rdLen
);
860 bool addEDNS(PacketBuffer
& packet
, size_t maximumSize
, bool dnssecOK
, uint16_t payloadSize
, uint8_t ednsrcode
)
862 if (!generateOptRR(std::string(), packet
, maximumSize
, payloadSize
, ednsrcode
, dnssecOK
)) {
866 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet
, [](dnsheader
& header
) {
867 header
.arcount
= htons(ntohs(header
.arcount
) + 1);
875 This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
876 generating a NXD or NODATA answer with a SOA record in the additional section (or optionally the authority section for a full cacheable NXDOMAIN/NODATA).
878 bool setNegativeAndAdditionalSOA(DNSQuestion
& dq
, bool nxd
, const DNSName
& zone
, uint32_t ttl
, const DNSName
& mname
, const DNSName
& rname
, uint32_t serial
, uint32_t refresh
, uint32_t retry
, uint32_t expire
, uint32_t minimum
, bool soaInAuthoritySection
)
880 auto& packet
= dq
.getMutableData();
881 auto dh
= dq
.getHeader();
882 if (ntohs(dh
->qdcount
) != 1) {
886 size_t queryPartSize
= sizeof(dnsheader
) + dq
.ids
.qname
.wirelength() + DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
;
887 if (packet
.size() < queryPartSize
) {
888 /* something is already wrong, don't build on flawed foundations */
892 uint16_t qtype
= htons(QType::SOA
);
893 uint16_t qclass
= htons(QClass::IN
);
894 uint16_t rdLength
= mname
.wirelength() + rname
.wirelength() + sizeof(serial
) + sizeof(refresh
) + sizeof(retry
) + sizeof(expire
) + sizeof(minimum
);
895 size_t soaSize
= zone
.wirelength() + sizeof(qtype
) + sizeof(qclass
) + sizeof(ttl
) + sizeof(rdLength
) + rdLength
;
896 bool hadEDNS
= false;
897 bool dnssecOK
= false;
899 if (g_addEDNSToSelfGeneratedResponses
) {
900 uint16_t payloadSize
= 0;
902 hadEDNS
= getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(packet
.data()), packet
.size(), &payloadSize
, &z
);
904 dnssecOK
= z
& EDNS_HEADER_FLAG_DO
;
908 /* chop off everything after the question */
909 packet
.resize(queryPartSize
);
910 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet
, [nxd
](dnsheader
& header
) {
912 header
.rcode
= RCode::NXDomain
;
915 header
.rcode
= RCode::NoError
;
924 rdLength
= htons(rdLength
);
926 serial
= htonl(serial
);
927 refresh
= htonl(refresh
);
928 retry
= htonl(retry
);
929 expire
= htonl(expire
);
930 minimum
= htonl(minimum
);
933 soa
.reserve(soaSize
);
934 soa
.append(zone
.toDNSString());
935 soa
.append(reinterpret_cast<const char*>(&qtype
), sizeof(qtype
));
936 soa
.append(reinterpret_cast<const char*>(&qclass
), sizeof(qclass
));
937 soa
.append(reinterpret_cast<const char*>(&ttl
), sizeof(ttl
));
938 soa
.append(reinterpret_cast<const char*>(&rdLength
), sizeof(rdLength
));
939 soa
.append(mname
.toDNSString());
940 soa
.append(rname
.toDNSString());
941 soa
.append(reinterpret_cast<const char*>(&serial
), sizeof(serial
));
942 soa
.append(reinterpret_cast<const char*>(&refresh
), sizeof(refresh
));
943 soa
.append(reinterpret_cast<const char*>(&retry
), sizeof(retry
));
944 soa
.append(reinterpret_cast<const char*>(&expire
), sizeof(expire
));
945 soa
.append(reinterpret_cast<const char*>(&minimum
), sizeof(minimum
));
947 if (soa
.size() != soaSize
) {
948 throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa
.size()) + " vs " + std::to_string(soaSize
));
951 packet
.insert(packet
.end(), soa
.begin(), soa
.end());
953 /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
954 NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
955 and have EDNS added to AR afterwards */
956 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet
, [soaInAuthoritySection
](dnsheader
& header
) {
957 if (soaInAuthoritySection
) {
958 header
.nscount
= htons(1);
961 header
.arcount
= htons(1);
967 /* now we need to add a new OPT record */
968 return addEDNS(packet
, dq
.getMaximumSize(), dnssecOK
, g_PayloadSizeSelfGenAnswers
, dq
.ednsRCode
);
974 bool addEDNSToQueryTurnedResponse(DNSQuestion
& dq
)
976 uint16_t optRDPosition
;
977 /* remaining is at least the size of the rdlen + the options if any + the following records if any */
978 size_t remaining
= 0;
980 auto& packet
= dq
.getMutableData();
981 int res
= getEDNSOptionsStart(packet
, dq
.ids
.qname
.wirelength(), &optRDPosition
, &remaining
);
984 /* if the initial query did not have EDNS0, we are done */
988 const size_t existingOptLen
= /* root */ 1 + DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
+ EDNS_EXTENDED_RCODE_SIZE
+ EDNS_VERSION_SIZE
+ /* Z */ 2 + remaining
;
989 if (existingOptLen
>= packet
.size()) {
990 /* something is wrong, bail out */
994 uint8_t* optRDLen
= &packet
.at(optRDPosition
);
995 uint8_t* optPtr
= (optRDLen
- (/* root */ 1 + DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
+ EDNS_EXTENDED_RCODE_SIZE
+ EDNS_VERSION_SIZE
+ /* Z */ 2));
997 const uint8_t* zPtr
= optPtr
+ /* root */ 1 + DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
+ EDNS_EXTENDED_RCODE_SIZE
+ EDNS_VERSION_SIZE
;
998 uint16_t z
= 0x100 * (*zPtr
) + *(zPtr
+ 1);
999 bool dnssecOK
= z
& EDNS_HEADER_FLAG_DO
;
1001 /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
1002 packet
.resize(packet
.size() - existingOptLen
);
1003 dnsdist::PacketMangling::editDNSHeaderFromPacket(packet
, [](dnsheader
& header
) {
1008 if (g_addEDNSToSelfGeneratedResponses
) {
1009 /* now we need to add a new OPT record */
1010 return addEDNS(packet
, dq
.getMaximumSize(), dnssecOK
, g_PayloadSizeSelfGenAnswers
, dq
.ednsRCode
);
1013 /* otherwise we are just fine */
1017 // goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
1018 int getEDNSZ(const DNSQuestion
& dq
)
1021 const auto& dh
= dq
.getHeader();
1022 if (ntohs(dh
->qdcount
) != 1 || dh
->ancount
!= 0 || ntohs(dh
->arcount
) != 1 || dh
->nscount
!= 0) {
1026 if (dq
.getData().size() <= sizeof(dnsheader
)) {
1030 size_t pos
= sizeof(dnsheader
) + dq
.ids
.qname
.wirelength() + DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
;
1032 if (dq
.getData().size() <= (pos
+ /* root */ 1 + DNS_TYPE_SIZE
+ DNS_CLASS_SIZE
)) {
1036 auto& packet
= dq
.getData();
1038 if (packet
.at(pos
) != 0) {
1039 /* not root, so not a valid OPT record */
1045 uint16_t qtype
= packet
.at(pos
) * 256 + packet
.at(pos
+ 1);
1046 pos
+= DNS_TYPE_SIZE
;
1047 pos
+= DNS_CLASS_SIZE
;
1049 if (qtype
!= QType::OPT
|| (pos
+ EDNS_EXTENDED_RCODE_SIZE
+ EDNS_VERSION_SIZE
+ 1) >= packet
.size()) {
1053 const uint8_t* z
= &packet
.at(pos
+ EDNS_EXTENDED_RCODE_SIZE
+ EDNS_VERSION_SIZE
);
1054 return 0x100 * (*z
) + *(z
+ 1);
1061 bool queryHasEDNS(const DNSQuestion
& dq
)
1063 uint16_t optRDPosition
;
1064 size_t ecsRemaining
= 0;
1066 int res
= getEDNSOptionsStart(dq
.getData(), dq
.ids
.qname
.wirelength(), &optRDPosition
, &ecsRemaining
);
1074 bool getEDNS0Record(const PacketBuffer
& packet
, EDNS0Record
& edns0
)
1079 int res
= locateEDNSOptRR(packet
, &optStart
, &optLen
, &last
);
1085 if (optLen
< optRecordMinimumSize
) {
1089 if (optStart
< packet
.size() && packet
.at(optStart
) != 0) {
1090 // OPT RR Name != '.'
1094 static_assert(sizeof(EDNS0Record
) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
1095 // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
1096 memcpy(&edns0
, &packet
.at(optStart
+ 5), sizeof edns0
);
1100 bool setEDNSOption(DNSQuestion
& dq
, uint16_t ednsCode
, const std::string
& ednsData
)
1102 std::string optRData
;
1103 generateEDNSOption(ednsCode
, ednsData
, optRData
);
1105 if (dq
.getHeader()->arcount
) {
1106 bool ednsAdded
= false;
1107 bool optionAdded
= false;
1108 PacketBuffer newContent
;
1109 newContent
.reserve(dq
.getData().size());
1111 if (!slowRewriteEDNSOptionInQueryWithRecords(dq
.getData(), newContent
, ednsAdded
, ednsCode
, optionAdded
, true, optRData
)) {
1115 if (newContent
.size() > dq
.getMaximumSize()) {
1119 dq
.getMutableData() = std::move(newContent
);
1120 if (!dq
.ids
.ednsAdded
&& ednsAdded
) {
1121 dq
.ids
.ednsAdded
= true;
1127 auto& data
= dq
.getMutableData();
1128 if (generateOptRR(optRData
, data
, dq
.getMaximumSize(), g_EdnsUDPPayloadSize
, 0, false)) {
1129 dnsdist::PacketMangling::editDNSHeaderFromPacket(dq
.getMutableData(), [](dnsheader
& header
) {
1130 header
.arcount
= htons(1);
1133 // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
1134 dq
.ids
.ednsAdded
= true;
1142 bool setInternalQueryRCode(InternalQueryState
& state
, PacketBuffer
& buffer
, uint8_t rcode
, bool clearAnswers
)
1144 const auto qnameLength
= state
.qname
.wirelength();
1145 if (buffer
.size() < sizeof(dnsheader
) + qnameLength
+ sizeof(uint16_t) + sizeof(uint16_t)) {
1150 bool hadEDNS
= false;
1152 hadEDNS
= getEDNS0Record(buffer
, edns0
);
1155 dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer
, [rcode
, clearAnswers
](dnsheader
& header
) {
1156 header
.rcode
= rcode
;
1159 header
.ra
= header
.rd
;
1171 buffer
.resize(sizeof(dnsheader
) + qnameLength
+ sizeof(uint16_t) + sizeof(uint16_t));
1173 DNSQuestion
dq(state
, buffer
);
1174 if (!addEDNS(buffer
, dq
.getMaximumSize(), edns0
.extFlags
& htons(EDNS_HEADER_FLAG_DO
), g_PayloadSizeSelfGenAnswers
, 0)) {