]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsparser.cc
rec: Only log qname parsing errors when 'log-common-errors' is set
[thirdparty/pdns.git] / pdns / dnsparser.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsparser.hh"
23 #include "dnswriter.hh"
24 #include <boost/algorithm/string.hpp>
25 #include <boost/format.hpp>
26
27 #include "namespaces.hh"
28
29 class UnknownRecordContent : public DNSRecordContent
30 {
31 public:
32 UnknownRecordContent(const DNSRecord& dr, PacketReader& pr)
33 : d_dr(dr)
34 {
35 pr.copyRecord(d_record, dr.d_clen);
36 }
37
38 UnknownRecordContent(const string& zone)
39 {
40 // parse the input
41 vector<string> parts;
42 stringtok(parts, zone);
43 if(parts.size()!=3 && !(parts.size()==2 && equals(parts[1],"0")) )
44 throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got "+std::to_string(parts.size())+": "+zone );
45 const string& relevant=(parts.size() > 2) ? parts[2] : "";
46 unsigned int total=pdns_stou(parts[1]);
47 if(relevant.size() % 2 || relevant.size() / 2 != total)
48 throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str());
49 string out;
50 out.reserve(total+1);
51 for(unsigned int n=0; n < total; ++n) {
52 int c;
53 sscanf(relevant.c_str()+2*n, "%02x", &c);
54 out.append(1, (char)c);
55 }
56
57 d_record.insert(d_record.end(), out.begin(), out.end());
58 }
59
60 string getZoneRepresentation(bool noDot) const override
61 {
62 ostringstream str;
63 str<<"\\# "<<(unsigned int)d_record.size()<<" ";
64 char hex[4];
65 for(size_t n=0; n<d_record.size(); ++n) {
66 snprintf(hex, sizeof(hex), "%02x", d_record.at(n));
67 str << hex;
68 }
69 return str.str();
70 }
71
72 void toPacket(DNSPacketWriter& pw) override
73 {
74 pw.xfrBlob(string(d_record.begin(),d_record.end()));
75 }
76
77 uint16_t getType() const override
78 {
79 return d_dr.d_type;
80 }
81 private:
82 DNSRecord d_dr;
83 vector<uint8_t> d_record;
84 };
85
86 shared_ptr<DNSRecordContent> DNSRecordContent::unserialize(const DNSName& qname, uint16_t qtype, const string& serialized)
87 {
88 dnsheader dnsheader;
89 memset(&dnsheader, 0, sizeof(dnsheader));
90 dnsheader.qdcount=htons(1);
91 dnsheader.ancount=htons(1);
92
93 vector<uint8_t> packet; // build pseudo packet
94
95 /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
96
97 string encoded=qname.toDNSString();
98
99 packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size());
100
101 uint16_t pos=0;
102
103 memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader);
104
105 char tmp[6]="\x0" "\x0\x1" "\x0\x1"; // root question for ns_t_a
106 memcpy(&packet[pos], &tmp, 5); pos+=5;
107
108 memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size();
109
110 struct dnsrecordheader drh;
111 drh.d_type=htons(qtype);
112 drh.d_class=htons(QClass::IN);
113 drh.d_ttl=0;
114 drh.d_clen=htons(serialized.size());
115
116 memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh);
117 memcpy(&packet[pos], serialized.c_str(), serialized.size()); pos+=(uint16_t)serialized.size();
118
119 MOADNSParser mdp(false, (char*)&*packet.begin(), (unsigned int)packet.size());
120 shared_ptr<DNSRecordContent> ret= mdp.d_answers.begin()->first.d_content;
121 return ret;
122 }
123
124 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr,
125 PacketReader& pr)
126 {
127 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
128
129 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
130 if(i==getTypemap().end() || !i->second) {
131 return std::make_shared<UnknownRecordContent>(dr, pr);
132 }
133
134 return i->second(dr, pr);
135 }
136
137 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(uint16_t qtype, uint16_t qclass,
138 const string& content)
139 {
140 zmakermap_t::const_iterator i=getZmakermap().find(make_pair(qclass, qtype));
141 if(i==getZmakermap().end()) {
142 return std::make_shared<UnknownRecordContent>(content);
143 }
144
145 return i->second(content);
146 }
147
148 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr, PacketReader& pr, uint16_t oc) {
149 // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
150 // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
151 // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
152 if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1)
153 return std::make_shared<UnknownRecordContent>(dr, pr);
154
155 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
156
157 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
158 if(i==getTypemap().end() || !i->second) {
159 return std::make_shared<UnknownRecordContent>(dr, pr);
160 }
161
162 return i->second(dr, pr);
163 }
164
165
166 DNSRecordContent::typemap_t& DNSRecordContent::getTypemap()
167 {
168 static DNSRecordContent::typemap_t typemap;
169 return typemap;
170 }
171
172 DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap()
173 {
174 static DNSRecordContent::n2typemap_t n2typemap;
175 return n2typemap;
176 }
177
178 DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap()
179 {
180 static DNSRecordContent::t2namemap_t t2namemap;
181 return t2namemap;
182 }
183
184 DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap()
185 {
186 static DNSRecordContent::zmakermap_t zmakermap;
187 return zmakermap;
188 }
189
190 DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname)
191 {
192 d_type = rr.qtype.getCode();
193 d_ttl = rr.ttl;
194 d_class = rr.qclass;
195 d_place = DNSResourceRecord::ANSWER;
196 d_clen = 0;
197 d_content = DNSRecordContent::mastermake(d_type, rr.qclass, rr.content);
198 }
199
200 // If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
201 DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& d) {
202 DNSResourceRecord rr;
203 rr.qname = d.d_name;
204 rr.qtype = QType(d.d_type);
205 rr.ttl = d.d_ttl;
206 rr.content = d.d_content->getZoneRepresentation(true);
207 rr.auth = false;
208 rr.qclass = d.d_class;
209 return rr;
210 }
211
212 void MOADNSParser::init(bool query, const std::string& packet)
213 {
214 if (packet.size() < sizeof(dnsheader))
215 throw MOADNSException("Packet shorter than minimal header");
216
217 memcpy(&d_header, packet.data(), sizeof(dnsheader));
218
219 if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
220 throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
221
222 d_header.qdcount=ntohs(d_header.qdcount);
223 d_header.ancount=ntohs(d_header.ancount);
224 d_header.nscount=ntohs(d_header.nscount);
225 d_header.arcount=ntohs(d_header.arcount);
226
227 if (query && (d_header.qdcount > 1))
228 throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
229
230 unsigned int n=0;
231
232 PacketReader pr(packet);
233 bool validPacket=false;
234 try {
235 d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
236
237 for(n=0;n < d_header.qdcount; ++n) {
238 d_qname=pr.getName();
239 d_qtype=pr.get16BitInt();
240 d_qclass=pr.get16BitInt();
241 }
242
243 struct dnsrecordheader ah;
244 vector<unsigned char> record;
245 bool seenTSIG = false;
246 validPacket=true;
247 d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
248 for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
249 DNSRecord dr;
250
251 if(n < d_header.ancount)
252 dr.d_place=DNSResourceRecord::ANSWER;
253 else if(n < d_header.ancount + d_header.nscount)
254 dr.d_place=DNSResourceRecord::AUTHORITY;
255 else
256 dr.d_place=DNSResourceRecord::ADDITIONAL;
257
258 unsigned int recordStartPos=pr.getPosition();
259
260 DNSName name=pr.getName();
261
262 pr.getDnsrecordheader(ah);
263 dr.d_ttl=ah.d_ttl;
264 dr.d_type=ah.d_type;
265 dr.d_class=ah.d_class;
266
267 dr.d_name=name;
268 dr.d_clen=ah.d_clen;
269
270 if (query &&
271 !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section
272 (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) {
273 // cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
274 dr.d_content=std::make_shared<UnknownRecordContent>(dr, pr);
275 }
276 else {
277 // cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
278 dr.d_content=DNSRecordContent::mastermake(dr, pr, d_header.opcode);
279 }
280
281 d_answers.push_back(make_pair(dr, pr.getPosition() - sizeof(dnsheader)));
282
283 /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned:
284 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF)
285 */
286 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) {
287 /* only XPF records are allowed after a TSIG */
288 throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one.");
289 }
290
291 if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) {
292 if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) {
293 throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
294 }
295 seenTSIG = true;
296 d_tsigPos = recordStartPos;
297 }
298 }
299
300 #if 0
301 if(pr.getPosition()!=packet.size()) {
302 throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
303 std::to_string(packet.size()) + ")");
304 }
305 #endif
306 }
307 catch(const std::out_of_range &re) {
308 if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
309 if(n < d_header.ancount) {
310 d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
311 }
312 else if(n < d_header.ancount + d_header.nscount) {
313 d_header.nscount = n - d_header.ancount; d_header.arcount=0;
314 }
315 else {
316 d_header.arcount = n - d_header.ancount - d_header.nscount;
317 }
318 }
319 else {
320 throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
321 std::to_string(d_header.rd)+
322 "), out of bounds: "+string(re.what()));
323 }
324 }
325 }
326
327
328 void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
329 {
330 unsigned int n;
331 unsigned char *p=reinterpret_cast<unsigned char*>(&ah);
332
333 for(n=0; n < sizeof(dnsrecordheader); ++n)
334 p[n]=d_content.at(d_pos++);
335
336 ah.d_type=ntohs(ah.d_type);
337 ah.d_class=ntohs(ah.d_class);
338 ah.d_clen=ntohs(ah.d_clen);
339 ah.d_ttl=ntohl(ah.d_ttl);
340
341 d_startrecordpos=d_pos; // needed for getBlob later on
342 d_recordlen=ah.d_clen;
343 }
344
345
346 void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len)
347 {
348 dest.resize(len);
349 if(!len)
350 return;
351
352 for(uint16_t n=0;n<len;++n) {
353 dest.at(n)=d_content.at(d_pos++);
354 }
355 }
356
357 void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
358 {
359 if(d_pos + len > d_content.size())
360 throw std::out_of_range("Attempt to copy outside of packet");
361
362 memcpy(dest, &d_content.at(d_pos), len);
363 d_pos+=len;
364 }
365
366 void PacketReader::xfr48BitInt(uint64_t& ret)
367 {
368 ret=0;
369 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
370 ret<<=8;
371 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
372 ret<<=8;
373 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
374 ret<<=8;
375 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
376 ret<<=8;
377 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
378 ret<<=8;
379 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
380 }
381
382 uint32_t PacketReader::get32BitInt()
383 {
384 uint32_t ret=0;
385 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
386 ret<<=8;
387 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
388 ret<<=8;
389 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
390 ret<<=8;
391 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
392
393 return ret;
394 }
395
396
397 uint16_t PacketReader::get16BitInt()
398 {
399 uint16_t ret=0;
400 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
401 ret<<=8;
402 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
403
404 return ret;
405 }
406
407 uint8_t PacketReader::get8BitInt()
408 {
409 return d_content.at(d_pos++);
410 }
411
412 DNSName PacketReader::getName()
413 {
414 unsigned int consumed;
415 try {
416 DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, 0 /* qtype */, 0 /* qclass */, &consumed, sizeof(dnsheader));
417
418 d_pos+=consumed;
419 return dn;
420 }
421 catch(const std::range_error& re) {
422 throw std::out_of_range(string("dnsname issue: ")+re.what());
423 }
424 catch(...) {
425 throw std::out_of_range("dnsname issue");
426 }
427 throw PDNSException("PacketReader::getName(): name is empty");
428 }
429
430 static string txtEscape(const string &name)
431 {
432 string ret;
433 char ebuf[5];
434
435 for(string::const_iterator i=name.begin();i!=name.end();++i) {
436 if((unsigned char) *i >= 127 || (unsigned char) *i < 32) {
437 snprintf(ebuf, sizeof(ebuf), "\\%03u", (unsigned char)*i);
438 ret += ebuf;
439 }
440 else if(*i=='"' || *i=='\\'){
441 ret += '\\';
442 ret += *i;
443 }
444 else
445 ret += *i;
446 }
447 return ret;
448 }
449
450 // exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
451 string PacketReader::getText(bool multi, bool lenField)
452 {
453 string ret;
454 ret.reserve(40);
455 while(d_pos < d_startrecordpos + d_recordlen ) {
456 if(!ret.empty()) {
457 ret.append(1,' ');
458 }
459 uint16_t labellen;
460 if(lenField)
461 labellen=static_cast<uint8_t>(d_content.at(d_pos++));
462 else
463 labellen=d_recordlen - (d_pos - d_startrecordpos);
464
465 ret.append(1,'"');
466 if(labellen) { // no need to do anything for an empty string
467 string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
468 ret.append(txtEscape(val)); // the end is one beyond the packet
469 }
470 ret.append(1,'"');
471 d_pos+=labellen;
472 if(!multi)
473 break;
474 }
475
476 return ret;
477 }
478
479 string PacketReader::getUnquotedText(bool lenField)
480 {
481 uint16_t stop_at;
482 if(lenField)
483 stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
484 else
485 stop_at = d_recordlen;
486
487 /* think unsigned overflow */
488 if (stop_at < d_pos) {
489 throw std::out_of_range("getUnquotedText out of record range");
490 }
491
492 if(stop_at == d_pos)
493 return "";
494
495 d_pos++;
496 string ret(&d_content.at(d_pos), &d_content.at(stop_at));
497 d_pos = stop_at;
498 return ret;
499 }
500
501 void PacketReader::xfrBlob(string& blob)
502 try
503 {
504 if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen))) {
505 if (d_pos > (d_startrecordpos + d_recordlen)) {
506 throw std::out_of_range("xfrBlob out of record range");
507 }
508 blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1);
509 }
510 else {
511 blob.clear();
512 }
513
514 d_pos = d_startrecordpos + d_recordlen;
515 }
516 catch(...)
517 {
518 throw std::out_of_range("xfrBlob out of range");
519 }
520
521 void PacketReader::xfrBlobNoSpaces(string& blob, int length) {
522 xfrBlob(blob, length);
523 }
524
525 void PacketReader::xfrBlob(string& blob, int length)
526 {
527 if(length) {
528 if (length < 0) {
529 throw std::out_of_range("xfrBlob out of range (negative length)");
530 }
531
532 blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 );
533
534 d_pos += length;
535 }
536 else {
537 blob.clear();
538 }
539 }
540
541
542 void PacketReader::xfrHexBlob(string& blob, bool keepReading)
543 {
544 xfrBlob(blob);
545 }
546
547 //FIXME400 remove this method completely
548 string simpleCompress(const string& elabel, const string& root)
549 {
550 string label=elabel;
551 // FIXME400: this relies on the semi-canonical escaped output from getName
552 if(strchr(label.c_str(), '\\')) {
553 boost::replace_all(label, "\\.", ".");
554 boost::replace_all(label, "\\032", " ");
555 boost::replace_all(label, "\\\\", "\\");
556 }
557 typedef vector<pair<unsigned int, unsigned int> > parts_t;
558 parts_t parts;
559 vstringtok(parts, label, ".");
560 string ret;
561 ret.reserve(label.size()+4);
562 for(parts_t::const_iterator i=parts.begin(); i!=parts.end(); ++i) {
563 if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + i->first, 1 + label.length() - i->first)) { // also match trailing 0, hence '1 +'
564 const unsigned char rootptr[2]={0xc0,0x11};
565 ret.append((const char *) rootptr, 2);
566 return ret;
567 }
568 ret.append(1, (char)(i->second - i->first));
569 ret.append(label.c_str() + i->first, i->second - i->first);
570 }
571 ret.append(1, (char)0);
572 return ret;
573 }
574
575
576 /** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
577 * If you survive that, feel free to read from the pointer */
578 class DNSPacketMangler
579 {
580 public:
581 explicit DNSPacketMangler(std::string& packet)
582 : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
583 {}
584 DNSPacketMangler(char* packet, size_t length)
585 : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
586 {}
587
588 /*! Advances past a wire-format domain name
589 * The name is not checked for adherence to length restrictions.
590 * Compression pointers are not followed.
591 */
592 void skipDomainName()
593 {
594 uint8_t len;
595 while((len=get8BitInt())) {
596 if(len >= 0xc0) { // extended label
597 get8BitInt();
598 return;
599 }
600 skipBytes(len);
601 }
602 }
603
604 void skipBytes(uint16_t bytes)
605 {
606 moveOffset(bytes);
607 }
608 void rewindBytes(uint16_t by)
609 {
610 rewindOffset(by);
611 }
612 uint32_t get32BitInt()
613 {
614 const char* p = d_packet + d_offset;
615 moveOffset(4);
616 uint32_t ret;
617 memcpy(&ret, (void*)p, sizeof(ret));
618 return ntohl(ret);
619 }
620 uint16_t get16BitInt()
621 {
622 const char* p = d_packet + d_offset;
623 moveOffset(2);
624 uint16_t ret;
625 memcpy(&ret, (void*)p, sizeof(ret));
626 return ntohs(ret);
627 }
628
629 uint8_t get8BitInt()
630 {
631 const char* p = d_packet + d_offset;
632 moveOffset(1);
633 return *p;
634 }
635
636 void skipRData()
637 {
638 int toskip = get16BitInt();
639 moveOffset(toskip);
640 }
641
642 void decreaseAndSkip32BitInt(uint32_t decrease)
643 {
644 const char *p = d_packet + d_offset;
645 moveOffset(4);
646
647 uint32_t tmp;
648 memcpy(&tmp, (void*) p, sizeof(tmp));
649 tmp = ntohl(tmp);
650 tmp-=decrease;
651 tmp = htonl(tmp);
652 memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
653 }
654 void setAndSkip32BitInt(uint32_t value)
655 {
656 moveOffset(4);
657
658 value = htonl(value);
659 memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
660 }
661 uint32_t getOffset() const
662 {
663 return d_offset;
664 }
665 private:
666 void moveOffset(uint16_t by)
667 {
668 d_notyouroffset += by;
669 if(d_notyouroffset > d_length)
670 throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
671 + std::to_string(d_length) );
672 }
673 void rewindOffset(uint16_t by)
674 {
675 if(d_notyouroffset < by)
676 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
677 + std::to_string(by));
678 d_notyouroffset -= by;
679 if(d_notyouroffset < 12)
680 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
681 + std::to_string(12));
682 }
683 char* d_packet;
684 size_t d_length;
685
686 uint32_t d_notyouroffset; // only 'moveOffset' can touch this
687 const uint32_t& d_offset; // look.. but don't touch
688
689 };
690
691 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
692 void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)> visitor)
693 {
694 if(length < sizeof(dnsheader))
695 return;
696 try
697 {
698 dnsheader dh;
699 memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
700 uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
701 DNSPacketMangler dpm(packet, length);
702
703 uint64_t n;
704 for(n=0; n < ntohs(dh.qdcount) ; ++n) {
705 dpm.skipDomainName();
706 /* type and class */
707 dpm.skipBytes(4);
708 }
709
710 for(n=0; n < numrecords; ++n) {
711 dpm.skipDomainName();
712
713 uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
714 uint16_t dnstype = dpm.get16BitInt();
715 uint16_t dnsclass = dpm.get16BitInt();
716
717 if(dnstype == QType::OPT) // not getting near that one with a stick
718 break;
719
720 uint32_t dnsttl = dpm.get32BitInt();
721 uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
722 if (newttl) {
723 dpm.rewindBytes(sizeof(newttl));
724 dpm.setAndSkip32BitInt(newttl);
725 }
726 dpm.skipRData();
727 }
728 }
729 catch(...)
730 {
731 return;
732 }
733 }
734
735 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
736 void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
737 {
738 if(length < sizeof(dnsheader))
739 return;
740 try
741 {
742 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
743 const uint64_t dqcount = ntohs(dh->qdcount);
744 const uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
745 DNSPacketMangler dpm(packet, length);
746
747 uint64_t n;
748 for(n=0; n < dqcount; ++n) {
749 dpm.skipDomainName();
750 /* type and class */
751 dpm.skipBytes(4);
752 }
753 // cerr<<"Skipped "<<n<<" questions, now parsing "<<numrecords<<" records"<<endl;
754 for(n=0; n < numrecords; ++n) {
755 dpm.skipDomainName();
756
757 uint16_t dnstype = dpm.get16BitInt();
758 /* class */
759 dpm.skipBytes(2);
760
761 if(dnstype == QType::OPT) // not aging that one with a stick
762 break;
763
764 dpm.decreaseAndSkip32BitInt(seconds);
765 dpm.skipRData();
766 }
767 }
768 catch(...)
769 {
770 return;
771 }
772 }
773
774 void ageDNSPacket(std::string& packet, uint32_t seconds)
775 {
776 ageDNSPacket((char*)packet.c_str(), packet.length(), seconds);
777 }
778
779 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA)
780 {
781 uint32_t result = std::numeric_limits<uint32_t>::max();
782 if(length < sizeof(dnsheader)) {
783 return result;
784 }
785 try
786 {
787 const dnsheader* dh = (const dnsheader*) packet;
788 DNSPacketMangler dpm(const_cast<char*>(packet), length);
789
790 const uint16_t qdcount = ntohs(dh->qdcount);
791 for(size_t n = 0; n < qdcount; ++n) {
792 dpm.skipDomainName();
793 /* type and class */
794 dpm.skipBytes(4);
795 }
796 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
797 for(size_t n = 0; n < numrecords; ++n) {
798 dpm.skipDomainName();
799 const uint16_t dnstype = dpm.get16BitInt();
800 /* class */
801 const uint16_t dnsclass = dpm.get16BitInt();
802
803 if(dnstype == QType::OPT) {
804 break;
805 }
806
807 /* report it if we see a SOA record in the AUTHORITY section */
808 if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) {
809 *seenAuthSOA = true;
810 }
811
812 const uint32_t ttl = dpm.get32BitInt();
813 if (result > ttl) {
814 result = ttl;
815 }
816
817 dpm.skipRData();
818 }
819 }
820 catch(...)
821 {
822 }
823 return result;
824 }
825
826 uint32_t getDNSPacketLength(const char* packet, size_t length)
827 {
828 uint32_t result = length;
829 if(length < sizeof(dnsheader)) {
830 return result;
831 }
832 try
833 {
834 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
835 DNSPacketMangler dpm(const_cast<char*>(packet), length);
836
837 const uint16_t qdcount = ntohs(dh->qdcount);
838 for(size_t n = 0; n < qdcount; ++n) {
839 dpm.skipDomainName();
840 /* type and class */
841 dpm.skipBytes(4);
842 }
843 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
844 for(size_t n = 0; n < numrecords; ++n) {
845 dpm.skipDomainName();
846 /* type (2), class (2) and ttl (4) */
847 dpm.skipBytes(8);
848 dpm.skipRData();
849 }
850 result = dpm.getOffset();
851 }
852 catch(...)
853 {
854 }
855 return result;
856 }
857
858 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type)
859 {
860 uint16_t result = 0;
861 if(length < sizeof(dnsheader)) {
862 return result;
863 }
864 try
865 {
866 const dnsheader* dh = (const dnsheader*) packet;
867 DNSPacketMangler dpm(const_cast<char*>(packet), length);
868
869 const uint16_t qdcount = ntohs(dh->qdcount);
870 for(size_t n = 0; n < qdcount; ++n) {
871 dpm.skipDomainName();
872 if (section == 0) {
873 uint16_t dnstype = dpm.get16BitInt();
874 if (dnstype == type) {
875 result++;
876 }
877 /* class */
878 dpm.skipBytes(2);
879 } else {
880 /* type and class */
881 dpm.skipBytes(4);
882 }
883 }
884 const uint16_t ancount = ntohs(dh->ancount);
885 for(size_t n = 0; n < ancount; ++n) {
886 dpm.skipDomainName();
887 if (section == 1) {
888 uint16_t dnstype = dpm.get16BitInt();
889 if (dnstype == type) {
890 result++;
891 }
892 /* class */
893 dpm.skipBytes(2);
894 } else {
895 /* type and class */
896 dpm.skipBytes(4);
897 }
898 /* ttl */
899 dpm.skipBytes(4);
900 dpm.skipRData();
901 }
902 const uint16_t nscount = ntohs(dh->nscount);
903 for(size_t n = 0; n < nscount; ++n) {
904 dpm.skipDomainName();
905 if (section == 2) {
906 uint16_t dnstype = dpm.get16BitInt();
907 if (dnstype == type) {
908 result++;
909 }
910 /* class */
911 dpm.skipBytes(2);
912 } else {
913 /* type and class */
914 dpm.skipBytes(4);
915 }
916 /* ttl */
917 dpm.skipBytes(4);
918 dpm.skipRData();
919 }
920 const uint16_t arcount = ntohs(dh->arcount);
921 for(size_t n = 0; n < arcount; ++n) {
922 dpm.skipDomainName();
923 if (section == 3) {
924 uint16_t dnstype = dpm.get16BitInt();
925 if (dnstype == type) {
926 result++;
927 }
928 /* class */
929 dpm.skipBytes(2);
930 } else {
931 /* type and class */
932 dpm.skipBytes(4);
933 }
934 /* ttl */
935 dpm.skipBytes(4);
936 dpm.skipRData();
937 }
938 }
939 catch(...)
940 {
941 }
942 return result;
943 }
944
945 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
946 {
947 if (length < sizeof(dnsheader)) {
948 return false;
949 }
950
951 *payloadSize = 0;
952 *z = 0;
953
954 try
955 {
956 const dnsheader* dh = (const dnsheader*) packet;
957 DNSPacketMangler dpm(const_cast<char*>(packet), length);
958
959 const uint16_t qdcount = ntohs(dh->qdcount);
960 for(size_t n = 0; n < qdcount; ++n) {
961 dpm.skipDomainName();
962 /* type and class */
963 dpm.skipBytes(4);
964 }
965 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
966 for(size_t n = 0; n < numrecords; ++n) {
967 dpm.skipDomainName();
968 const uint16_t dnstype = dpm.get16BitInt();
969 const uint16_t dnsclass = dpm.get16BitInt();
970
971 if(dnstype == QType::OPT) {
972 /* skip extended rcode and version */
973 dpm.skipBytes(2);
974 *z = dpm.get16BitInt();
975 *payloadSize = dnsclass;
976 return true;
977 }
978
979 /* TTL */
980 dpm.skipBytes(4);
981 dpm.skipRData();
982 }
983 }
984 catch(...)
985 {
986 }
987
988 return false;
989 }