]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsparser.cc
Merge pull request #6838 from mind04/autoserial
[thirdparty/pdns.git] / pdns / dnsparser.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsparser.hh"
23 #include "dnswriter.hh"
24 #include <boost/algorithm/string.hpp>
25 #include <boost/format.hpp>
26
27 #include "namespaces.hh"
28
29 class UnknownRecordContent : public DNSRecordContent
30 {
31 public:
32 UnknownRecordContent(const DNSRecord& dr, PacketReader& pr)
33 : d_dr(dr)
34 {
35 pr.copyRecord(d_record, dr.d_clen);
36 }
37
38 UnknownRecordContent(const string& zone)
39 {
40 // parse the input
41 vector<string> parts;
42 stringtok(parts, zone);
43 if(parts.size()!=3 && !(parts.size()==2 && equals(parts[1],"0")) )
44 throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got "+std::to_string(parts.size())+": "+zone );
45 const string& relevant=(parts.size() > 2) ? parts[2] : "";
46 unsigned int total=pdns_stou(parts[1]);
47 if(relevant.size() % 2 || relevant.size() / 2 != total)
48 throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str());
49 string out;
50 out.reserve(total+1);
51 for(unsigned int n=0; n < total; ++n) {
52 int c;
53 sscanf(relevant.c_str()+2*n, "%02x", &c);
54 out.append(1, (char)c);
55 }
56
57 d_record.insert(d_record.end(), out.begin(), out.end());
58 }
59
60 string getZoneRepresentation(bool noDot) const override
61 {
62 ostringstream str;
63 str<<"\\# "<<(unsigned int)d_record.size()<<" ";
64 char hex[4];
65 for(size_t n=0; n<d_record.size(); ++n) {
66 snprintf(hex,sizeof(hex)-1, "%02x", d_record.at(n));
67 str << hex;
68 }
69 return str.str();
70 }
71
72 void toPacket(DNSPacketWriter& pw) override
73 {
74 pw.xfrBlob(string(d_record.begin(),d_record.end()));
75 }
76
77 uint16_t getType() const override
78 {
79 return d_dr.d_type;
80 }
81 private:
82 DNSRecord d_dr;
83 vector<uint8_t> d_record;
84 };
85
86 shared_ptr<DNSRecordContent> DNSRecordContent::unserialize(const DNSName& qname, uint16_t qtype, const string& serialized)
87 {
88 dnsheader dnsheader;
89 memset(&dnsheader, 0, sizeof(dnsheader));
90 dnsheader.qdcount=htons(1);
91 dnsheader.ancount=htons(1);
92
93 vector<uint8_t> packet; // build pseudo packet
94
95 /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
96
97 string encoded=qname.toDNSString();
98
99 packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size());
100
101 uint16_t pos=0;
102
103 memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader);
104
105 char tmp[6]="\x0" "\x0\x1" "\x0\x1"; // root question for ns_t_a
106 memcpy(&packet[pos], &tmp, 5); pos+=5;
107
108 memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size();
109
110 struct dnsrecordheader drh;
111 drh.d_type=htons(qtype);
112 drh.d_class=htons(QClass::IN);
113 drh.d_ttl=0;
114 drh.d_clen=htons(serialized.size());
115
116 memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh);
117 memcpy(&packet[pos], serialized.c_str(), serialized.size()); pos+=(uint16_t)serialized.size();
118
119 MOADNSParser mdp(false, (char*)&*packet.begin(), (unsigned int)packet.size());
120 shared_ptr<DNSRecordContent> ret= mdp.d_answers.begin()->first.d_content;
121 return ret;
122 }
123
124 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr,
125 PacketReader& pr)
126 {
127 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
128
129 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
130 if(i==getTypemap().end() || !i->second) {
131 return std::make_shared<UnknownRecordContent>(dr, pr);
132 }
133
134 return i->second(dr, pr);
135 }
136
137 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(uint16_t qtype, uint16_t qclass,
138 const string& content)
139 {
140 zmakermap_t::const_iterator i=getZmakermap().find(make_pair(qclass, qtype));
141 if(i==getZmakermap().end()) {
142 return std::make_shared<UnknownRecordContent>(content);
143 }
144
145 return i->second(content);
146 }
147
148 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr, PacketReader& pr, uint16_t oc) {
149 // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
150 // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
151 // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
152 if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1)
153 return std::make_shared<UnknownRecordContent>(dr, pr);
154
155 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
156
157 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
158 if(i==getTypemap().end() || !i->second) {
159 return std::make_shared<UnknownRecordContent>(dr, pr);
160 }
161
162 return i->second(dr, pr);
163 }
164
165
166 DNSRecordContent::typemap_t& DNSRecordContent::getTypemap()
167 {
168 static DNSRecordContent::typemap_t typemap;
169 return typemap;
170 }
171
172 DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap()
173 {
174 static DNSRecordContent::n2typemap_t n2typemap;
175 return n2typemap;
176 }
177
178 DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap()
179 {
180 static DNSRecordContent::t2namemap_t t2namemap;
181 return t2namemap;
182 }
183
184 DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap()
185 {
186 static DNSRecordContent::zmakermap_t zmakermap;
187 return zmakermap;
188 }
189
190 DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname)
191 {
192 d_type = rr.qtype.getCode();
193 d_ttl = rr.ttl;
194 d_class = rr.qclass;
195 d_place = DNSResourceRecord::ANSWER;
196 d_clen = 0;
197 d_content = DNSRecordContent::mastermake(d_type, rr.qclass, rr.content);
198 }
199
200 // If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
201 DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& d) {
202 DNSResourceRecord rr;
203 rr.qname = d.d_name;
204 rr.qtype = QType(d.d_type);
205 rr.ttl = d.d_ttl;
206 rr.content = d.d_content->getZoneRepresentation(true);
207 rr.auth = false;
208 rr.qclass = d.d_class;
209 return rr;
210 }
211
212 void MOADNSParser::init(bool query, const std::string& packet)
213 {
214 if (packet.size() < sizeof(dnsheader))
215 throw MOADNSException("Packet shorter than minimal header");
216
217 memcpy(&d_header, packet.data(), sizeof(dnsheader));
218
219 if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
220 throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
221
222 d_header.qdcount=ntohs(d_header.qdcount);
223 d_header.ancount=ntohs(d_header.ancount);
224 d_header.nscount=ntohs(d_header.nscount);
225 d_header.arcount=ntohs(d_header.arcount);
226
227 if (query && (d_header.qdcount > 1))
228 throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
229
230 unsigned int n=0;
231
232 PacketReader pr(packet);
233 bool validPacket=false;
234 try {
235 d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
236
237 for(n=0;n < d_header.qdcount; ++n) {
238 d_qname=pr.getName();
239 d_qtype=pr.get16BitInt();
240 d_qclass=pr.get16BitInt();
241 }
242
243 struct dnsrecordheader ah;
244 vector<unsigned char> record;
245 bool seenTSIG = false;
246 validPacket=true;
247 d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
248 for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
249 DNSRecord dr;
250
251 if(n < d_header.ancount)
252 dr.d_place=DNSResourceRecord::ANSWER;
253 else if(n < d_header.ancount + d_header.nscount)
254 dr.d_place=DNSResourceRecord::AUTHORITY;
255 else
256 dr.d_place=DNSResourceRecord::ADDITIONAL;
257
258 unsigned int recordStartPos=pr.getPosition();
259
260 DNSName name=pr.getName();
261
262 pr.getDnsrecordheader(ah);
263 dr.d_ttl=ah.d_ttl;
264 dr.d_type=ah.d_type;
265 dr.d_class=ah.d_class;
266
267 dr.d_name=name;
268 dr.d_clen=ah.d_clen;
269
270 if (query &&
271 !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section
272 (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) {
273 // cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
274 dr.d_content=std::make_shared<UnknownRecordContent>(dr, pr);
275 }
276 else {
277 // cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
278 dr.d_content=DNSRecordContent::mastermake(dr, pr, d_header.opcode);
279 }
280
281 d_answers.push_back(make_pair(dr, pr.getPosition() - sizeof(dnsheader)));
282
283 /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned:
284 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF)
285 */
286 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) {
287 /* only XPF records are allowed after a TSIG */
288 throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one.");
289 }
290
291 if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) {
292 if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) {
293 throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
294 }
295 seenTSIG = true;
296 d_tsigPos = recordStartPos;
297 }
298 }
299
300 #if 0
301 if(pr.getPosition()!=packet.size()) {
302 throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
303 std::to_string(packet.size()) + ")");
304 }
305 #endif
306 }
307 catch(const std::out_of_range &re) {
308 if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
309 if(n < d_header.ancount) {
310 d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
311 }
312 else if(n < d_header.ancount + d_header.nscount) {
313 d_header.nscount = n - d_header.ancount; d_header.arcount=0;
314 }
315 else {
316 d_header.arcount = n - d_header.ancount - d_header.nscount;
317 }
318 }
319 else {
320 throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
321 std::to_string(d_header.rd)+
322 "), out of bounds: "+string(re.what()));
323 }
324 }
325 }
326
327
328 void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
329 {
330 unsigned int n;
331 unsigned char *p=reinterpret_cast<unsigned char*>(&ah);
332
333 for(n=0; n < sizeof(dnsrecordheader); ++n)
334 p[n]=d_content.at(d_pos++);
335
336 ah.d_type=ntohs(ah.d_type);
337 ah.d_class=ntohs(ah.d_class);
338 ah.d_clen=ntohs(ah.d_clen);
339 ah.d_ttl=ntohl(ah.d_ttl);
340
341 d_startrecordpos=d_pos; // needed for getBlob later on
342 d_recordlen=ah.d_clen;
343 }
344
345
346 void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len)
347 {
348 dest.resize(len);
349 if(!len)
350 return;
351
352 for(uint16_t n=0;n<len;++n) {
353 dest.at(n)=d_content.at(d_pos++);
354 }
355 }
356
357 void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
358 {
359 if(d_pos + len > d_content.size())
360 throw std::out_of_range("Attempt to copy outside of packet");
361
362 memcpy(dest, &d_content.at(d_pos), len);
363 d_pos+=len;
364 }
365
366 void PacketReader::xfr48BitInt(uint64_t& ret)
367 {
368 ret=0;
369 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
370 ret<<=8;
371 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
372 ret<<=8;
373 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
374 ret<<=8;
375 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
376 ret<<=8;
377 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
378 ret<<=8;
379 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
380 }
381
382 uint32_t PacketReader::get32BitInt()
383 {
384 uint32_t ret=0;
385 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
386 ret<<=8;
387 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
388 ret<<=8;
389 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
390 ret<<=8;
391 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
392
393 return ret;
394 }
395
396
397 uint16_t PacketReader::get16BitInt()
398 {
399 uint16_t ret=0;
400 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
401 ret<<=8;
402 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
403
404 return ret;
405 }
406
407 uint8_t PacketReader::get8BitInt()
408 {
409 return d_content.at(d_pos++);
410 }
411
412 DNSName PacketReader::getName()
413 {
414 unsigned int consumed;
415 try {
416 DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, 0 /* qtype */, 0 /* qclass */, &consumed, sizeof(dnsheader));
417
418 d_pos+=consumed;
419 return dn;
420 }
421 catch(const std::range_error& re) {
422 throw std::out_of_range(string("dnsname issue: ")+re.what());
423 }
424 catch(...) {
425 throw std::out_of_range("dnsname issue");
426 }
427 throw PDNSException("PacketReader::getName(): name is empty");
428 }
429
430 static string txtEscape(const string &name)
431 {
432 string ret;
433 char ebuf[5];
434
435 for(string::const_iterator i=name.begin();i!=name.end();++i) {
436 if((unsigned char) *i >= 127 || (unsigned char) *i < 32) {
437 snprintf(ebuf, sizeof(ebuf), "\\%03u", (unsigned char)*i);
438 ret += ebuf;
439 }
440 else if(*i=='"' || *i=='\\'){
441 ret += '\\';
442 ret += *i;
443 }
444 else
445 ret += *i;
446 }
447 return ret;
448 }
449
450 // exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
451 string PacketReader::getText(bool multi, bool lenField)
452 {
453 string ret;
454 ret.reserve(40);
455 while(d_pos < d_startrecordpos + d_recordlen ) {
456 if(!ret.empty()) {
457 ret.append(1,' ');
458 }
459 uint16_t labellen;
460 if(lenField)
461 labellen=static_cast<uint8_t>(d_content.at(d_pos++));
462 else
463 labellen=d_recordlen - (d_pos - d_startrecordpos);
464
465 ret.append(1,'"');
466 if(labellen) { // no need to do anything for an empty string
467 string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
468 ret.append(txtEscape(val)); // the end is one beyond the packet
469 }
470 ret.append(1,'"');
471 d_pos+=labellen;
472 if(!multi)
473 break;
474 }
475
476 return ret;
477 }
478
479 string PacketReader::getUnquotedText(bool lenField)
480 {
481 uint16_t stop_at;
482 if(lenField)
483 stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
484 else
485 stop_at = d_recordlen;
486
487 if(stop_at == d_pos)
488 return "";
489
490 d_pos++;
491 string ret(&d_content.at(d_pos), &d_content.at(stop_at));
492 d_pos = stop_at;
493 return ret;
494 }
495
496 void PacketReader::xfrBlob(string& blob)
497 try
498 {
499 if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen)))
500 blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1);
501 else
502 blob.clear();
503
504 d_pos = d_startrecordpos + d_recordlen;
505 }
506 catch(...)
507 {
508 throw std::out_of_range("xfrBlob out of range");
509 }
510
511 void PacketReader::xfrBlobNoSpaces(string& blob, int length) {
512 xfrBlob(blob, length);
513 }
514
515 void PacketReader::xfrBlob(string& blob, int length)
516 {
517 if(length) {
518 blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 );
519
520 d_pos += length;
521 }
522 else
523 blob.clear();
524 }
525
526
527 void PacketReader::xfrHexBlob(string& blob, bool keepReading)
528 {
529 xfrBlob(blob);
530 }
531
532 //FIXME400 remove this method completely
533 string simpleCompress(const string& elabel, const string& root)
534 {
535 string label=elabel;
536 // FIXME400: this relies on the semi-canonical escaped output from getName
537 if(strchr(label.c_str(), '\\')) {
538 boost::replace_all(label, "\\.", ".");
539 boost::replace_all(label, "\\032", " ");
540 boost::replace_all(label, "\\\\", "\\");
541 }
542 typedef vector<pair<unsigned int, unsigned int> > parts_t;
543 parts_t parts;
544 vstringtok(parts, label, ".");
545 string ret;
546 ret.reserve(label.size()+4);
547 for(parts_t::const_iterator i=parts.begin(); i!=parts.end(); ++i) {
548 if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + i->first, 1 + label.length() - i->first)) { // also match trailing 0, hence '1 +'
549 const unsigned char rootptr[2]={0xc0,0x11};
550 ret.append((const char *) rootptr, 2);
551 return ret;
552 }
553 ret.append(1, (char)(i->second - i->first));
554 ret.append(label.c_str() + i->first, i->second - i->first);
555 }
556 ret.append(1, (char)0);
557 return ret;
558 }
559
560
561 /** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
562 * If you survive that, feel free to read from the pointer */
563 class DNSPacketMangler
564 {
565 public:
566 explicit DNSPacketMangler(std::string& packet)
567 : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
568 {}
569 DNSPacketMangler(char* packet, size_t length)
570 : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
571 {}
572
573 /*! Advances past a wire-format domain name
574 * The name is not checked for adherence to length restrictions.
575 * Compression pointers are not followed.
576 */
577 void skipDomainName()
578 {
579 uint8_t len;
580 while((len=get8BitInt())) {
581 if(len >= 0xc0) { // extended label
582 get8BitInt();
583 return;
584 }
585 skipBytes(len);
586 }
587 }
588
589 void skipBytes(uint16_t bytes)
590 {
591 moveOffset(bytes);
592 }
593 void rewindBytes(uint16_t by)
594 {
595 rewindOffset(by);
596 }
597 uint32_t get32BitInt()
598 {
599 const char* p = d_packet + d_offset;
600 moveOffset(4);
601 uint32_t ret;
602 memcpy(&ret, (void*)p, sizeof(ret));
603 return ntohl(ret);
604 }
605 uint16_t get16BitInt()
606 {
607 const char* p = d_packet + d_offset;
608 moveOffset(2);
609 uint16_t ret;
610 memcpy(&ret, (void*)p, sizeof(ret));
611 return ntohs(ret);
612 }
613
614 uint8_t get8BitInt()
615 {
616 const char* p = d_packet + d_offset;
617 moveOffset(1);
618 return *p;
619 }
620
621 void skipRData()
622 {
623 int toskip = get16BitInt();
624 moveOffset(toskip);
625 }
626
627 void decreaseAndSkip32BitInt(uint32_t decrease)
628 {
629 const char *p = d_packet + d_offset;
630 moveOffset(4);
631
632 uint32_t tmp;
633 memcpy(&tmp, (void*) p, sizeof(tmp));
634 tmp = ntohl(tmp);
635 tmp-=decrease;
636 tmp = htonl(tmp);
637 memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
638 }
639 void setAndSkip32BitInt(uint32_t value)
640 {
641 moveOffset(4);
642
643 value = htonl(value);
644 memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
645 }
646 uint32_t getOffset() const
647 {
648 return d_offset;
649 }
650 private:
651 void moveOffset(uint16_t by)
652 {
653 d_notyouroffset += by;
654 if(d_notyouroffset > d_length)
655 throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
656 + std::to_string(d_length) );
657 }
658 void rewindOffset(uint16_t by)
659 {
660 if(d_notyouroffset < by)
661 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
662 + std::to_string(by));
663 d_notyouroffset -= by;
664 if(d_notyouroffset < 12)
665 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
666 + std::to_string(12));
667 }
668 char* d_packet;
669 size_t d_length;
670
671 uint32_t d_notyouroffset; // only 'moveOffset' can touch this
672 const uint32_t& d_offset; // look.. but don't touch
673
674 };
675
676 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
677 void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)> visitor)
678 {
679 if(length < sizeof(dnsheader))
680 return;
681 try
682 {
683 dnsheader dh;
684 memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
685 uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
686 DNSPacketMangler dpm(packet, length);
687
688 uint64_t n;
689 for(n=0; n < ntohs(dh.qdcount) ; ++n) {
690 dpm.skipDomainName();
691 /* type and class */
692 dpm.skipBytes(4);
693 }
694
695 for(n=0; n < numrecords; ++n) {
696 dpm.skipDomainName();
697
698 uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
699 uint16_t dnstype = dpm.get16BitInt();
700 uint16_t dnsclass = dpm.get16BitInt();
701
702 if(dnstype == QType::OPT) // not getting near that one with a stick
703 break;
704
705 uint32_t dnsttl = dpm.get32BitInt();
706 uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
707 if (newttl) {
708 dpm.rewindBytes(sizeof(newttl));
709 dpm.setAndSkip32BitInt(newttl);
710 }
711 dpm.skipRData();
712 }
713 }
714 catch(...)
715 {
716 return;
717 }
718 }
719
720 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
721 void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
722 {
723 if(length < sizeof(dnsheader))
724 return;
725 try
726 {
727 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
728 const uint64_t dqcount = ntohs(dh->qdcount);
729 const uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
730 DNSPacketMangler dpm(packet, length);
731
732 uint64_t n;
733 for(n=0; n < dqcount; ++n) {
734 dpm.skipDomainName();
735 /* type and class */
736 dpm.skipBytes(4);
737 }
738 // cerr<<"Skipped "<<n<<" questions, now parsing "<<numrecords<<" records"<<endl;
739 for(n=0; n < numrecords; ++n) {
740 dpm.skipDomainName();
741
742 uint16_t dnstype = dpm.get16BitInt();
743 /* class */
744 dpm.skipBytes(2);
745
746 if(dnstype == QType::OPT) // not aging that one with a stick
747 break;
748
749 dpm.decreaseAndSkip32BitInt(seconds);
750 dpm.skipRData();
751 }
752 }
753 catch(...)
754 {
755 return;
756 }
757 }
758
759 void ageDNSPacket(std::string& packet, uint32_t seconds)
760 {
761 ageDNSPacket((char*)packet.c_str(), packet.length(), seconds);
762 }
763
764 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA)
765 {
766 uint32_t result = std::numeric_limits<uint32_t>::max();
767 if(length < sizeof(dnsheader)) {
768 return result;
769 }
770 try
771 {
772 const dnsheader* dh = (const dnsheader*) packet;
773 DNSPacketMangler dpm(const_cast<char*>(packet), length);
774
775 const uint16_t qdcount = ntohs(dh->qdcount);
776 for(size_t n = 0; n < qdcount; ++n) {
777 dpm.skipDomainName();
778 /* type and class */
779 dpm.skipBytes(4);
780 }
781 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
782 for(size_t n = 0; n < numrecords; ++n) {
783 dpm.skipDomainName();
784 const uint16_t dnstype = dpm.get16BitInt();
785 /* class */
786 const uint16_t dnsclass = dpm.get16BitInt();
787
788 if(dnstype == QType::OPT) {
789 break;
790 }
791
792 /* report it if we see a SOA record in the AUTHORITY section */
793 if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) {
794 *seenAuthSOA = true;
795 }
796
797 const uint32_t ttl = dpm.get32BitInt();
798 if (result > ttl) {
799 result = ttl;
800 }
801
802 dpm.skipRData();
803 }
804 }
805 catch(...)
806 {
807 }
808 return result;
809 }
810
811 uint32_t getDNSPacketLength(const char* packet, size_t length)
812 {
813 uint32_t result = length;
814 if(length < sizeof(dnsheader)) {
815 return result;
816 }
817 try
818 {
819 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
820 DNSPacketMangler dpm(const_cast<char*>(packet), length);
821
822 const uint16_t qdcount = ntohs(dh->qdcount);
823 for(size_t n = 0; n < qdcount; ++n) {
824 dpm.skipDomainName();
825 /* type and class */
826 dpm.skipBytes(4);
827 }
828 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
829 for(size_t n = 0; n < numrecords; ++n) {
830 dpm.skipDomainName();
831 /* type (2), class (2) and ttl (4) */
832 dpm.skipBytes(8);
833 dpm.skipRData();
834 }
835 result = dpm.getOffset();
836 }
837 catch(...)
838 {
839 }
840 return result;
841 }
842
843 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type)
844 {
845 uint16_t result = 0;
846 if(length < sizeof(dnsheader)) {
847 return result;
848 }
849 try
850 {
851 const dnsheader* dh = (const dnsheader*) packet;
852 DNSPacketMangler dpm(const_cast<char*>(packet), length);
853
854 const uint16_t qdcount = ntohs(dh->qdcount);
855 for(size_t n = 0; n < qdcount; ++n) {
856 dpm.skipDomainName();
857 if (section == 0) {
858 uint16_t dnstype = dpm.get16BitInt();
859 if (dnstype == type) {
860 result++;
861 }
862 /* class */
863 dpm.skipBytes(2);
864 } else {
865 /* type and class */
866 dpm.skipBytes(4);
867 }
868 }
869 const uint16_t ancount = ntohs(dh->ancount);
870 for(size_t n = 0; n < ancount; ++n) {
871 dpm.skipDomainName();
872 if (section == 1) {
873 uint16_t dnstype = dpm.get16BitInt();
874 if (dnstype == type) {
875 result++;
876 }
877 /* class */
878 dpm.skipBytes(2);
879 } else {
880 /* type and class */
881 dpm.skipBytes(4);
882 }
883 /* ttl */
884 dpm.skipBytes(4);
885 dpm.skipRData();
886 }
887 const uint16_t nscount = ntohs(dh->nscount);
888 for(size_t n = 0; n < nscount; ++n) {
889 dpm.skipDomainName();
890 if (section == 2) {
891 uint16_t dnstype = dpm.get16BitInt();
892 if (dnstype == type) {
893 result++;
894 }
895 /* class */
896 dpm.skipBytes(2);
897 } else {
898 /* type and class */
899 dpm.skipBytes(4);
900 }
901 /* ttl */
902 dpm.skipBytes(4);
903 dpm.skipRData();
904 }
905 const uint16_t arcount = ntohs(dh->arcount);
906 for(size_t n = 0; n < arcount; ++n) {
907 dpm.skipDomainName();
908 if (section == 3) {
909 uint16_t dnstype = dpm.get16BitInt();
910 if (dnstype == type) {
911 result++;
912 }
913 /* class */
914 dpm.skipBytes(2);
915 } else {
916 /* type and class */
917 dpm.skipBytes(4);
918 }
919 /* ttl */
920 dpm.skipBytes(4);
921 dpm.skipRData();
922 }
923 }
924 catch(...)
925 {
926 }
927 return result;
928 }
929
930 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
931 {
932 if (length < sizeof(dnsheader)) {
933 return false;
934 }
935
936 *payloadSize = 0;
937 *z = 0;
938
939 try
940 {
941 const dnsheader* dh = (const dnsheader*) packet;
942 DNSPacketMangler dpm(const_cast<char*>(packet), length);
943
944 const uint16_t qdcount = ntohs(dh->qdcount);
945 for(size_t n = 0; n < qdcount; ++n) {
946 dpm.skipDomainName();
947 /* type and class */
948 dpm.skipBytes(4);
949 }
950 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
951 for(size_t n = 0; n < numrecords; ++n) {
952 dpm.skipDomainName();
953 const uint16_t dnstype = dpm.get16BitInt();
954 const uint16_t dnsclass = dpm.get16BitInt();
955
956 if(dnstype == QType::OPT) {
957 /* skip extended rcode and version */
958 dpm.skipBytes(2);
959 *z = dpm.get16BitInt();
960 *payloadSize = dnsclass;
961 return true;
962 }
963
964 /* TTL */
965 dpm.skipBytes(4);
966 dpm.skipRData();
967 }
968 }
969 catch(...)
970 {
971 }
972
973 return false;
974 }