2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
22 #include "dnsparser.hh"
23 #include "dnswriter.hh"
24 #include <boost/algorithm/string.hpp>
25 #include <boost/format.hpp>
27 #include "namespaces.hh"
28 #include "noinitvector.hh"
30 UnknownRecordContent::UnknownRecordContent(const string
& zone
)
34 stringtok(parts
, zone
);
35 // we need exactly 3 parts, except if the length field is set to 0 then we only need 2
36 if (parts
.size() != 3 && !(parts
.size() == 2 && boost::equals(parts
.at(1), "0"))) {
37 throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got " + std::to_string(parts
.size()) + ": " + zone
);
40 if (parts
.at(0) != "\\#") {
41 throw MOADNSException("Unknown record was stored incorrectly, first part should be '\\#', got '" + parts
.at(0) + "'");
44 const string
& relevant
= (parts
.size() > 2) ? parts
.at(2) : "";
45 auto total
= pdns::checked_stoi
<unsigned int>(parts
.at(1));
46 if (relevant
.size() % 2 || (relevant
.size() / 2) != total
) {
47 throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant
.size() % total
).str());
51 out
.reserve(total
+ 1);
53 for (unsigned int n
= 0; n
< total
; ++n
) {
55 if (sscanf(&relevant
.at(2*n
), "%02x", &c
) != 1) {
56 throw MOADNSException("unable to read data at position " + std::to_string(2 * n
) + " from unknown record of size " + std::to_string(relevant
.size()));
58 out
.append(1, (char)c
);
61 d_record
.insert(d_record
.end(), out
.begin(), out
.end());
64 string
UnknownRecordContent::getZoneRepresentation(bool /* noDot */) const
67 str
<<"\\# "<<(unsigned int)d_record
.size()<<" ";
69 for (unsigned char n
: d_record
) {
70 snprintf(hex
, sizeof(hex
), "%02x", n
);
76 void UnknownRecordContent::toPacket(DNSPacketWriter
& pw
) const
78 pw
.xfrBlob(string(d_record
.begin(),d_record
.end()));
81 shared_ptr
<DNSRecordContent
> DNSRecordContent::deserialize(const DNSName
& qname
, uint16_t qtype
, const string
& serialized
)
84 memset(&dnsheader
, 0, sizeof(dnsheader
));
85 dnsheader
.qdcount
=htons(1);
86 dnsheader
.ancount
=htons(1);
88 PacketBuffer packet
; // build pseudo packet
89 /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
90 const auto& encoded
= qname
.getStorage();
91 packet
.resize(sizeof(dnsheader
) + 5 + encoded
.size() + sizeof(struct dnsrecordheader
) + serialized
.size());
94 memcpy(&packet
[0], &dnsheader
, sizeof(dnsheader
)); pos
+=sizeof(dnsheader
);
96 constexpr std::array
<uint8_t, 5> tmp
= {'\x0', '\x0', '\x1', '\x0', '\x1' }; // root question for ns_t_a
97 memcpy(&packet
[pos
], tmp
.data(), tmp
.size()); pos
+= tmp
.size();
99 memcpy(&packet
[pos
], encoded
.c_str(), encoded
.size()); pos
+=(uint16_t)encoded
.size();
101 struct dnsrecordheader drh
;
102 drh
.d_type
=htons(qtype
);
103 drh
.d_class
=htons(QClass::IN
);
105 drh
.d_clen
=htons(serialized
.size());
107 memcpy(&packet
[pos
], &drh
, sizeof(drh
)); pos
+=sizeof(drh
);
108 if (!serialized
.empty()) {
109 memcpy(&packet
[pos
], serialized
.c_str(), serialized
.size());
110 pos
+= (uint16_t) serialized
.size();
115 dr
.d_class
= QClass::IN
;
118 dr
.d_clen
= serialized
.size();
119 PacketReader
pr(std::string_view(reinterpret_cast<const char*>(packet
.data()), packet
.size()), packet
.size() - serialized
.size() - sizeof(dnsrecordheader
));
120 /* needed to get the record boundaries right */
121 pr
.getDnsrecordheader(drh
);
122 auto content
= DNSRecordContent::make(dr
, pr
, Opcode::Query
);
126 std::shared_ptr
<DNSRecordContent
> DNSRecordContent::make(const DNSRecord
& dr
,
129 uint16_t searchclass
= (dr
.d_type
== QType::OPT
) ? 1 : dr
.d_class
; // class is invalid for OPT
131 auto i
= getTypemap().find(pair(searchclass
, dr
.d_type
));
132 if(i
==getTypemap().end() || !i
->second
) {
133 return std::make_shared
<UnknownRecordContent
>(dr
, pr
);
136 return i
->second(dr
, pr
);
139 std::shared_ptr
<DNSRecordContent
> DNSRecordContent::make(uint16_t qtype
, uint16_t qclass
,
140 const string
& content
)
142 auto i
= getZmakermap().find(pair(qclass
, qtype
));
143 if(i
==getZmakermap().end()) {
144 return std::make_shared
<UnknownRecordContent
>(content
);
147 return i
->second(content
);
150 std::shared_ptr
<DNSRecordContent
> DNSRecordContent::make(const DNSRecord
& dr
, PacketReader
& pr
, uint16_t oc
)
152 // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
153 // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
154 // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
155 if (oc
== Opcode::Update
&& dr
.d_place
== DNSResourceRecord::ANSWER
&& dr
.d_class
!= 1)
156 return std::make_shared
<UnknownRecordContent
>(dr
, pr
);
158 uint16_t searchclass
= (dr
.d_type
== QType::OPT
) ? 1 : dr
.d_class
; // class is invalid for OPT
160 auto i
= getTypemap().find(pair(searchclass
, dr
.d_type
));
161 if(i
==getTypemap().end() || !i
->second
) {
162 return std::make_shared
<UnknownRecordContent
>(dr
, pr
);
165 return i
->second(dr
, pr
);
168 string
DNSRecordContent::upgradeContent(const DNSName
& qname
, const QType
& qtype
, const string
& content
) {
169 // seamless upgrade for previously unsupported but now implemented types.
170 UnknownRecordContent
unknown_content(content
);
171 shared_ptr
<DNSRecordContent
> rc
= DNSRecordContent::deserialize(qname
, qtype
.getCode(), unknown_content
.serialize(qname
));
172 return rc
->getZoneRepresentation();
175 DNSRecordContent::typemap_t
& DNSRecordContent::getTypemap()
177 static DNSRecordContent::typemap_t typemap
;
181 DNSRecordContent::n2typemap_t
& DNSRecordContent::getN2Typemap()
183 static DNSRecordContent::n2typemap_t n2typemap
;
187 DNSRecordContent::t2namemap_t
& DNSRecordContent::getT2Namemap()
189 static DNSRecordContent::t2namemap_t t2namemap
;
193 DNSRecordContent::zmakermap_t
& DNSRecordContent::getZmakermap()
195 static DNSRecordContent::zmakermap_t zmakermap
;
199 bool DNSRecordContent::isRegisteredType(uint16_t rtype
, uint16_t rclass
)
201 return getTypemap().count(pair(rclass
, rtype
)) != 0;
204 DNSRecord::DNSRecord(const DNSResourceRecord
& rr
): d_name(rr
.qname
)
206 d_type
= rr
.qtype
.getCode();
209 d_place
= DNSResourceRecord::ANSWER
;
211 d_content
= DNSRecordContent::make(d_type
, rr
.qclass
, rr
.content
);
214 // If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
215 DNSResourceRecord
DNSResourceRecord::fromWire(const DNSRecord
& wire
)
217 DNSResourceRecord resourceRecord
;
218 resourceRecord
.qname
= wire
.d_name
;
219 resourceRecord
.qtype
= QType(wire
.d_type
);
220 resourceRecord
.ttl
= wire
.d_ttl
;
221 resourceRecord
.content
= wire
.getContent()->getZoneRepresentation(true);
222 resourceRecord
.auth
= false;
223 resourceRecord
.qclass
= wire
.d_class
;
224 return resourceRecord
;
227 void MOADNSParser::init(bool query
, const std::string_view
& packet
)
229 if (packet
.size() < sizeof(dnsheader
))
230 throw MOADNSException("Packet shorter than minimal header");
232 memcpy(&d_header
, packet
.data(), sizeof(dnsheader
));
234 if(d_header
.opcode
!= Opcode::Query
&& d_header
.opcode
!= Opcode::Notify
&& d_header
.opcode
!= Opcode::Update
)
235 throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header
.opcode
));
237 d_header
.qdcount
=ntohs(d_header
.qdcount
);
238 d_header
.ancount
=ntohs(d_header
.ancount
);
239 d_header
.nscount
=ntohs(d_header
.nscount
);
240 d_header
.arcount
=ntohs(d_header
.arcount
);
242 if (query
&& (d_header
.qdcount
> 1))
243 throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header
.qdcount
)+")");
247 PacketReader
pr(packet
);
248 bool validPacket
=false;
250 d_qtype
= d_qclass
= 0; // sometimes replies come in with no question, don't present garbage then
252 for(n
=0;n
< d_header
.qdcount
; ++n
) {
253 d_qname
=pr
.getName();
254 d_qtype
=pr
.get16BitInt();
255 d_qclass
=pr
.get16BitInt();
258 struct dnsrecordheader ah
;
259 vector
<unsigned char> record
;
260 bool seenTSIG
= false;
262 d_answers
.reserve((unsigned int)(d_header
.ancount
+ d_header
.nscount
+ d_header
.arcount
));
263 for(n
=0;n
< (unsigned int)(d_header
.ancount
+ d_header
.nscount
+ d_header
.arcount
); ++n
) {
266 if(n
< d_header
.ancount
)
267 dr
.d_place
=DNSResourceRecord::ANSWER
;
268 else if(n
< d_header
.ancount
+ d_header
.nscount
)
269 dr
.d_place
=DNSResourceRecord::AUTHORITY
;
271 dr
.d_place
=DNSResourceRecord::ADDITIONAL
;
273 unsigned int recordStartPos
=pr
.getPosition();
275 DNSName name
=pr
.getName();
277 pr
.getDnsrecordheader(ah
);
280 dr
.d_class
=ah
.d_class
;
282 dr
.d_name
= std::move(name
);
283 dr
.d_clen
= ah
.d_clen
;
286 !(d_qtype
== QType::IXFR
&& dr
.d_place
== DNSResourceRecord::AUTHORITY
&& dr
.d_type
== QType::SOA
) && // IXFR queries have a SOA in their AUTHORITY section
287 (dr
.d_place
== DNSResourceRecord::ANSWER
|| dr
.d_place
== DNSResourceRecord::AUTHORITY
|| (dr
.d_type
!= QType::OPT
&& dr
.d_type
!= QType::TSIG
&& dr
.d_type
!= QType::SIG
&& dr
.d_type
!= QType::TKEY
) || ((dr
.d_type
== QType::TSIG
|| dr
.d_type
== QType::SIG
|| dr
.d_type
== QType::TKEY
) && dr
.d_class
!= QClass::ANY
))) {
288 // cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
289 dr
.setContent(std::make_shared
<UnknownRecordContent
>(dr
, pr
));
292 // cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
293 dr
.setContent(DNSRecordContent::make(dr
, pr
, d_header
.opcode
));
296 /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned:
297 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF)
299 if (dr
.d_place
== DNSResourceRecord::ADDITIONAL
&& seenTSIG
) {
300 /* only XPF records are allowed after a TSIG */
301 throw MOADNSException("Packet ("+d_qname
.toString()+"|#"+std::to_string(d_qtype
)+") has an unexpected record ("+std::to_string(dr
.d_type
)+") after a TSIG one.");
304 if(dr
.d_type
== QType::TSIG
&& dr
.d_class
== QClass::ANY
) {
305 if(seenTSIG
|| dr
.d_place
!= DNSResourceRecord::ADDITIONAL
) {
306 throw MOADNSException("Packet ("+d_qname
.toLogString()+"|#"+std::to_string(d_qtype
)+") has a TSIG record in an invalid position.");
309 d_tsigPos
= recordStartPos
;
312 d_answers
.emplace_back(std::move(dr
), pr
.getPosition() - sizeof(dnsheader
));
316 if(pr
.getPosition()!=packet
.size()) {
317 throw MOADNSException("Packet ("+d_qname
+"|#"+std::to_string(d_qtype
)+") has trailing garbage ("+ std::to_string(pr
.getPosition()) + " < " +
318 std::to_string(packet
.size()) + ")");
322 catch(const std::out_of_range
&re
) {
323 if(validPacket
&& d_header
.tc
) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
324 if(n
< d_header
.ancount
) {
325 d_header
.ancount
=n
; d_header
.nscount
= d_header
.arcount
= 0;
327 else if(n
< d_header
.ancount
+ d_header
.nscount
) {
328 d_header
.nscount
= n
- d_header
.ancount
; d_header
.arcount
=0;
331 d_header
.arcount
= n
- d_header
.ancount
- d_header
.nscount
;
335 throw MOADNSException("Error parsing packet of "+std::to_string(packet
.size())+" bytes (rd="+
336 std::to_string(d_header
.rd
)+
337 "), out of bounds: "+string(re
.what()));
342 bool MOADNSParser::hasEDNS() const
344 if (d_header
.arcount
== 0 || d_answers
.empty()) {
348 for (const auto& record
: d_answers
) {
349 if (record
.first
.d_place
== DNSResourceRecord::ADDITIONAL
&& record
.first
.d_type
== QType::OPT
) {
357 void PacketReader::getDnsrecordheader(struct dnsrecordheader
&ah
)
359 unsigned char *p
= reinterpret_cast<unsigned char*>(&ah
);
361 for(unsigned int n
= 0; n
< sizeof(dnsrecordheader
); ++n
) {
362 p
[n
] = d_content
.at(d_pos
++);
365 ah
.d_type
= ntohs(ah
.d_type
);
366 ah
.d_class
= ntohs(ah
.d_class
);
367 ah
.d_clen
= ntohs(ah
.d_clen
);
368 ah
.d_ttl
= ntohl(ah
.d_ttl
);
370 d_startrecordpos
= d_pos
; // needed for getBlob later on
371 d_recordlen
= ah
.d_clen
;
375 void PacketReader::copyRecord(vector
<unsigned char>& dest
, uint16_t len
)
380 if ((d_pos
+ len
) > d_content
.size()) {
381 throw std::out_of_range("Attempt to copy outside of packet");
386 for (uint16_t n
= 0; n
< len
; ++n
) {
387 dest
.at(n
) = d_content
.at(d_pos
++);
391 void PacketReader::copyRecord(unsigned char* dest
, uint16_t len
)
393 if (d_pos
+ len
> d_content
.size()) {
394 throw std::out_of_range("Attempt to copy outside of packet");
397 memcpy(dest
, &d_content
.at(d_pos
), len
);
401 void PacketReader::xfrNodeOrLocatorID(NodeOrLocatorID
& ret
)
403 if (d_pos
+ sizeof(ret
) > d_content
.size()) {
404 throw std::out_of_range("Attempt to read 64 bit value outside of packet");
406 memcpy(&ret
.content
, &d_content
.at(d_pos
), sizeof(ret
.content
));
407 d_pos
+= sizeof(ret
);
410 void PacketReader::xfr48BitInt(uint64_t& ret
)
413 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
415 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
417 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
419 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
421 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
423 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
426 uint32_t PacketReader::get32BitInt()
429 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
431 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
433 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
435 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
441 uint16_t PacketReader::get16BitInt()
444 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
446 ret
+=static_cast<uint8_t>(d_content
.at(d_pos
++));
451 uint8_t PacketReader::get8BitInt()
453 return d_content
.at(d_pos
++);
456 DNSName
PacketReader::getName()
458 unsigned int consumed
;
460 DNSName
dn((const char*) d_content
.data(), d_content
.size(), d_pos
, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed
, sizeof(dnsheader
));
465 catch(const std::range_error
& re
) {
466 throw std::out_of_range(string("dnsname issue: ")+re
.what());
469 throw std::out_of_range("dnsname issue");
471 throw PDNSException("PacketReader::getName(): name is empty");
474 static string
txtEscape(const string
&name
)
480 if((unsigned char) i
>= 127 || (unsigned char) i
< 32) {
481 snprintf(ebuf
, sizeof(ebuf
), "\\%03u", (unsigned char)i
);
484 else if(i
=='"' || i
=='\\'){
494 // exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
495 string
PacketReader::getText(bool multi
, bool lenField
)
499 while(d_pos
< d_startrecordpos
+ d_recordlen
) {
505 labellen
=static_cast<uint8_t>(d_content
.at(d_pos
++));
507 labellen
=d_recordlen
- (d_pos
- d_startrecordpos
);
510 if(labellen
) { // no need to do anything for an empty string
511 string
val(&d_content
.at(d_pos
), &d_content
.at(d_pos
+labellen
-1)+1);
512 ret
.append(txtEscape(val
)); // the end is one beyond the packet
520 if (ret
.empty() && !lenField
) {
521 // all lenField == false cases (CAA and URI at the time of this writing) want that emptiness to be explicit
527 string
PacketReader::getUnquotedText(bool lenField
)
531 stop_at
= static_cast<uint8_t>(d_content
.at(d_pos
)) + d_pos
+ 1;
533 stop_at
= d_recordlen
;
535 /* think unsigned overflow */
536 if (stop_at
< d_pos
) {
537 throw std::out_of_range("getUnquotedText out of record range");
544 string
ret(d_content
.substr(d_pos
, stop_at
-d_pos
));
549 void PacketReader::xfrBlob(string
& blob
)
552 if(d_recordlen
&& !(d_pos
== (d_startrecordpos
+ d_recordlen
))) {
553 if (d_pos
> (d_startrecordpos
+ d_recordlen
)) {
554 throw std::out_of_range("xfrBlob out of record range");
556 blob
.assign(&d_content
.at(d_pos
), &d_content
.at(d_startrecordpos
+ d_recordlen
- 1 ) + 1);
562 d_pos
= d_startrecordpos
+ d_recordlen
;
566 throw std::out_of_range("xfrBlob out of range");
570 void PacketReader::xfrBlobNoSpaces(string
& blob
, int length
) {
571 xfrBlob(blob
, length
);
574 void PacketReader::xfrBlob(string
& blob
, int length
)
578 throw std::out_of_range("xfrBlob out of range (negative length)");
581 blob
.assign(&d_content
.at(d_pos
), &d_content
.at(d_pos
+ length
- 1 ) + 1 );
590 void PacketReader::xfrSvcParamKeyVals(set
<SvcParam
> &kvs
) {
591 while (d_pos
< (d_startrecordpos
+ d_recordlen
)) {
592 if (d_pos
+ 2 > (d_startrecordpos
+ d_recordlen
)) {
593 throw std::out_of_range("incomplete key");
597 auto key
= static_cast<SvcParam::SvcParamKey
>(keyInt
);
601 if (d_pos
+ len
> (d_startrecordpos
+ d_recordlen
)) {
602 throw std::out_of_range("record is shorter than SVCB lengthfield implies");
607 case SvcParam::mandatory
: {
609 throw std::out_of_range("mandatory SvcParam has invalid length");
612 throw std::out_of_range("empty 'mandatory' values");
614 std::set
<SvcParam::SvcParamKey
> paramKeys
;
615 size_t stop
= d_pos
+ len
;
616 while (d_pos
< stop
) {
619 paramKeys
.insert(static_cast<SvcParam::SvcParamKey
>(keyval
));
621 kvs
.insert(SvcParam(key
, std::move(paramKeys
)));
624 case SvcParam::alpn
: {
625 size_t stop
= d_pos
+ len
;
626 std::vector
<string
> alpns
;
627 while (d_pos
< stop
) {
632 throw std::out_of_range("alpn length of 0");
634 xfrBlob(alpn
, alpnLen
);
635 alpns
.push_back(alpn
);
637 kvs
.insert(SvcParam(key
, std::move(alpns
)));
640 case SvcParam::no_default_alpn
: {
642 throw std::out_of_range("invalid length for no-default-alpn");
644 kvs
.insert(SvcParam(key
));
647 case SvcParam::port
: {
649 throw std::out_of_range("invalid length for port");
653 kvs
.insert(SvcParam(key
, port
));
656 case SvcParam::ipv4hint
: /* fall-through */
657 case SvcParam::ipv6hint
: {
658 size_t addrLen
= (key
== SvcParam::ipv4hint
? 4 : 16);
659 if (len
% addrLen
!= 0) {
660 throw std::out_of_range("invalid length for " + SvcParam::keyToString(key
));
662 vector
<ComboAddress
> addresses
;
663 auto stop
= d_pos
+ len
;
667 xfrCAWithoutPort(key
, addr
);
668 addresses
.push_back(addr
);
670 kvs
.insert(SvcParam(key
, std::move(addresses
)));
673 case SvcParam::ech
: {
676 xfrBlobNoSpaces(blob
, len
);
677 kvs
.insert(SvcParam(key
, blob
));
684 kvs
.insert(SvcParam(key
, blob
));
692 void PacketReader::xfrHexBlob(string
& blob
, bool /* keepReading */)
697 //FIXME400 remove this method completely
698 string
simpleCompress(const string
& elabel
, const string
& root
)
701 // FIXME400: this relies on the semi-canonical escaped output from getName
702 if(strchr(label
.c_str(), '\\')) {
703 boost::replace_all(label
, "\\.", ".");
704 boost::replace_all(label
, "\\032", " ");
705 boost::replace_all(label
, "\\\\", "\\");
707 typedef vector
<pair
<unsigned int, unsigned int> > parts_t
;
709 vstringtok(parts
, label
, ".");
711 ret
.reserve(label
.size()+4);
712 for(const auto & part
: parts
) {
713 if(!root
.empty() && !strncasecmp(root
.c_str(), label
.c_str() + part
.first
, 1 + label
.length() - part
.first
)) { // also match trailing 0, hence '1 +'
714 const unsigned char rootptr
[2]={0xc0,0x11};
715 ret
.append((const char *) rootptr
, 2);
718 ret
.append(1, (char)(part
.second
- part
.first
));
719 ret
.append(label
.c_str() + part
.first
, part
.second
- part
.first
);
721 ret
.append(1, (char)0);
725 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
726 void editDNSPacketTTL(char* packet
, size_t length
, const std::function
<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor
)
728 if(length
< sizeof(dnsheader
))
733 memcpy((void*)&dh
, (const dnsheader
*)packet
, sizeof(dh
));
734 uint64_t numrecords
= ntohs(dh
.ancount
) + ntohs(dh
.nscount
) + ntohs(dh
.arcount
);
735 DNSPacketMangler
dpm(packet
, length
);
738 for(n
=0; n
< ntohs(dh
.qdcount
) ; ++n
) {
739 dpm
.skipDomainName();
744 for(n
=0; n
< numrecords
; ++n
) {
745 dpm
.skipDomainName();
747 uint8_t section
= n
< ntohs(dh
.ancount
) ? 1 : (n
< (ntohs(dh
.ancount
) + ntohs(dh
.nscount
)) ? 2 : 3);
748 uint16_t dnstype
= dpm
.get16BitInt();
749 uint16_t dnsclass
= dpm
.get16BitInt();
751 if(dnstype
== QType::OPT
) // not getting near that one with a stick
754 uint32_t dnsttl
= dpm
.get32BitInt();
755 uint32_t newttl
= visitor(section
, dnsclass
, dnstype
, dnsttl
);
757 dpm
.rewindBytes(sizeof(newttl
));
758 dpm
.setAndSkip32BitInt(newttl
);
769 static bool checkIfPacketContainsRecords(const PacketBuffer
& packet
, const std::unordered_set
<QType
>& qtypes
)
771 auto length
= packet
.size();
772 if (length
< sizeof(dnsheader
)) {
777 const dnsheader_aligned
dh(packet
.data());
778 DNSPacketMangler
dpm(const_cast<char*>(reinterpret_cast<const char*>(packet
.data())), length
);
780 const uint16_t qdcount
= ntohs(dh
->qdcount
);
781 for (size_t n
= 0; n
< qdcount
; ++n
) {
782 dpm
.skipDomainName();
786 const size_t recordsCount
= static_cast<size_t>(ntohs(dh
->ancount
)) + ntohs(dh
->nscount
) + ntohs(dh
->arcount
);
787 for (size_t n
= 0; n
< recordsCount
; ++n
) {
788 dpm
.skipDomainName();
789 uint16_t dnstype
= dpm
.get16BitInt();
790 uint16_t dnsclass
= dpm
.get16BitInt();
791 if (dnsclass
== QClass::IN
&& qtypes
.count(dnstype
) > 0) {
805 static int rewritePacketWithoutRecordTypes(const PacketBuffer
& initialPacket
, PacketBuffer
& newContent
, const std::unordered_set
<QType
>& qtypes
)
807 static const std::unordered_set
<QType
>& safeTypes
{QType::A
, QType::AAAA
, QType::DHCID
, QType::TXT
, QType::OPT
, QType::HINFO
, QType::DNSKEY
, QType::CDNSKEY
, QType::DS
, QType::CDS
, QType::DLV
, QType::SSHFP
, QType::KEY
, QType::CERT
, QType::TLSA
, QType::SMIMEA
, QType::OPENPGPKEY
, QType::SVCB
, QType::HTTPS
, QType::NSEC3
, QType::CSYNC
, QType::NSEC3PARAM
, QType::LOC
, QType::NID
, QType::L32
, QType::L64
, QType::EUI48
, QType::EUI64
, QType::URI
, QType::CAA
};
809 if (initialPacket
.size() < sizeof(dnsheader
)) {
813 const dnsheader_aligned
dh(initialPacket
.data());
815 if (ntohs(dh
->qdcount
) == 0)
817 auto packetView
= std::string_view(reinterpret_cast<const char*>(initialPacket
.data()), initialPacket
.size());
819 PacketReader
pr(packetView
);
823 uint16_t qdcount
= ntohs(dh
->qdcount
);
824 uint16_t ancount
= ntohs(dh
->ancount
);
825 uint16_t nscount
= ntohs(dh
->nscount
);
826 uint16_t arcount
= ntohs(dh
->arcount
);
830 struct dnsrecordheader ah
;
832 rrname
= pr
.getName();
833 rrtype
= pr
.get16BitInt();
834 rrclass
= pr
.get16BitInt();
836 GenericDNSPacketWriter
<PacketBuffer
> pw(newContent
, rrname
, rrtype
, rrclass
, dh
->opcode
);
837 pw
.getHeader()->id
=dh
->id
;
838 pw
.getHeader()->qr
=dh
->qr
;
839 pw
.getHeader()->aa
=dh
->aa
;
840 pw
.getHeader()->tc
=dh
->tc
;
841 pw
.getHeader()->rd
=dh
->rd
;
842 pw
.getHeader()->ra
=dh
->ra
;
843 pw
.getHeader()->ad
=dh
->ad
;
844 pw
.getHeader()->cd
=dh
->cd
;
845 pw
.getHeader()->rcode
=dh
->rcode
;
847 /* consume remaining qd if any */
849 for(idx
= 1; idx
< qdcount
; idx
++) {
850 rrname
= pr
.getName();
851 rrtype
= pr
.get16BitInt();
852 rrclass
= pr
.get16BitInt();
859 for (idx
= 0; idx
< ancount
; idx
++) {
860 rrname
= pr
.getName();
861 pr
.getDnsrecordheader(ah
);
864 if (qtypes
.find(ah
.d_type
) == qtypes
.end()) {
865 // if this is not a safe type
866 if (safeTypes
.find(ah
.d_type
) == safeTypes
.end()) {
867 // "unsafe" types might countain compressed data, so cancel rewrite
871 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ANSWER
, true);
877 for (idx
= 0; idx
< nscount
; idx
++) {
878 rrname
= pr
.getName();
879 pr
.getDnsrecordheader(ah
);
882 if (qtypes
.find(ah
.d_type
) == qtypes
.end()) {
883 if (safeTypes
.find(ah
.d_type
) == safeTypes
.end()) {
884 // "unsafe" types might countain compressed data, so cancel rewrite
888 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::AUTHORITY
, true);
893 for (idx
= 0; idx
< arcount
; idx
++) {
894 rrname
= pr
.getName();
895 pr
.getDnsrecordheader(ah
);
898 if (qtypes
.find(ah
.d_type
) == qtypes
.end()) {
899 if (safeTypes
.find(ah
.d_type
) == safeTypes
.end()) {
900 // "unsafe" types might countain compressed data, so cancel rewrite
904 pw
.startRecord(rrname
, ah
.d_type
, ah
.d_ttl
, ah
.d_class
, DNSResourceRecord::ADDITIONAL
, true);
919 void clearDNSPacketRecordTypes(vector
<uint8_t>& packet
, const std::unordered_set
<QType
>& qtypes
)
921 return clearDNSPacketRecordTypes(reinterpret_cast<PacketBuffer
&>(packet
), qtypes
);
924 void clearDNSPacketRecordTypes(PacketBuffer
& packet
, const std::unordered_set
<QType
>& qtypes
)
926 if (!checkIfPacketContainsRecords(packet
, qtypes
)) {
930 PacketBuffer newContent
;
932 auto result
= rewritePacketWithoutRecordTypes(packet
, newContent
, qtypes
);
934 packet
= std::move(newContent
);
938 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
939 void ageDNSPacket(char* packet
, size_t length
, uint32_t seconds
, const dnsheader_aligned
& aligned_dh
)
941 if (length
< sizeof(dnsheader
)) {
945 const dnsheader
* dhp
= aligned_dh
.get();
946 const uint64_t dqcount
= ntohs(dhp
->qdcount
);
947 const uint64_t numrecords
= ntohs(dhp
->ancount
) + ntohs(dhp
->nscount
) + ntohs(dhp
->arcount
);
948 DNSPacketMangler
dpm(packet
, length
);
950 for (uint64_t rec
= 0; rec
< dqcount
; ++rec
) {
951 dpm
.skipDomainName();
956 for(uint64_t rec
= 0; rec
< numrecords
; ++rec
) {
957 dpm
.skipDomainName();
959 uint16_t dnstype
= dpm
.get16BitInt();
963 if (dnstype
!= QType::OPT
) { // not aging that one with a stick
964 dpm
.decreaseAndSkip32BitInt(seconds
);
975 void ageDNSPacket(std::string
& packet
, uint32_t seconds
, const dnsheader_aligned
& aligned_dh
)
977 ageDNSPacket(packet
.data(), packet
.length(), seconds
, aligned_dh
);
980 uint32_t getDNSPacketMinTTL(const char* packet
, size_t length
, bool* seenAuthSOA
)
982 uint32_t result
= std::numeric_limits
<uint32_t>::max();
983 if(length
< sizeof(dnsheader
)) {
988 const dnsheader_aligned
dh(packet
);
989 DNSPacketMangler
dpm(const_cast<char*>(packet
), length
);
991 const uint16_t qdcount
= ntohs(dh
->qdcount
);
992 for(size_t n
= 0; n
< qdcount
; ++n
) {
993 dpm
.skipDomainName();
997 const size_t numrecords
= ntohs(dh
->ancount
) + ntohs(dh
->nscount
) + ntohs(dh
->arcount
);
998 for(size_t n
= 0; n
< numrecords
; ++n
) {
999 dpm
.skipDomainName();
1000 const uint16_t dnstype
= dpm
.get16BitInt();
1002 const uint16_t dnsclass
= dpm
.get16BitInt();
1004 if(dnstype
== QType::OPT
) {
1008 /* report it if we see a SOA record in the AUTHORITY section */
1009 if(dnstype
== QType::SOA
&& dnsclass
== QClass::IN
&& seenAuthSOA
!= nullptr && n
>= ntohs(dh
->ancount
) && n
< (ntohs(dh
->ancount
) + ntohs(dh
->nscount
))) {
1010 *seenAuthSOA
= true;
1013 const uint32_t ttl
= dpm
.get32BitInt();
1027 uint32_t getDNSPacketLength(const char* packet
, size_t length
)
1029 uint32_t result
= length
;
1030 if(length
< sizeof(dnsheader
)) {
1035 const dnsheader_aligned
dh(packet
);
1036 DNSPacketMangler
dpm(const_cast<char*>(packet
), length
);
1038 const uint16_t qdcount
= ntohs(dh
->qdcount
);
1039 for(size_t n
= 0; n
< qdcount
; ++n
) {
1040 dpm
.skipDomainName();
1041 /* type and class */
1044 const size_t numrecords
= ntohs(dh
->ancount
) + ntohs(dh
->nscount
) + ntohs(dh
->arcount
);
1045 for(size_t n
= 0; n
< numrecords
; ++n
) {
1046 dpm
.skipDomainName();
1047 /* type (2), class (2) and ttl (4) */
1051 result
= dpm
.getOffset();
1059 uint16_t getRecordsOfTypeCount(const char* packet
, size_t length
, uint8_t section
, uint16_t type
)
1061 uint16_t result
= 0;
1062 if(length
< sizeof(dnsheader
)) {
1067 const dnsheader_aligned
dh(packet
);
1068 DNSPacketMangler
dpm(const_cast<char*>(packet
), length
);
1070 const uint16_t qdcount
= ntohs(dh
->qdcount
);
1071 for(size_t n
= 0; n
< qdcount
; ++n
) {
1072 dpm
.skipDomainName();
1074 uint16_t dnstype
= dpm
.get16BitInt();
1075 if (dnstype
== type
) {
1081 /* type and class */
1085 const uint16_t ancount
= ntohs(dh
->ancount
);
1086 for(size_t n
= 0; n
< ancount
; ++n
) {
1087 dpm
.skipDomainName();
1089 uint16_t dnstype
= dpm
.get16BitInt();
1090 if (dnstype
== type
) {
1096 /* type and class */
1103 const uint16_t nscount
= ntohs(dh
->nscount
);
1104 for(size_t n
= 0; n
< nscount
; ++n
) {
1105 dpm
.skipDomainName();
1107 uint16_t dnstype
= dpm
.get16BitInt();
1108 if (dnstype
== type
) {
1114 /* type and class */
1121 const uint16_t arcount
= ntohs(dh
->arcount
);
1122 for(size_t n
= 0; n
< arcount
; ++n
) {
1123 dpm
.skipDomainName();
1125 uint16_t dnstype
= dpm
.get16BitInt();
1126 if (dnstype
== type
) {
1132 /* type and class */
1146 bool getEDNSUDPPayloadSizeAndZ(const char* packet
, size_t length
, uint16_t* payloadSize
, uint16_t* z
)
1148 if (length
< sizeof(dnsheader
)) {
1157 const dnsheader_aligned
dh(packet
);
1158 DNSPacketMangler
dpm(const_cast<char*>(packet
), length
);
1160 const uint16_t qdcount
= ntohs(dh
->qdcount
);
1161 for(size_t n
= 0; n
< qdcount
; ++n
) {
1162 dpm
.skipDomainName();
1163 /* type and class */
1166 const size_t numrecords
= ntohs(dh
->ancount
) + ntohs(dh
->nscount
) + ntohs(dh
->arcount
);
1167 for(size_t n
= 0; n
< numrecords
; ++n
) {
1168 dpm
.skipDomainName();
1169 const uint16_t dnstype
= dpm
.get16BitInt();
1170 const uint16_t dnsclass
= dpm
.get16BitInt();
1172 if(dnstype
== QType::OPT
) {
1173 /* skip extended rcode and version */
1175 *z
= dpm
.get16BitInt();
1176 *payloadSize
= dnsclass
;
1192 bool visitDNSPacket(const std::string_view
& packet
, const std::function
<bool(uint8_t, uint16_t, uint16_t, uint32_t, uint16_t, const char*)>& visitor
)
1194 if (packet
.size() < sizeof(dnsheader
)) {
1200 const dnsheader_aligned
dh(packet
.data());
1201 uint64_t numrecords
= ntohs(dh
->ancount
) + ntohs(dh
->nscount
) + ntohs(dh
->arcount
);
1202 PacketReader
reader(packet
);
1205 for (n
= 0; n
< ntohs(dh
->qdcount
) ; ++n
) {
1206 (void) reader
.getName();
1207 /* type and class */
1211 for (n
= 0; n
< numrecords
; ++n
) {
1212 (void) reader
.getName();
1214 uint8_t section
= n
< ntohs(dh
->ancount
) ? 1 : (n
< (ntohs(dh
->ancount
) + ntohs(dh
->nscount
)) ? 2 : 3);
1215 uint16_t dnstype
= reader
.get16BitInt();
1216 uint16_t dnsclass
= reader
.get16BitInt();
1218 if (dnstype
== QType::OPT
) {
1219 // not getting near that one with a stick
1223 uint32_t dnsttl
= reader
.get32BitInt();
1224 uint16_t contentLength
= reader
.get16BitInt();
1225 uint16_t pos
= reader
.getPosition();
1227 bool done
= visitor(section
, dnsclass
, dnstype
, dnsttl
, contentLength
, &packet
.at(pos
));
1232 reader
.skip(contentLength
);