]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsparser.cc
dnsdist: Fix DNS over plain HTTP broken by `reloadAllCertificates()`
[thirdparty/pdns.git] / pdns / dnsparser.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsparser.hh"
23 #include "dnswriter.hh"
24 #include <boost/algorithm/string.hpp>
25 #include <boost/format.hpp>
26
27 #include "namespaces.hh"
28 #include "noinitvector.hh"
29
30 UnknownRecordContent::UnknownRecordContent(const string& zone)
31 {
32 // parse the input
33 vector<string> parts;
34 stringtok(parts, zone);
35 // we need exactly 3 parts, except if the length field is set to 0 then we only need 2
36 if (parts.size() != 3 && !(parts.size() == 2 && boost::equals(parts.at(1), "0"))) {
37 throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got " + std::to_string(parts.size()) + ": " + zone);
38 }
39
40 if (parts.at(0) != "\\#") {
41 throw MOADNSException("Unknown record was stored incorrectly, first part should be '\\#', got '" + parts.at(0) + "'");
42 }
43
44 const string& relevant = (parts.size() > 2) ? parts.at(2) : "";
45 auto total = pdns::checked_stoi<unsigned int>(parts.at(1));
46 if (relevant.size() % 2 || (relevant.size() / 2) != total) {
47 throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str());
48 }
49
50 string out;
51 out.reserve(total + 1);
52
53 for (unsigned int n = 0; n < total; ++n) {
54 int c;
55 if (sscanf(&relevant.at(2*n), "%02x", &c) != 1) {
56 throw MOADNSException("unable to read data at position " + std::to_string(2 * n) + " from unknown record of size " + std::to_string(relevant.size()));
57 }
58 out.append(1, (char)c);
59 }
60
61 d_record.insert(d_record.end(), out.begin(), out.end());
62 }
63
64 string UnknownRecordContent::getZoneRepresentation(bool /* noDot */) const
65 {
66 ostringstream str;
67 str<<"\\# "<<(unsigned int)d_record.size()<<" ";
68 char hex[4];
69 for (unsigned char n : d_record) {
70 snprintf(hex, sizeof(hex), "%02x", n);
71 str << hex;
72 }
73 return str.str();
74 }
75
76 void UnknownRecordContent::toPacket(DNSPacketWriter& pw) const
77 {
78 pw.xfrBlob(string(d_record.begin(),d_record.end()));
79 }
80
81 shared_ptr<DNSRecordContent> DNSRecordContent::deserialize(const DNSName& qname, uint16_t qtype, const string& serialized)
82 {
83 dnsheader dnsheader;
84 memset(&dnsheader, 0, sizeof(dnsheader));
85 dnsheader.qdcount=htons(1);
86 dnsheader.ancount=htons(1);
87
88 PacketBuffer packet; // build pseudo packet
89 /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
90 const auto& encoded = qname.getStorage();
91 packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size());
92
93 uint16_t pos=0;
94 memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader);
95
96 constexpr std::array<uint8_t, 5> tmp= {'\x0', '\x0', '\x1', '\x0', '\x1' }; // root question for ns_t_a
97 memcpy(&packet[pos], tmp.data(), tmp.size()); pos += tmp.size();
98
99 memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size();
100
101 struct dnsrecordheader drh;
102 drh.d_type=htons(qtype);
103 drh.d_class=htons(QClass::IN);
104 drh.d_ttl=0;
105 drh.d_clen=htons(serialized.size());
106
107 memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh);
108 if (!serialized.empty()) {
109 memcpy(&packet[pos], serialized.c_str(), serialized.size());
110 pos += (uint16_t) serialized.size();
111 (void) pos;
112 }
113
114 DNSRecord dr;
115 dr.d_class = QClass::IN;
116 dr.d_type = qtype;
117 dr.d_name = qname;
118 dr.d_clen = serialized.size();
119 PacketReader pr(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()), packet.size() - serialized.size() - sizeof(dnsrecordheader));
120 /* needed to get the record boundaries right */
121 pr.getDnsrecordheader(drh);
122 auto content = DNSRecordContent::make(dr, pr, Opcode::Query);
123 return content;
124 }
125
126 std::shared_ptr<DNSRecordContent> DNSRecordContent::make(const DNSRecord& dr,
127 PacketReader& pr)
128 {
129 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
130
131 auto i = getTypemap().find(pair(searchclass, dr.d_type));
132 if(i==getTypemap().end() || !i->second) {
133 return std::make_shared<UnknownRecordContent>(dr, pr);
134 }
135
136 return i->second(dr, pr);
137 }
138
139 std::shared_ptr<DNSRecordContent> DNSRecordContent::make(uint16_t qtype, uint16_t qclass,
140 const string& content)
141 {
142 auto i = getZmakermap().find(pair(qclass, qtype));
143 if(i==getZmakermap().end()) {
144 return std::make_shared<UnknownRecordContent>(content);
145 }
146
147 return i->second(content);
148 }
149
150 std::shared_ptr<DNSRecordContent> DNSRecordContent::make(const DNSRecord& dr, PacketReader& pr, uint16_t oc)
151 {
152 // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
153 // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
154 // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
155 if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1)
156 return std::make_shared<UnknownRecordContent>(dr, pr);
157
158 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
159
160 auto i = getTypemap().find(pair(searchclass, dr.d_type));
161 if(i==getTypemap().end() || !i->second) {
162 return std::make_shared<UnknownRecordContent>(dr, pr);
163 }
164
165 return i->second(dr, pr);
166 }
167
168 string DNSRecordContent::upgradeContent(const DNSName& qname, const QType& qtype, const string& content) {
169 // seamless upgrade for previously unsupported but now implemented types.
170 UnknownRecordContent unknown_content(content);
171 shared_ptr<DNSRecordContent> rc = DNSRecordContent::deserialize(qname, qtype.getCode(), unknown_content.serialize(qname));
172 return rc->getZoneRepresentation();
173 }
174
175 DNSRecordContent::typemap_t& DNSRecordContent::getTypemap()
176 {
177 static DNSRecordContent::typemap_t typemap;
178 return typemap;
179 }
180
181 DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap()
182 {
183 static DNSRecordContent::n2typemap_t n2typemap;
184 return n2typemap;
185 }
186
187 DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap()
188 {
189 static DNSRecordContent::t2namemap_t t2namemap;
190 return t2namemap;
191 }
192
193 DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap()
194 {
195 static DNSRecordContent::zmakermap_t zmakermap;
196 return zmakermap;
197 }
198
199 bool DNSRecordContent::isRegisteredType(uint16_t rtype, uint16_t rclass)
200 {
201 return getTypemap().count(pair(rclass, rtype)) != 0;
202 }
203
204 DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname)
205 {
206 d_type = rr.qtype.getCode();
207 d_ttl = rr.ttl;
208 d_class = rr.qclass;
209 d_place = DNSResourceRecord::ANSWER;
210 d_clen = 0;
211 d_content = DNSRecordContent::make(d_type, rr.qclass, rr.content);
212 }
213
214 // If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
215 DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& wire)
216 {
217 DNSResourceRecord resourceRecord;
218 resourceRecord.qname = wire.d_name;
219 resourceRecord.qtype = QType(wire.d_type);
220 resourceRecord.ttl = wire.d_ttl;
221 resourceRecord.content = wire.getContent()->getZoneRepresentation(true);
222 resourceRecord.auth = false;
223 resourceRecord.qclass = wire.d_class;
224 return resourceRecord;
225 }
226
227 void MOADNSParser::init(bool query, const std::string_view& packet)
228 {
229 if (packet.size() < sizeof(dnsheader))
230 throw MOADNSException("Packet shorter than minimal header");
231
232 memcpy(&d_header, packet.data(), sizeof(dnsheader));
233
234 if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
235 throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
236
237 d_header.qdcount=ntohs(d_header.qdcount);
238 d_header.ancount=ntohs(d_header.ancount);
239 d_header.nscount=ntohs(d_header.nscount);
240 d_header.arcount=ntohs(d_header.arcount);
241
242 if (query && (d_header.qdcount > 1))
243 throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
244
245 unsigned int n=0;
246
247 PacketReader pr(packet);
248 bool validPacket=false;
249 try {
250 d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
251
252 for(n=0;n < d_header.qdcount; ++n) {
253 d_qname=pr.getName();
254 d_qtype=pr.get16BitInt();
255 d_qclass=pr.get16BitInt();
256 }
257
258 struct dnsrecordheader ah;
259 vector<unsigned char> record;
260 bool seenTSIG = false;
261 validPacket=true;
262 d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
263 for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
264 DNSRecord dr;
265
266 if(n < d_header.ancount)
267 dr.d_place=DNSResourceRecord::ANSWER;
268 else if(n < d_header.ancount + d_header.nscount)
269 dr.d_place=DNSResourceRecord::AUTHORITY;
270 else
271 dr.d_place=DNSResourceRecord::ADDITIONAL;
272
273 unsigned int recordStartPos=pr.getPosition();
274
275 DNSName name=pr.getName();
276
277 pr.getDnsrecordheader(ah);
278 dr.d_ttl=ah.d_ttl;
279 dr.d_type=ah.d_type;
280 dr.d_class=ah.d_class;
281
282 dr.d_name = std::move(name);
283 dr.d_clen = ah.d_clen;
284
285 if (query &&
286 !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section
287 (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) {
288 // cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
289 dr.setContent(std::make_shared<UnknownRecordContent>(dr, pr));
290 }
291 else {
292 // cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
293 dr.setContent(DNSRecordContent::make(dr, pr, d_header.opcode));
294 }
295
296 /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned:
297 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF)
298 */
299 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) {
300 /* only XPF records are allowed after a TSIG */
301 throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one.");
302 }
303
304 if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) {
305 if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) {
306 throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
307 }
308 seenTSIG = true;
309 d_tsigPos = recordStartPos;
310 }
311
312 d_answers.emplace_back(std::move(dr), pr.getPosition() - sizeof(dnsheader));
313 }
314
315 #if 0
316 if(pr.getPosition()!=packet.size()) {
317 throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
318 std::to_string(packet.size()) + ")");
319 }
320 #endif
321 }
322 catch(const std::out_of_range &re) {
323 if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
324 if(n < d_header.ancount) {
325 d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
326 }
327 else if(n < d_header.ancount + d_header.nscount) {
328 d_header.nscount = n - d_header.ancount; d_header.arcount=0;
329 }
330 else {
331 d_header.arcount = n - d_header.ancount - d_header.nscount;
332 }
333 }
334 else {
335 throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
336 std::to_string(d_header.rd)+
337 "), out of bounds: "+string(re.what()));
338 }
339 }
340 }
341
342 bool MOADNSParser::hasEDNS() const
343 {
344 if (d_header.arcount == 0 || d_answers.empty()) {
345 return false;
346 }
347
348 for (const auto& record : d_answers) {
349 if (record.first.d_place == DNSResourceRecord::ADDITIONAL && record.first.d_type == QType::OPT) {
350 return true;
351 }
352 }
353
354 return false;
355 }
356
357 void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
358 {
359 unsigned char *p = reinterpret_cast<unsigned char*>(&ah);
360
361 for(unsigned int n = 0; n < sizeof(dnsrecordheader); ++n) {
362 p[n] = d_content.at(d_pos++);
363 }
364
365 ah.d_type = ntohs(ah.d_type);
366 ah.d_class = ntohs(ah.d_class);
367 ah.d_clen = ntohs(ah.d_clen);
368 ah.d_ttl = ntohl(ah.d_ttl);
369
370 d_startrecordpos = d_pos; // needed for getBlob later on
371 d_recordlen = ah.d_clen;
372 }
373
374
375 void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len)
376 {
377 if (len == 0) {
378 return;
379 }
380 if ((d_pos + len) > d_content.size()) {
381 throw std::out_of_range("Attempt to copy outside of packet");
382 }
383
384 dest.resize(len);
385
386 for (uint16_t n = 0; n < len; ++n) {
387 dest.at(n) = d_content.at(d_pos++);
388 }
389 }
390
391 void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
392 {
393 if (d_pos + len > d_content.size()) {
394 throw std::out_of_range("Attempt to copy outside of packet");
395 }
396
397 memcpy(dest, &d_content.at(d_pos), len);
398 d_pos += len;
399 }
400
401 void PacketReader::xfrNodeOrLocatorID(NodeOrLocatorID& ret)
402 {
403 if (d_pos + sizeof(ret) > d_content.size()) {
404 throw std::out_of_range("Attempt to read 64 bit value outside of packet");
405 }
406 memcpy(&ret.content, &d_content.at(d_pos), sizeof(ret.content));
407 d_pos += sizeof(ret);
408 }
409
410 void PacketReader::xfr48BitInt(uint64_t& ret)
411 {
412 ret=0;
413 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
414 ret<<=8;
415 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
416 ret<<=8;
417 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
418 ret<<=8;
419 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
420 ret<<=8;
421 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
422 ret<<=8;
423 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
424 }
425
426 uint32_t PacketReader::get32BitInt()
427 {
428 uint32_t ret=0;
429 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
430 ret<<=8;
431 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
432 ret<<=8;
433 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
434 ret<<=8;
435 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
436
437 return ret;
438 }
439
440
441 uint16_t PacketReader::get16BitInt()
442 {
443 uint16_t ret=0;
444 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
445 ret<<=8;
446 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
447
448 return ret;
449 }
450
451 uint8_t PacketReader::get8BitInt()
452 {
453 return d_content.at(d_pos++);
454 }
455
456 DNSName PacketReader::getName()
457 {
458 unsigned int consumed;
459 try {
460 DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed, sizeof(dnsheader));
461
462 d_pos+=consumed;
463 return dn;
464 }
465 catch(const std::range_error& re) {
466 throw std::out_of_range(string("dnsname issue: ")+re.what());
467 }
468 catch(...) {
469 throw std::out_of_range("dnsname issue");
470 }
471 throw PDNSException("PacketReader::getName(): name is empty");
472 }
473
474 static string txtEscape(const string &name)
475 {
476 string ret;
477 char ebuf[5];
478
479 for(char i : name) {
480 if((unsigned char) i >= 127 || (unsigned char) i < 32) {
481 snprintf(ebuf, sizeof(ebuf), "\\%03u", (unsigned char)i);
482 ret += ebuf;
483 }
484 else if(i=='"' || i=='\\'){
485 ret += '\\';
486 ret += i;
487 }
488 else
489 ret += i;
490 }
491 return ret;
492 }
493
494 // exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
495 string PacketReader::getText(bool multi, bool lenField)
496 {
497 string ret;
498 ret.reserve(40);
499 while(d_pos < d_startrecordpos + d_recordlen ) {
500 if(!ret.empty()) {
501 ret.append(1,' ');
502 }
503 uint16_t labellen;
504 if(lenField)
505 labellen=static_cast<uint8_t>(d_content.at(d_pos++));
506 else
507 labellen=d_recordlen - (d_pos - d_startrecordpos);
508
509 ret.append(1,'"');
510 if(labellen) { // no need to do anything for an empty string
511 string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
512 ret.append(txtEscape(val)); // the end is one beyond the packet
513 }
514 ret.append(1,'"');
515 d_pos+=labellen;
516 if(!multi)
517 break;
518 }
519
520 if (ret.empty() && !lenField) {
521 // all lenField == false cases (CAA and URI at the time of this writing) want that emptiness to be explicit
522 return "\"\"";
523 }
524 return ret;
525 }
526
527 string PacketReader::getUnquotedText(bool lenField)
528 {
529 uint16_t stop_at;
530 if(lenField)
531 stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
532 else
533 stop_at = d_recordlen;
534
535 /* think unsigned overflow */
536 if (stop_at < d_pos) {
537 throw std::out_of_range("getUnquotedText out of record range");
538 }
539
540 if(stop_at == d_pos)
541 return "";
542
543 d_pos++;
544 string ret(d_content.substr(d_pos, stop_at-d_pos));
545 d_pos = stop_at;
546 return ret;
547 }
548
549 void PacketReader::xfrBlob(string& blob)
550 {
551 try {
552 if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen))) {
553 if (d_pos > (d_startrecordpos + d_recordlen)) {
554 throw std::out_of_range("xfrBlob out of record range");
555 }
556 blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1);
557 }
558 else {
559 blob.clear();
560 }
561
562 d_pos = d_startrecordpos + d_recordlen;
563 }
564 catch(...)
565 {
566 throw std::out_of_range("xfrBlob out of range");
567 }
568 }
569
570 void PacketReader::xfrBlobNoSpaces(string& blob, int length) {
571 xfrBlob(blob, length);
572 }
573
574 void PacketReader::xfrBlob(string& blob, int length)
575 {
576 if(length) {
577 if (length < 0) {
578 throw std::out_of_range("xfrBlob out of range (negative length)");
579 }
580
581 blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 );
582
583 d_pos += length;
584 }
585 else {
586 blob.clear();
587 }
588 }
589
590 void PacketReader::xfrSvcParamKeyVals(set<SvcParam> &kvs) {
591 while (d_pos < (d_startrecordpos + d_recordlen)) {
592 if (d_pos + 2 > (d_startrecordpos + d_recordlen)) {
593 throw std::out_of_range("incomplete key");
594 }
595 uint16_t keyInt;
596 xfr16BitInt(keyInt);
597 auto key = static_cast<SvcParam::SvcParamKey>(keyInt);
598 uint16_t len;
599 xfr16BitInt(len);
600
601 if (d_pos + len > (d_startrecordpos + d_recordlen)) {
602 throw std::out_of_range("record is shorter than SVCB lengthfield implies");
603 }
604
605 switch (key)
606 {
607 case SvcParam::mandatory: {
608 if (len % 2 != 0) {
609 throw std::out_of_range("mandatory SvcParam has invalid length");
610 }
611 if (len == 0) {
612 throw std::out_of_range("empty 'mandatory' values");
613 }
614 std::set<SvcParam::SvcParamKey> paramKeys;
615 size_t stop = d_pos + len;
616 while (d_pos < stop) {
617 uint16_t keyval;
618 xfr16BitInt(keyval);
619 paramKeys.insert(static_cast<SvcParam::SvcParamKey>(keyval));
620 }
621 kvs.insert(SvcParam(key, std::move(paramKeys)));
622 break;
623 }
624 case SvcParam::alpn: {
625 size_t stop = d_pos + len;
626 std::vector<string> alpns;
627 while (d_pos < stop) {
628 string alpn;
629 uint8_t alpnLen = 0;
630 xfr8BitInt(alpnLen);
631 if (alpnLen == 0) {
632 throw std::out_of_range("alpn length of 0");
633 }
634 xfrBlob(alpn, alpnLen);
635 alpns.push_back(alpn);
636 }
637 kvs.insert(SvcParam(key, std::move(alpns)));
638 break;
639 }
640 case SvcParam::no_default_alpn: {
641 if (len != 0) {
642 throw std::out_of_range("invalid length for no-default-alpn");
643 }
644 kvs.insert(SvcParam(key));
645 break;
646 }
647 case SvcParam::port: {
648 if (len != 2) {
649 throw std::out_of_range("invalid length for port");
650 }
651 uint16_t port;
652 xfr16BitInt(port);
653 kvs.insert(SvcParam(key, port));
654 break;
655 }
656 case SvcParam::ipv4hint: /* fall-through */
657 case SvcParam::ipv6hint: {
658 size_t addrLen = (key == SvcParam::ipv4hint ? 4 : 16);
659 if (len % addrLen != 0) {
660 throw std::out_of_range("invalid length for " + SvcParam::keyToString(key));
661 }
662 vector<ComboAddress> addresses;
663 auto stop = d_pos + len;
664 while (d_pos < stop)
665 {
666 ComboAddress addr;
667 xfrCAWithoutPort(key, addr);
668 addresses.push_back(addr);
669 }
670 kvs.insert(SvcParam(key, std::move(addresses)));
671 break;
672 }
673 case SvcParam::ech: {
674 std::string blob;
675 blob.reserve(len);
676 xfrBlobNoSpaces(blob, len);
677 kvs.insert(SvcParam(key, blob));
678 break;
679 }
680 default: {
681 std::string blob;
682 blob.reserve(len);
683 xfrBlob(blob, len);
684 kvs.insert(SvcParam(key, blob));
685 break;
686 }
687 }
688 }
689 }
690
691
692 void PacketReader::xfrHexBlob(string& blob, bool /* keepReading */)
693 {
694 xfrBlob(blob);
695 }
696
697 //FIXME400 remove this method completely
698 string simpleCompress(const string& elabel, const string& root)
699 {
700 string label=elabel;
701 // FIXME400: this relies on the semi-canonical escaped output from getName
702 if(strchr(label.c_str(), '\\')) {
703 boost::replace_all(label, "\\.", ".");
704 boost::replace_all(label, "\\032", " ");
705 boost::replace_all(label, "\\\\", "\\");
706 }
707 typedef vector<pair<unsigned int, unsigned int> > parts_t;
708 parts_t parts;
709 vstringtok(parts, label, ".");
710 string ret;
711 ret.reserve(label.size()+4);
712 for(const auto & part : parts) {
713 if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + part.first, 1 + label.length() - part.first)) { // also match trailing 0, hence '1 +'
714 const unsigned char rootptr[2]={0xc0,0x11};
715 ret.append((const char *) rootptr, 2);
716 return ret;
717 }
718 ret.append(1, (char)(part.second - part.first));
719 ret.append(label.c_str() + part.first, part.second - part.first);
720 }
721 ret.append(1, (char)0);
722 return ret;
723 }
724
725 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
726 void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor)
727 {
728 if(length < sizeof(dnsheader))
729 return;
730 try
731 {
732 dnsheader dh;
733 memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
734 uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
735 DNSPacketMangler dpm(packet, length);
736
737 uint64_t n;
738 for(n=0; n < ntohs(dh.qdcount) ; ++n) {
739 dpm.skipDomainName();
740 /* type and class */
741 dpm.skipBytes(4);
742 }
743
744 for(n=0; n < numrecords; ++n) {
745 dpm.skipDomainName();
746
747 uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
748 uint16_t dnstype = dpm.get16BitInt();
749 uint16_t dnsclass = dpm.get16BitInt();
750
751 if(dnstype == QType::OPT) // not getting near that one with a stick
752 break;
753
754 uint32_t dnsttl = dpm.get32BitInt();
755 uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
756 if (newttl) {
757 dpm.rewindBytes(sizeof(newttl));
758 dpm.setAndSkip32BitInt(newttl);
759 }
760 dpm.skipRData();
761 }
762 }
763 catch(...)
764 {
765 return;
766 }
767 }
768
769 static bool checkIfPacketContainsRecords(const PacketBuffer& packet, const std::unordered_set<QType>& qtypes)
770 {
771 auto length = packet.size();
772 if (length < sizeof(dnsheader)) {
773 return false;
774 }
775
776 try {
777 const dnsheader_aligned dh(packet.data());
778 DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(packet.data())), length);
779
780 const uint16_t qdcount = ntohs(dh->qdcount);
781 for (size_t n = 0; n < qdcount; ++n) {
782 dpm.skipDomainName();
783 /* type and class */
784 dpm.skipBytes(4);
785 }
786 const size_t recordsCount = static_cast<size_t>(ntohs(dh->ancount)) + ntohs(dh->nscount) + ntohs(dh->arcount);
787 for (size_t n = 0; n < recordsCount; ++n) {
788 dpm.skipDomainName();
789 uint16_t dnstype = dpm.get16BitInt();
790 uint16_t dnsclass = dpm.get16BitInt();
791 if (dnsclass == QClass::IN && qtypes.count(dnstype) > 0) {
792 return true;
793 }
794 /* ttl */
795 dpm.skipBytes(4);
796 dpm.skipRData();
797 }
798 }
799 catch (...) {
800 }
801
802 return false;
803 }
804
805 static int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, PacketBuffer& newContent, const std::unordered_set<QType>& qtypes)
806 {
807 static const std::unordered_set<QType>& safeTypes{QType::A, QType::AAAA, QType::DHCID, QType::TXT, QType::OPT, QType::HINFO, QType::DNSKEY, QType::CDNSKEY, QType::DS, QType::CDS, QType::DLV, QType::SSHFP, QType::KEY, QType::CERT, QType::TLSA, QType::SMIMEA, QType::OPENPGPKEY, QType::SVCB, QType::HTTPS, QType::NSEC3, QType::CSYNC, QType::NSEC3PARAM, QType::LOC, QType::NID, QType::L32, QType::L64, QType::EUI48, QType::EUI64, QType::URI, QType::CAA};
808
809 if (initialPacket.size() < sizeof(dnsheader)) {
810 return EINVAL;
811 }
812 try {
813 const dnsheader_aligned dh(initialPacket.data());
814
815 if (ntohs(dh->qdcount) == 0)
816 return ENOENT;
817 auto packetView = std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size());
818
819 PacketReader pr(packetView);
820
821 size_t idx = 0;
822 DNSName rrname;
823 uint16_t qdcount = ntohs(dh->qdcount);
824 uint16_t ancount = ntohs(dh->ancount);
825 uint16_t nscount = ntohs(dh->nscount);
826 uint16_t arcount = ntohs(dh->arcount);
827 uint16_t rrtype;
828 uint16_t rrclass;
829 string blob;
830 struct dnsrecordheader ah;
831
832 rrname = pr.getName();
833 rrtype = pr.get16BitInt();
834 rrclass = pr.get16BitInt();
835
836 GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
837 pw.getHeader()->id=dh->id;
838 pw.getHeader()->qr=dh->qr;
839 pw.getHeader()->aa=dh->aa;
840 pw.getHeader()->tc=dh->tc;
841 pw.getHeader()->rd=dh->rd;
842 pw.getHeader()->ra=dh->ra;
843 pw.getHeader()->ad=dh->ad;
844 pw.getHeader()->cd=dh->cd;
845 pw.getHeader()->rcode=dh->rcode;
846
847 /* consume remaining qd if any */
848 if (qdcount > 1) {
849 for(idx = 1; idx < qdcount; idx++) {
850 rrname = pr.getName();
851 rrtype = pr.get16BitInt();
852 rrclass = pr.get16BitInt();
853 (void) rrtype;
854 (void) rrclass;
855 }
856 }
857
858 /* copy AN */
859 for (idx = 0; idx < ancount; idx++) {
860 rrname = pr.getName();
861 pr.getDnsrecordheader(ah);
862 pr.xfrBlob(blob);
863
864 if (qtypes.find(ah.d_type) == qtypes.end()) {
865 // if this is not a safe type
866 if (safeTypes.find(ah.d_type) == safeTypes.end()) {
867 // "unsafe" types might countain compressed data, so cancel rewrite
868 newContent.clear();
869 return EIO;
870 }
871 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
872 pw.xfrBlob(blob);
873 }
874 }
875
876 /* copy NS */
877 for (idx = 0; idx < nscount; idx++) {
878 rrname = pr.getName();
879 pr.getDnsrecordheader(ah);
880 pr.xfrBlob(blob);
881
882 if (qtypes.find(ah.d_type) == qtypes.end()) {
883 if (safeTypes.find(ah.d_type) == safeTypes.end()) {
884 // "unsafe" types might countain compressed data, so cancel rewrite
885 newContent.clear();
886 return EIO;
887 }
888 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
889 pw.xfrBlob(blob);
890 }
891 }
892 /* copy AR */
893 for (idx = 0; idx < arcount; idx++) {
894 rrname = pr.getName();
895 pr.getDnsrecordheader(ah);
896 pr.xfrBlob(blob);
897
898 if (qtypes.find(ah.d_type) == qtypes.end()) {
899 if (safeTypes.find(ah.d_type) == safeTypes.end()) {
900 // "unsafe" types might countain compressed data, so cancel rewrite
901 newContent.clear();
902 return EIO;
903 }
904 pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
905 pw.xfrBlob(blob);
906 }
907 }
908 pw.commit();
909
910 }
911 catch (...)
912 {
913 newContent.clear();
914 return EIO;
915 }
916 return 0;
917 }
918
919 void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::unordered_set<QType>& qtypes)
920 {
921 return clearDNSPacketRecordTypes(reinterpret_cast<PacketBuffer&>(packet), qtypes);
922 }
923
924 void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::unordered_set<QType>& qtypes)
925 {
926 if (!checkIfPacketContainsRecords(packet, qtypes)) {
927 return;
928 }
929
930 PacketBuffer newContent;
931
932 auto result = rewritePacketWithoutRecordTypes(packet, newContent, qtypes);
933 if (!result) {
934 packet = std::move(newContent);
935 }
936 }
937
938 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
939 void ageDNSPacket(char* packet, size_t length, uint32_t seconds, const dnsheader_aligned& aligned_dh)
940 {
941 if (length < sizeof(dnsheader)) {
942 return;
943 }
944 try {
945 const dnsheader* dhp = aligned_dh.get();
946 const uint64_t dqcount = ntohs(dhp->qdcount);
947 const uint64_t numrecords = ntohs(dhp->ancount) + ntohs(dhp->nscount) + ntohs(dhp->arcount);
948 DNSPacketMangler dpm(packet, length);
949
950 for (uint64_t rec = 0; rec < dqcount; ++rec) {
951 dpm.skipDomainName();
952 /* type and class */
953 dpm.skipBytes(4);
954 }
955
956 for(uint64_t rec = 0; rec < numrecords; ++rec) {
957 dpm.skipDomainName();
958
959 uint16_t dnstype = dpm.get16BitInt();
960 /* class */
961 dpm.skipBytes(2);
962
963 if (dnstype != QType::OPT) { // not aging that one with a stick
964 dpm.decreaseAndSkip32BitInt(seconds);
965 } else {
966 dpm.skipBytes(4);
967 }
968 dpm.skipRData();
969 }
970 }
971 catch(...) {
972 }
973 }
974
975 void ageDNSPacket(std::string& packet, uint32_t seconds, const dnsheader_aligned& aligned_dh)
976 {
977 ageDNSPacket(packet.data(), packet.length(), seconds, aligned_dh);
978 }
979
980 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA)
981 {
982 uint32_t result = std::numeric_limits<uint32_t>::max();
983 if(length < sizeof(dnsheader)) {
984 return result;
985 }
986 try
987 {
988 const dnsheader_aligned dh(packet);
989 DNSPacketMangler dpm(const_cast<char*>(packet), length);
990
991 const uint16_t qdcount = ntohs(dh->qdcount);
992 for(size_t n = 0; n < qdcount; ++n) {
993 dpm.skipDomainName();
994 /* type and class */
995 dpm.skipBytes(4);
996 }
997 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
998 for(size_t n = 0; n < numrecords; ++n) {
999 dpm.skipDomainName();
1000 const uint16_t dnstype = dpm.get16BitInt();
1001 /* class */
1002 const uint16_t dnsclass = dpm.get16BitInt();
1003
1004 if(dnstype == QType::OPT) {
1005 break;
1006 }
1007
1008 /* report it if we see a SOA record in the AUTHORITY section */
1009 if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) {
1010 *seenAuthSOA = true;
1011 }
1012
1013 const uint32_t ttl = dpm.get32BitInt();
1014 if (result > ttl) {
1015 result = ttl;
1016 }
1017
1018 dpm.skipRData();
1019 }
1020 }
1021 catch(...)
1022 {
1023 }
1024 return result;
1025 }
1026
1027 uint32_t getDNSPacketLength(const char* packet, size_t length)
1028 {
1029 uint32_t result = length;
1030 if(length < sizeof(dnsheader)) {
1031 return result;
1032 }
1033 try
1034 {
1035 const dnsheader_aligned dh(packet);
1036 DNSPacketMangler dpm(const_cast<char*>(packet), length);
1037
1038 const uint16_t qdcount = ntohs(dh->qdcount);
1039 for(size_t n = 0; n < qdcount; ++n) {
1040 dpm.skipDomainName();
1041 /* type and class */
1042 dpm.skipBytes(4);
1043 }
1044 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1045 for(size_t n = 0; n < numrecords; ++n) {
1046 dpm.skipDomainName();
1047 /* type (2), class (2) and ttl (4) */
1048 dpm.skipBytes(8);
1049 dpm.skipRData();
1050 }
1051 result = dpm.getOffset();
1052 }
1053 catch(...)
1054 {
1055 }
1056 return result;
1057 }
1058
1059 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type)
1060 {
1061 uint16_t result = 0;
1062 if(length < sizeof(dnsheader)) {
1063 return result;
1064 }
1065 try
1066 {
1067 const dnsheader_aligned dh(packet);
1068 DNSPacketMangler dpm(const_cast<char*>(packet), length);
1069
1070 const uint16_t qdcount = ntohs(dh->qdcount);
1071 for(size_t n = 0; n < qdcount; ++n) {
1072 dpm.skipDomainName();
1073 if (section == 0) {
1074 uint16_t dnstype = dpm.get16BitInt();
1075 if (dnstype == type) {
1076 result++;
1077 }
1078 /* class */
1079 dpm.skipBytes(2);
1080 } else {
1081 /* type and class */
1082 dpm.skipBytes(4);
1083 }
1084 }
1085 const uint16_t ancount = ntohs(dh->ancount);
1086 for(size_t n = 0; n < ancount; ++n) {
1087 dpm.skipDomainName();
1088 if (section == 1) {
1089 uint16_t dnstype = dpm.get16BitInt();
1090 if (dnstype == type) {
1091 result++;
1092 }
1093 /* class */
1094 dpm.skipBytes(2);
1095 } else {
1096 /* type and class */
1097 dpm.skipBytes(4);
1098 }
1099 /* ttl */
1100 dpm.skipBytes(4);
1101 dpm.skipRData();
1102 }
1103 const uint16_t nscount = ntohs(dh->nscount);
1104 for(size_t n = 0; n < nscount; ++n) {
1105 dpm.skipDomainName();
1106 if (section == 2) {
1107 uint16_t dnstype = dpm.get16BitInt();
1108 if (dnstype == type) {
1109 result++;
1110 }
1111 /* class */
1112 dpm.skipBytes(2);
1113 } else {
1114 /* type and class */
1115 dpm.skipBytes(4);
1116 }
1117 /* ttl */
1118 dpm.skipBytes(4);
1119 dpm.skipRData();
1120 }
1121 const uint16_t arcount = ntohs(dh->arcount);
1122 for(size_t n = 0; n < arcount; ++n) {
1123 dpm.skipDomainName();
1124 if (section == 3) {
1125 uint16_t dnstype = dpm.get16BitInt();
1126 if (dnstype == type) {
1127 result++;
1128 }
1129 /* class */
1130 dpm.skipBytes(2);
1131 } else {
1132 /* type and class */
1133 dpm.skipBytes(4);
1134 }
1135 /* ttl */
1136 dpm.skipBytes(4);
1137 dpm.skipRData();
1138 }
1139 }
1140 catch(...)
1141 {
1142 }
1143 return result;
1144 }
1145
1146 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
1147 {
1148 if (length < sizeof(dnsheader)) {
1149 return false;
1150 }
1151
1152 *payloadSize = 0;
1153 *z = 0;
1154
1155 try
1156 {
1157 const dnsheader_aligned dh(packet);
1158 DNSPacketMangler dpm(const_cast<char*>(packet), length);
1159
1160 const uint16_t qdcount = ntohs(dh->qdcount);
1161 for(size_t n = 0; n < qdcount; ++n) {
1162 dpm.skipDomainName();
1163 /* type and class */
1164 dpm.skipBytes(4);
1165 }
1166 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1167 for(size_t n = 0; n < numrecords; ++n) {
1168 dpm.skipDomainName();
1169 const uint16_t dnstype = dpm.get16BitInt();
1170 const uint16_t dnsclass = dpm.get16BitInt();
1171
1172 if(dnstype == QType::OPT) {
1173 /* skip extended rcode and version */
1174 dpm.skipBytes(2);
1175 *z = dpm.get16BitInt();
1176 *payloadSize = dnsclass;
1177 return true;
1178 }
1179
1180 /* TTL */
1181 dpm.skipBytes(4);
1182 dpm.skipRData();
1183 }
1184 }
1185 catch(...)
1186 {
1187 }
1188
1189 return false;
1190 }
1191
1192 bool visitDNSPacket(const std::string_view& packet, const std::function<bool(uint8_t, uint16_t, uint16_t, uint32_t, uint16_t, const char*)>& visitor)
1193 {
1194 if (packet.size() < sizeof(dnsheader)) {
1195 return false;
1196 }
1197
1198 try
1199 {
1200 const dnsheader_aligned dh(packet.data());
1201 uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1202 PacketReader reader(packet);
1203
1204 uint64_t n;
1205 for (n = 0; n < ntohs(dh->qdcount) ; ++n) {
1206 (void) reader.getName();
1207 /* type and class */
1208 reader.skip(4);
1209 }
1210
1211 for (n = 0; n < numrecords; ++n) {
1212 (void) reader.getName();
1213
1214 uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
1215 uint16_t dnstype = reader.get16BitInt();
1216 uint16_t dnsclass = reader.get16BitInt();
1217
1218 if (dnstype == QType::OPT) {
1219 // not getting near that one with a stick
1220 break;
1221 }
1222
1223 uint32_t dnsttl = reader.get32BitInt();
1224 uint16_t contentLength = reader.get16BitInt();
1225 uint16_t pos = reader.getPosition();
1226
1227 bool done = visitor(section, dnsclass, dnstype, dnsttl, contentLength, &packet.at(pos));
1228 if (done) {
1229 return true;
1230 }
1231
1232 reader.skip(contentLength);
1233 }
1234 }
1235 catch (...) {
1236 return false;
1237 }
1238
1239 return true;
1240 }