]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsparser.cc
Merge pull request #7628 from tcely/patch-3
[thirdparty/pdns.git] / pdns / dnsparser.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsparser.hh"
23 #include "dnswriter.hh"
24 #include <boost/algorithm/string.hpp>
25 #include <boost/format.hpp>
26
27 #include "namespaces.hh"
28
29 class UnknownRecordContent : public DNSRecordContent
30 {
31 public:
32 UnknownRecordContent(const DNSRecord& dr, PacketReader& pr)
33 : d_dr(dr)
34 {
35 pr.copyRecord(d_record, dr.d_clen);
36 }
37
38 UnknownRecordContent(const string& zone)
39 {
40 // parse the input
41 vector<string> parts;
42 stringtok(parts, zone);
43 if(parts.size()!=3 && !(parts.size()==2 && equals(parts[1],"0")) )
44 throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got "+std::to_string(parts.size())+": "+zone );
45 const string& relevant=(parts.size() > 2) ? parts[2] : "";
46 unsigned int total=pdns_stou(parts[1]);
47 if(relevant.size() % 2 || relevant.size() / 2 != total)
48 throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str());
49 string out;
50 out.reserve(total+1);
51 for(unsigned int n=0; n < total; ++n) {
52 int c;
53 sscanf(relevant.c_str()+2*n, "%02x", &c);
54 out.append(1, (char)c);
55 }
56
57 d_record.insert(d_record.end(), out.begin(), out.end());
58 }
59
60 string getZoneRepresentation(bool noDot) const override
61 {
62 ostringstream str;
63 str<<"\\# "<<(unsigned int)d_record.size()<<" ";
64 char hex[4];
65 for(size_t n=0; n<d_record.size(); ++n) {
66 snprintf(hex, sizeof(hex), "%02x", d_record.at(n));
67 str << hex;
68 }
69 return str.str();
70 }
71
72 void toPacket(DNSPacketWriter& pw) override
73 {
74 pw.xfrBlob(string(d_record.begin(),d_record.end()));
75 }
76
77 uint16_t getType() const override
78 {
79 return d_dr.d_type;
80 }
81 private:
82 DNSRecord d_dr;
83 vector<uint8_t> d_record;
84 };
85
86 shared_ptr<DNSRecordContent> DNSRecordContent::unserialize(const DNSName& qname, uint16_t qtype, const string& serialized)
87 {
88 dnsheader dnsheader;
89 memset(&dnsheader, 0, sizeof(dnsheader));
90 dnsheader.qdcount=htons(1);
91 dnsheader.ancount=htons(1);
92
93 vector<uint8_t> packet; // build pseudo packet
94
95 /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
96
97 string encoded=qname.toDNSString();
98
99 packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size());
100
101 uint16_t pos=0;
102
103 memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader);
104
105 char tmp[6]="\x0" "\x0\x1" "\x0\x1"; // root question for ns_t_a
106 memcpy(&packet[pos], &tmp, 5); pos+=5;
107
108 memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size();
109
110 struct dnsrecordheader drh;
111 drh.d_type=htons(qtype);
112 drh.d_class=htons(QClass::IN);
113 drh.d_ttl=0;
114 drh.d_clen=htons(serialized.size());
115
116 memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh);
117 memcpy(&packet[pos], serialized.c_str(), serialized.size()); pos+=(uint16_t)serialized.size();
118
119 MOADNSParser mdp(false, (char*)&*packet.begin(), (unsigned int)packet.size());
120 shared_ptr<DNSRecordContent> ret= mdp.d_answers.begin()->first.d_content;
121 return ret;
122 }
123
124 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr,
125 PacketReader& pr)
126 {
127 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
128
129 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
130 if(i==getTypemap().end() || !i->second) {
131 return std::make_shared<UnknownRecordContent>(dr, pr);
132 }
133
134 return i->second(dr, pr);
135 }
136
137 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(uint16_t qtype, uint16_t qclass,
138 const string& content)
139 {
140 zmakermap_t::const_iterator i=getZmakermap().find(make_pair(qclass, qtype));
141 if(i==getZmakermap().end()) {
142 return std::make_shared<UnknownRecordContent>(content);
143 }
144
145 return i->second(content);
146 }
147
148 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr, PacketReader& pr, uint16_t oc) {
149 // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
150 // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
151 // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
152 if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1)
153 return std::make_shared<UnknownRecordContent>(dr, pr);
154
155 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
156
157 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
158 if(i==getTypemap().end() || !i->second) {
159 return std::make_shared<UnknownRecordContent>(dr, pr);
160 }
161
162 return i->second(dr, pr);
163 }
164
165
166 DNSRecordContent::typemap_t& DNSRecordContent::getTypemap()
167 {
168 static DNSRecordContent::typemap_t typemap;
169 return typemap;
170 }
171
172 DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap()
173 {
174 static DNSRecordContent::n2typemap_t n2typemap;
175 return n2typemap;
176 }
177
178 DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap()
179 {
180 static DNSRecordContent::t2namemap_t t2namemap;
181 return t2namemap;
182 }
183
184 DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap()
185 {
186 static DNSRecordContent::zmakermap_t zmakermap;
187 return zmakermap;
188 }
189
190 DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname)
191 {
192 d_type = rr.qtype.getCode();
193 d_ttl = rr.ttl;
194 d_class = rr.qclass;
195 d_place = DNSResourceRecord::ANSWER;
196 d_clen = 0;
197 d_content = DNSRecordContent::mastermake(d_type, rr.qclass, rr.content);
198 }
199
200 // If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
201 DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& d) {
202 DNSResourceRecord rr;
203 rr.qname = d.d_name;
204 rr.qtype = QType(d.d_type);
205 rr.ttl = d.d_ttl;
206 rr.content = d.d_content->getZoneRepresentation(true);
207 rr.auth = false;
208 rr.qclass = d.d_class;
209 return rr;
210 }
211
212 void MOADNSParser::init(bool query, const std::string& packet)
213 {
214 if (packet.size() < sizeof(dnsheader))
215 throw MOADNSException("Packet shorter than minimal header");
216
217 memcpy(&d_header, packet.data(), sizeof(dnsheader));
218
219 if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
220 throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
221
222 d_header.qdcount=ntohs(d_header.qdcount);
223 d_header.ancount=ntohs(d_header.ancount);
224 d_header.nscount=ntohs(d_header.nscount);
225 d_header.arcount=ntohs(d_header.arcount);
226
227 if (query && (d_header.qdcount > 1))
228 throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
229
230 unsigned int n=0;
231
232 PacketReader pr(packet);
233 bool validPacket=false;
234 try {
235 d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
236
237 for(n=0;n < d_header.qdcount; ++n) {
238 d_qname=pr.getName();
239 d_qtype=pr.get16BitInt();
240 d_qclass=pr.get16BitInt();
241 }
242
243 struct dnsrecordheader ah;
244 vector<unsigned char> record;
245 bool seenTSIG = false;
246 validPacket=true;
247 d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
248 for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
249 DNSRecord dr;
250
251 if(n < d_header.ancount)
252 dr.d_place=DNSResourceRecord::ANSWER;
253 else if(n < d_header.ancount + d_header.nscount)
254 dr.d_place=DNSResourceRecord::AUTHORITY;
255 else
256 dr.d_place=DNSResourceRecord::ADDITIONAL;
257
258 unsigned int recordStartPos=pr.getPosition();
259
260 DNSName name=pr.getName();
261
262 pr.getDnsrecordheader(ah);
263 dr.d_ttl=ah.d_ttl;
264 dr.d_type=ah.d_type;
265 dr.d_class=ah.d_class;
266
267 dr.d_name=name;
268 dr.d_clen=ah.d_clen;
269
270 if (query &&
271 !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section
272 (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) {
273 // cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
274 dr.d_content=std::make_shared<UnknownRecordContent>(dr, pr);
275 }
276 else {
277 // cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
278 dr.d_content=DNSRecordContent::mastermake(dr, pr, d_header.opcode);
279 }
280
281 d_answers.push_back(make_pair(dr, pr.getPosition() - sizeof(dnsheader)));
282
283 /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned:
284 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF)
285 */
286 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) {
287 /* only XPF records are allowed after a TSIG */
288 throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one.");
289 }
290
291 if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) {
292 if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) {
293 throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
294 }
295 seenTSIG = true;
296 d_tsigPos = recordStartPos;
297 }
298 }
299
300 #if 0
301 if(pr.getPosition()!=packet.size()) {
302 throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
303 std::to_string(packet.size()) + ")");
304 }
305 #endif
306 }
307 catch(const std::out_of_range &re) {
308 if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
309 if(n < d_header.ancount) {
310 d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
311 }
312 else if(n < d_header.ancount + d_header.nscount) {
313 d_header.nscount = n - d_header.ancount; d_header.arcount=0;
314 }
315 else {
316 d_header.arcount = n - d_header.ancount - d_header.nscount;
317 }
318 }
319 else {
320 throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
321 std::to_string(d_header.rd)+
322 "), out of bounds: "+string(re.what()));
323 }
324 }
325 }
326
327
328 void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
329 {
330 unsigned int n;
331 unsigned char *p=reinterpret_cast<unsigned char*>(&ah);
332
333 for(n=0; n < sizeof(dnsrecordheader); ++n)
334 p[n]=d_content.at(d_pos++);
335
336 ah.d_type=ntohs(ah.d_type);
337 ah.d_class=ntohs(ah.d_class);
338 ah.d_clen=ntohs(ah.d_clen);
339 ah.d_ttl=ntohl(ah.d_ttl);
340
341 d_startrecordpos=d_pos; // needed for getBlob later on
342 d_recordlen=ah.d_clen;
343 }
344
345
346 void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len)
347 {
348 dest.resize(len);
349 if(!len)
350 return;
351
352 for(uint16_t n=0;n<len;++n) {
353 dest.at(n)=d_content.at(d_pos++);
354 }
355 }
356
357 void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
358 {
359 if(d_pos + len > d_content.size())
360 throw std::out_of_range("Attempt to copy outside of packet");
361
362 memcpy(dest, &d_content.at(d_pos), len);
363 d_pos+=len;
364 }
365
366 void PacketReader::xfr48BitInt(uint64_t& ret)
367 {
368 ret=0;
369 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
370 ret<<=8;
371 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
372 ret<<=8;
373 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
374 ret<<=8;
375 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
376 ret<<=8;
377 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
378 ret<<=8;
379 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
380 }
381
382 uint32_t PacketReader::get32BitInt()
383 {
384 uint32_t ret=0;
385 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
386 ret<<=8;
387 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
388 ret<<=8;
389 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
390 ret<<=8;
391 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
392
393 return ret;
394 }
395
396
397 uint16_t PacketReader::get16BitInt()
398 {
399 uint16_t ret=0;
400 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
401 ret<<=8;
402 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
403
404 return ret;
405 }
406
407 uint8_t PacketReader::get8BitInt()
408 {
409 return d_content.at(d_pos++);
410 }
411
412 DNSName PacketReader::getName()
413 {
414 unsigned int consumed;
415 try {
416 DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, 0 /* qtype */, 0 /* qclass */, &consumed, sizeof(dnsheader));
417
418 d_pos+=consumed;
419 return dn;
420 }
421 catch(const std::range_error& re) {
422 throw std::out_of_range(string("dnsname issue: ")+re.what());
423 }
424 catch(...) {
425 throw std::out_of_range("dnsname issue");
426 }
427 throw PDNSException("PacketReader::getName(): name is empty");
428 }
429
430 static string txtEscape(const string &name)
431 {
432 string ret;
433 char ebuf[5];
434
435 for(string::const_iterator i=name.begin();i!=name.end();++i) {
436 if((unsigned char) *i >= 127 || (unsigned char) *i < 32) {
437 snprintf(ebuf, sizeof(ebuf), "\\%03u", (unsigned char)*i);
438 ret += ebuf;
439 }
440 else if(*i=='"' || *i=='\\'){
441 ret += '\\';
442 ret += *i;
443 }
444 else
445 ret += *i;
446 }
447 return ret;
448 }
449
450 // exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
451 string PacketReader::getText(bool multi, bool lenField)
452 {
453 string ret;
454 ret.reserve(40);
455 while(d_pos < d_startrecordpos + d_recordlen ) {
456 if(!ret.empty()) {
457 ret.append(1,' ');
458 }
459 uint16_t labellen;
460 if(lenField)
461 labellen=static_cast<uint8_t>(d_content.at(d_pos++));
462 else
463 labellen=d_recordlen - (d_pos - d_startrecordpos);
464
465 ret.append(1,'"');
466 if(labellen) { // no need to do anything for an empty string
467 string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
468 ret.append(txtEscape(val)); // the end is one beyond the packet
469 }
470 ret.append(1,'"');
471 d_pos+=labellen;
472 if(!multi)
473 break;
474 }
475
476 return ret;
477 }
478
479 string PacketReader::getUnquotedText(bool lenField)
480 {
481 uint16_t stop_at;
482 if(lenField)
483 stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
484 else
485 stop_at = d_recordlen;
486
487 if(stop_at == d_pos)
488 return "";
489
490 d_pos++;
491 string ret(&d_content.at(d_pos), &d_content.at(stop_at));
492 d_pos = stop_at;
493 return ret;
494 }
495
496 void PacketReader::xfrBlob(string& blob)
497 try
498 {
499 if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen))) {
500 if (d_pos > (d_startrecordpos + d_recordlen)) {
501 throw std::out_of_range("xfrBlob out of record range");
502 }
503 blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1);
504 }
505 else {
506 blob.clear();
507 }
508
509 d_pos = d_startrecordpos + d_recordlen;
510 }
511 catch(...)
512 {
513 throw std::out_of_range("xfrBlob out of range");
514 }
515
516 void PacketReader::xfrBlobNoSpaces(string& blob, int length) {
517 xfrBlob(blob, length);
518 }
519
520 void PacketReader::xfrBlob(string& blob, int length)
521 {
522 if(length) {
523 if (length < 0) {
524 throw std::out_of_range("xfrBlob out of range (negative length)");
525 }
526
527 blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 );
528
529 d_pos += length;
530 }
531 else {
532 blob.clear();
533 }
534 }
535
536
537 void PacketReader::xfrHexBlob(string& blob, bool keepReading)
538 {
539 xfrBlob(blob);
540 }
541
542 //FIXME400 remove this method completely
543 string simpleCompress(const string& elabel, const string& root)
544 {
545 string label=elabel;
546 // FIXME400: this relies on the semi-canonical escaped output from getName
547 if(strchr(label.c_str(), '\\')) {
548 boost::replace_all(label, "\\.", ".");
549 boost::replace_all(label, "\\032", " ");
550 boost::replace_all(label, "\\\\", "\\");
551 }
552 typedef vector<pair<unsigned int, unsigned int> > parts_t;
553 parts_t parts;
554 vstringtok(parts, label, ".");
555 string ret;
556 ret.reserve(label.size()+4);
557 for(parts_t::const_iterator i=parts.begin(); i!=parts.end(); ++i) {
558 if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + i->first, 1 + label.length() - i->first)) { // also match trailing 0, hence '1 +'
559 const unsigned char rootptr[2]={0xc0,0x11};
560 ret.append((const char *) rootptr, 2);
561 return ret;
562 }
563 ret.append(1, (char)(i->second - i->first));
564 ret.append(label.c_str() + i->first, i->second - i->first);
565 }
566 ret.append(1, (char)0);
567 return ret;
568 }
569
570
571 /** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
572 * If you survive that, feel free to read from the pointer */
573 class DNSPacketMangler
574 {
575 public:
576 explicit DNSPacketMangler(std::string& packet)
577 : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
578 {}
579 DNSPacketMangler(char* packet, size_t length)
580 : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
581 {}
582
583 /*! Advances past a wire-format domain name
584 * The name is not checked for adherence to length restrictions.
585 * Compression pointers are not followed.
586 */
587 void skipDomainName()
588 {
589 uint8_t len;
590 while((len=get8BitInt())) {
591 if(len >= 0xc0) { // extended label
592 get8BitInt();
593 return;
594 }
595 skipBytes(len);
596 }
597 }
598
599 void skipBytes(uint16_t bytes)
600 {
601 moveOffset(bytes);
602 }
603 void rewindBytes(uint16_t by)
604 {
605 rewindOffset(by);
606 }
607 uint32_t get32BitInt()
608 {
609 const char* p = d_packet + d_offset;
610 moveOffset(4);
611 uint32_t ret;
612 memcpy(&ret, (void*)p, sizeof(ret));
613 return ntohl(ret);
614 }
615 uint16_t get16BitInt()
616 {
617 const char* p = d_packet + d_offset;
618 moveOffset(2);
619 uint16_t ret;
620 memcpy(&ret, (void*)p, sizeof(ret));
621 return ntohs(ret);
622 }
623
624 uint8_t get8BitInt()
625 {
626 const char* p = d_packet + d_offset;
627 moveOffset(1);
628 return *p;
629 }
630
631 void skipRData()
632 {
633 int toskip = get16BitInt();
634 moveOffset(toskip);
635 }
636
637 void decreaseAndSkip32BitInt(uint32_t decrease)
638 {
639 const char *p = d_packet + d_offset;
640 moveOffset(4);
641
642 uint32_t tmp;
643 memcpy(&tmp, (void*) p, sizeof(tmp));
644 tmp = ntohl(tmp);
645 tmp-=decrease;
646 tmp = htonl(tmp);
647 memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
648 }
649 void setAndSkip32BitInt(uint32_t value)
650 {
651 moveOffset(4);
652
653 value = htonl(value);
654 memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
655 }
656 uint32_t getOffset() const
657 {
658 return d_offset;
659 }
660 private:
661 void moveOffset(uint16_t by)
662 {
663 d_notyouroffset += by;
664 if(d_notyouroffset > d_length)
665 throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
666 + std::to_string(d_length) );
667 }
668 void rewindOffset(uint16_t by)
669 {
670 if(d_notyouroffset < by)
671 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
672 + std::to_string(by));
673 d_notyouroffset -= by;
674 if(d_notyouroffset < 12)
675 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
676 + std::to_string(12));
677 }
678 char* d_packet;
679 size_t d_length;
680
681 uint32_t d_notyouroffset; // only 'moveOffset' can touch this
682 const uint32_t& d_offset; // look.. but don't touch
683
684 };
685
686 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
687 void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)> visitor)
688 {
689 if(length < sizeof(dnsheader))
690 return;
691 try
692 {
693 dnsheader dh;
694 memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
695 uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
696 DNSPacketMangler dpm(packet, length);
697
698 uint64_t n;
699 for(n=0; n < ntohs(dh.qdcount) ; ++n) {
700 dpm.skipDomainName();
701 /* type and class */
702 dpm.skipBytes(4);
703 }
704
705 for(n=0; n < numrecords; ++n) {
706 dpm.skipDomainName();
707
708 uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
709 uint16_t dnstype = dpm.get16BitInt();
710 uint16_t dnsclass = dpm.get16BitInt();
711
712 if(dnstype == QType::OPT) // not getting near that one with a stick
713 break;
714
715 uint32_t dnsttl = dpm.get32BitInt();
716 uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
717 if (newttl) {
718 dpm.rewindBytes(sizeof(newttl));
719 dpm.setAndSkip32BitInt(newttl);
720 }
721 dpm.skipRData();
722 }
723 }
724 catch(...)
725 {
726 return;
727 }
728 }
729
730 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
731 void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
732 {
733 if(length < sizeof(dnsheader))
734 return;
735 try
736 {
737 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
738 const uint64_t dqcount = ntohs(dh->qdcount);
739 const uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
740 DNSPacketMangler dpm(packet, length);
741
742 uint64_t n;
743 for(n=0; n < dqcount; ++n) {
744 dpm.skipDomainName();
745 /* type and class */
746 dpm.skipBytes(4);
747 }
748 // cerr<<"Skipped "<<n<<" questions, now parsing "<<numrecords<<" records"<<endl;
749 for(n=0; n < numrecords; ++n) {
750 dpm.skipDomainName();
751
752 uint16_t dnstype = dpm.get16BitInt();
753 /* class */
754 dpm.skipBytes(2);
755
756 if(dnstype == QType::OPT) // not aging that one with a stick
757 break;
758
759 dpm.decreaseAndSkip32BitInt(seconds);
760 dpm.skipRData();
761 }
762 }
763 catch(...)
764 {
765 return;
766 }
767 }
768
769 void ageDNSPacket(std::string& packet, uint32_t seconds)
770 {
771 ageDNSPacket((char*)packet.c_str(), packet.length(), seconds);
772 }
773
774 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA)
775 {
776 uint32_t result = std::numeric_limits<uint32_t>::max();
777 if(length < sizeof(dnsheader)) {
778 return result;
779 }
780 try
781 {
782 const dnsheader* dh = (const dnsheader*) packet;
783 DNSPacketMangler dpm(const_cast<char*>(packet), length);
784
785 const uint16_t qdcount = ntohs(dh->qdcount);
786 for(size_t n = 0; n < qdcount; ++n) {
787 dpm.skipDomainName();
788 /* type and class */
789 dpm.skipBytes(4);
790 }
791 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
792 for(size_t n = 0; n < numrecords; ++n) {
793 dpm.skipDomainName();
794 const uint16_t dnstype = dpm.get16BitInt();
795 /* class */
796 const uint16_t dnsclass = dpm.get16BitInt();
797
798 if(dnstype == QType::OPT) {
799 break;
800 }
801
802 /* report it if we see a SOA record in the AUTHORITY section */
803 if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) {
804 *seenAuthSOA = true;
805 }
806
807 const uint32_t ttl = dpm.get32BitInt();
808 if (result > ttl) {
809 result = ttl;
810 }
811
812 dpm.skipRData();
813 }
814 }
815 catch(...)
816 {
817 }
818 return result;
819 }
820
821 uint32_t getDNSPacketLength(const char* packet, size_t length)
822 {
823 uint32_t result = length;
824 if(length < sizeof(dnsheader)) {
825 return result;
826 }
827 try
828 {
829 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
830 DNSPacketMangler dpm(const_cast<char*>(packet), length);
831
832 const uint16_t qdcount = ntohs(dh->qdcount);
833 for(size_t n = 0; n < qdcount; ++n) {
834 dpm.skipDomainName();
835 /* type and class */
836 dpm.skipBytes(4);
837 }
838 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
839 for(size_t n = 0; n < numrecords; ++n) {
840 dpm.skipDomainName();
841 /* type (2), class (2) and ttl (4) */
842 dpm.skipBytes(8);
843 dpm.skipRData();
844 }
845 result = dpm.getOffset();
846 }
847 catch(...)
848 {
849 }
850 return result;
851 }
852
853 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type)
854 {
855 uint16_t result = 0;
856 if(length < sizeof(dnsheader)) {
857 return result;
858 }
859 try
860 {
861 const dnsheader* dh = (const dnsheader*) packet;
862 DNSPacketMangler dpm(const_cast<char*>(packet), length);
863
864 const uint16_t qdcount = ntohs(dh->qdcount);
865 for(size_t n = 0; n < qdcount; ++n) {
866 dpm.skipDomainName();
867 if (section == 0) {
868 uint16_t dnstype = dpm.get16BitInt();
869 if (dnstype == type) {
870 result++;
871 }
872 /* class */
873 dpm.skipBytes(2);
874 } else {
875 /* type and class */
876 dpm.skipBytes(4);
877 }
878 }
879 const uint16_t ancount = ntohs(dh->ancount);
880 for(size_t n = 0; n < ancount; ++n) {
881 dpm.skipDomainName();
882 if (section == 1) {
883 uint16_t dnstype = dpm.get16BitInt();
884 if (dnstype == type) {
885 result++;
886 }
887 /* class */
888 dpm.skipBytes(2);
889 } else {
890 /* type and class */
891 dpm.skipBytes(4);
892 }
893 /* ttl */
894 dpm.skipBytes(4);
895 dpm.skipRData();
896 }
897 const uint16_t nscount = ntohs(dh->nscount);
898 for(size_t n = 0; n < nscount; ++n) {
899 dpm.skipDomainName();
900 if (section == 2) {
901 uint16_t dnstype = dpm.get16BitInt();
902 if (dnstype == type) {
903 result++;
904 }
905 /* class */
906 dpm.skipBytes(2);
907 } else {
908 /* type and class */
909 dpm.skipBytes(4);
910 }
911 /* ttl */
912 dpm.skipBytes(4);
913 dpm.skipRData();
914 }
915 const uint16_t arcount = ntohs(dh->arcount);
916 for(size_t n = 0; n < arcount; ++n) {
917 dpm.skipDomainName();
918 if (section == 3) {
919 uint16_t dnstype = dpm.get16BitInt();
920 if (dnstype == type) {
921 result++;
922 }
923 /* class */
924 dpm.skipBytes(2);
925 } else {
926 /* type and class */
927 dpm.skipBytes(4);
928 }
929 /* ttl */
930 dpm.skipBytes(4);
931 dpm.skipRData();
932 }
933 }
934 catch(...)
935 {
936 }
937 return result;
938 }
939
940 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
941 {
942 if (length < sizeof(dnsheader)) {
943 return false;
944 }
945
946 *payloadSize = 0;
947 *z = 0;
948
949 try
950 {
951 const dnsheader* dh = (const dnsheader*) packet;
952 DNSPacketMangler dpm(const_cast<char*>(packet), length);
953
954 const uint16_t qdcount = ntohs(dh->qdcount);
955 for(size_t n = 0; n < qdcount; ++n) {
956 dpm.skipDomainName();
957 /* type and class */
958 dpm.skipBytes(4);
959 }
960 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
961 for(size_t n = 0; n < numrecords; ++n) {
962 dpm.skipDomainName();
963 const uint16_t dnstype = dpm.get16BitInt();
964 const uint16_t dnsclass = dpm.get16BitInt();
965
966 if(dnstype == QType::OPT) {
967 /* skip extended rcode and version */
968 dpm.skipBytes(2);
969 *z = dpm.get16BitInt();
970 *payloadSize = dnsclass;
971 return true;
972 }
973
974 /* TTL */
975 dpm.skipBytes(4);
976 dpm.skipRData();
977 }
978 }
979 catch(...)
980 {
981 }
982
983 return false;
984 }