]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsparser.cc
dnsdist: Handle trailing data correctly when adding OPT or ECS info
[thirdparty/pdns.git] / pdns / dnsparser.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsparser.hh"
23 #include "dnswriter.hh"
24 #include <boost/algorithm/string.hpp>
25 #include <boost/format.hpp>
26
27 #include "namespaces.hh"
28
29 class UnknownRecordContent : public DNSRecordContent
30 {
31 public:
32 UnknownRecordContent(const DNSRecord& dr, PacketReader& pr)
33 : d_dr(dr)
34 {
35 pr.copyRecord(d_record, dr.d_clen);
36 }
37
38 UnknownRecordContent(const string& zone)
39 {
40 // parse the input
41 vector<string> parts;
42 stringtok(parts, zone);
43 if(parts.size()!=3 && !(parts.size()==2 && equals(parts[1],"0")) )
44 throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got "+std::to_string(parts.size())+": "+zone );
45 const string& relevant=(parts.size() > 2) ? parts[2] : "";
46 unsigned int total=pdns_stou(parts[1]);
47 if(relevant.size() % 2 || relevant.size() / 2 != total)
48 throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str());
49 string out;
50 out.reserve(total+1);
51 for(unsigned int n=0; n < total; ++n) {
52 int c;
53 sscanf(relevant.c_str()+2*n, "%02x", &c);
54 out.append(1, (char)c);
55 }
56
57 d_record.insert(d_record.end(), out.begin(), out.end());
58 }
59
60 string getZoneRepresentation(bool noDot) const override
61 {
62 ostringstream str;
63 str<<"\\# "<<(unsigned int)d_record.size()<<" ";
64 char hex[4];
65 for(size_t n=0; n<d_record.size(); ++n) {
66 snprintf(hex,sizeof(hex)-1, "%02x", d_record.at(n));
67 str << hex;
68 }
69 return str.str();
70 }
71
72 void toPacket(DNSPacketWriter& pw) override
73 {
74 pw.xfrBlob(string(d_record.begin(),d_record.end()));
75 }
76
77 uint16_t getType() const override
78 {
79 return d_dr.d_type;
80 }
81 private:
82 DNSRecord d_dr;
83 vector<uint8_t> d_record;
84 };
85
86 shared_ptr<DNSRecordContent> DNSRecordContent::unserialize(const DNSName& qname, uint16_t qtype, const string& serialized)
87 {
88 dnsheader dnsheader;
89 memset(&dnsheader, 0, sizeof(dnsheader));
90 dnsheader.qdcount=htons(1);
91 dnsheader.ancount=htons(1);
92
93 vector<uint8_t> packet; // build pseudo packet
94
95 /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
96
97 string encoded=qname.toDNSString();
98
99 packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size());
100
101 uint16_t pos=0;
102
103 memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader);
104
105 char tmp[6]="\x0" "\x0\x1" "\x0\x1"; // root question for ns_t_a
106 memcpy(&packet[pos], &tmp, 5); pos+=5;
107
108 memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size();
109
110 struct dnsrecordheader drh;
111 drh.d_type=htons(qtype);
112 drh.d_class=htons(QClass::IN);
113 drh.d_ttl=0;
114 drh.d_clen=htons(serialized.size());
115
116 memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh);
117 memcpy(&packet[pos], serialized.c_str(), serialized.size()); pos+=(uint16_t)serialized.size();
118
119 MOADNSParser mdp(false, (char*)&*packet.begin(), (unsigned int)packet.size());
120 shared_ptr<DNSRecordContent> ret= mdp.d_answers.begin()->first.d_content;
121 return ret;
122 }
123
124 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr,
125 PacketReader& pr)
126 {
127 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
128
129 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
130 if(i==getTypemap().end() || !i->second) {
131 return std::make_shared<UnknownRecordContent>(dr, pr);
132 }
133
134 return i->second(dr, pr);
135 }
136
137 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(uint16_t qtype, uint16_t qclass,
138 const string& content)
139 {
140 zmakermap_t::const_iterator i=getZmakermap().find(make_pair(qclass, qtype));
141 if(i==getZmakermap().end()) {
142 return std::make_shared<UnknownRecordContent>(content);
143 }
144
145 return i->second(content);
146 }
147
148 std::shared_ptr<DNSRecordContent> DNSRecordContent::mastermake(const DNSRecord &dr, PacketReader& pr, uint16_t oc) {
149 // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
150 // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
151 // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
152 if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1)
153 return std::make_shared<UnknownRecordContent>(dr, pr);
154
155 uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
156
157 typemap_t::const_iterator i=getTypemap().find(make_pair(searchclass, dr.d_type));
158 if(i==getTypemap().end() || !i->second) {
159 return std::make_shared<UnknownRecordContent>(dr, pr);
160 }
161
162 return i->second(dr, pr);
163 }
164
165
166 DNSRecordContent::typemap_t& DNSRecordContent::getTypemap()
167 {
168 static DNSRecordContent::typemap_t typemap;
169 return typemap;
170 }
171
172 DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap()
173 {
174 static DNSRecordContent::n2typemap_t n2typemap;
175 return n2typemap;
176 }
177
178 DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap()
179 {
180 static DNSRecordContent::t2namemap_t t2namemap;
181 return t2namemap;
182 }
183
184 DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap()
185 {
186 static DNSRecordContent::zmakermap_t zmakermap;
187 return zmakermap;
188 }
189
190 DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname)
191 {
192 d_type = rr.qtype.getCode();
193 d_ttl = rr.ttl;
194 d_class = rr.qclass;
195 d_place = DNSResourceRecord::ANSWER;
196 d_clen = 0;
197 d_content = DNSRecordContent::mastermake(d_type, rr.qclass, rr.content);
198 }
199
200 // If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
201 DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& d) {
202 DNSResourceRecord rr;
203 rr.qname = d.d_name;
204 rr.qtype = QType(d.d_type);
205 rr.ttl = d.d_ttl;
206 rr.content = d.d_content->getZoneRepresentation(true);
207 rr.auth = false;
208 rr.qclass = d.d_class;
209 return rr;
210 }
211
212 void MOADNSParser::init(bool query, const std::string& packet)
213 {
214 if (packet.size() < sizeof(dnsheader))
215 throw MOADNSException("Packet shorter than minimal header");
216
217 memcpy(&d_header, packet.data(), sizeof(dnsheader));
218
219 if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
220 throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
221
222 d_header.qdcount=ntohs(d_header.qdcount);
223 d_header.ancount=ntohs(d_header.ancount);
224 d_header.nscount=ntohs(d_header.nscount);
225 d_header.arcount=ntohs(d_header.arcount);
226
227 if (query && (d_header.qdcount > 1))
228 throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
229
230 unsigned int n=0;
231
232 PacketReader pr(packet);
233 bool validPacket=false;
234 try {
235 d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
236
237 for(n=0;n < d_header.qdcount; ++n) {
238 d_qname=pr.getName();
239 d_qtype=pr.get16BitInt();
240 d_qclass=pr.get16BitInt();
241 }
242
243 struct dnsrecordheader ah;
244 vector<unsigned char> record;
245 bool seenTSIG = false;
246 validPacket=true;
247 d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
248 for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
249 DNSRecord dr;
250
251 if(n < d_header.ancount)
252 dr.d_place=DNSResourceRecord::ANSWER;
253 else if(n < d_header.ancount + d_header.nscount)
254 dr.d_place=DNSResourceRecord::AUTHORITY;
255 else
256 dr.d_place=DNSResourceRecord::ADDITIONAL;
257
258 unsigned int recordStartPos=pr.getPosition();
259
260 DNSName name=pr.getName();
261
262 pr.getDnsrecordheader(ah);
263 dr.d_ttl=ah.d_ttl;
264 dr.d_type=ah.d_type;
265 dr.d_class=ah.d_class;
266
267 dr.d_name=name;
268 dr.d_clen=ah.d_clen;
269
270 if (query &&
271 !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section
272 (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) {
273 // cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
274 dr.d_content=std::make_shared<UnknownRecordContent>(dr, pr);
275 }
276 else {
277 // cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
278 dr.d_content=DNSRecordContent::mastermake(dr, pr, d_header.opcode);
279 }
280
281 d_answers.push_back(make_pair(dr, pr.getPosition() - sizeof(dnsheader)));
282
283 /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned:
284 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF)
285 */
286 if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) {
287 /* only XPF records are allowed after a TSIG */
288 throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one.");
289 }
290
291 if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) {
292 if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) {
293 throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
294 }
295 seenTSIG = true;
296 d_tsigPos = recordStartPos;
297 }
298 }
299
300 #if 0
301 if(pr.getPosition()!=packet.size()) {
302 throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
303 std::to_string(packet.size()) + ")");
304 }
305 #endif
306 }
307 catch(const std::out_of_range &re) {
308 if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
309 if(n < d_header.ancount) {
310 d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
311 }
312 else if(n < d_header.ancount + d_header.nscount) {
313 d_header.nscount = n - d_header.ancount; d_header.arcount=0;
314 }
315 else {
316 d_header.arcount = n - d_header.ancount - d_header.nscount;
317 }
318 }
319 else {
320 throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
321 std::to_string(d_header.rd)+
322 "), out of bounds: "+string(re.what()));
323 }
324 }
325 }
326
327
328 void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
329 {
330 unsigned int n;
331 unsigned char *p=reinterpret_cast<unsigned char*>(&ah);
332
333 for(n=0; n < sizeof(dnsrecordheader); ++n)
334 p[n]=d_content.at(d_pos++);
335
336 ah.d_type=ntohs(ah.d_type);
337 ah.d_class=ntohs(ah.d_class);
338 ah.d_clen=ntohs(ah.d_clen);
339 ah.d_ttl=ntohl(ah.d_ttl);
340
341 d_startrecordpos=d_pos; // needed for getBlob later on
342 d_recordlen=ah.d_clen;
343 }
344
345
346 void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len)
347 {
348 dest.resize(len);
349 if(!len)
350 return;
351
352 for(uint16_t n=0;n<len;++n) {
353 dest.at(n)=d_content.at(d_pos++);
354 }
355 }
356
357 void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
358 {
359 if(d_pos + len > d_content.size())
360 throw std::out_of_range("Attempt to copy outside of packet");
361
362 memcpy(dest, &d_content.at(d_pos), len);
363 d_pos+=len;
364 }
365
366 void PacketReader::xfr48BitInt(uint64_t& ret)
367 {
368 ret=0;
369 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
370 ret<<=8;
371 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
372 ret<<=8;
373 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
374 ret<<=8;
375 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
376 ret<<=8;
377 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
378 ret<<=8;
379 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
380 }
381
382 uint32_t PacketReader::get32BitInt()
383 {
384 uint32_t ret=0;
385 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
386 ret<<=8;
387 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
388 ret<<=8;
389 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
390 ret<<=8;
391 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
392
393 return ret;
394 }
395
396
397 uint16_t PacketReader::get16BitInt()
398 {
399 uint16_t ret=0;
400 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
401 ret<<=8;
402 ret+=static_cast<uint8_t>(d_content.at(d_pos++));
403
404 return ret;
405 }
406
407 uint8_t PacketReader::get8BitInt()
408 {
409 return d_content.at(d_pos++);
410 }
411
412 DNSName PacketReader::getName()
413 {
414 unsigned int consumed;
415 try {
416 DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, 0 /* qtype */, 0 /* qclass */, &consumed, sizeof(dnsheader));
417
418 d_pos+=consumed;
419 return dn;
420 }
421 catch(const std::range_error& re) {
422 throw std::out_of_range(string("dnsname issue: ")+re.what());
423 }
424 catch(...) {
425 throw std::out_of_range("dnsname issue");
426 }
427 throw PDNSException("PacketReader::getName(): name is empty");
428 }
429
430 static string txtEscape(const string &name)
431 {
432 string ret;
433 char ebuf[5];
434
435 for(string::const_iterator i=name.begin();i!=name.end();++i) {
436 if((unsigned char) *i >= 127 || (unsigned char) *i < 32) {
437 snprintf(ebuf, sizeof(ebuf), "\\%03u", (unsigned char)*i);
438 ret += ebuf;
439 }
440 else if(*i=='"' || *i=='\\'){
441 ret += '\\';
442 ret += *i;
443 }
444 else
445 ret += *i;
446 }
447 return ret;
448 }
449
450 // exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
451 string PacketReader::getText(bool multi, bool lenField)
452 {
453 string ret;
454 ret.reserve(40);
455 while(d_pos < d_startrecordpos + d_recordlen ) {
456 if(!ret.empty()) {
457 ret.append(1,' ');
458 }
459 uint16_t labellen;
460 if(lenField)
461 labellen=static_cast<uint8_t>(d_content.at(d_pos++));
462 else
463 labellen=d_recordlen - (d_pos - d_startrecordpos);
464
465 ret.append(1,'"');
466 if(labellen) { // no need to do anything for an empty string
467 string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
468 ret.append(txtEscape(val)); // the end is one beyond the packet
469 }
470 ret.append(1,'"');
471 d_pos+=labellen;
472 if(!multi)
473 break;
474 }
475
476 return ret;
477 }
478
479 string PacketReader::getUnquotedText(bool lenField)
480 {
481 uint16_t stop_at;
482 if(lenField)
483 stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
484 else
485 stop_at = d_recordlen;
486
487 if(stop_at == d_pos)
488 return "";
489
490 d_pos++;
491 string ret(&d_content.at(d_pos), &d_content.at(stop_at));
492 d_pos = stop_at;
493 return ret;
494 }
495
496 void PacketReader::xfrBlob(string& blob)
497 try
498 {
499 if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen)))
500 blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1);
501 else
502 blob.clear();
503
504 d_pos = d_startrecordpos + d_recordlen;
505 }
506 catch(...)
507 {
508 throw std::out_of_range("xfrBlob out of range");
509 }
510
511 void PacketReader::xfrBlobNoSpaces(string& blob, int length) {
512 xfrBlob(blob, length);
513 }
514
515 void PacketReader::xfrBlob(string& blob, int length)
516 {
517 if(length) {
518 blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 );
519
520 d_pos += length;
521 }
522 else
523 blob.clear();
524 }
525
526
527 void PacketReader::xfrHexBlob(string& blob, bool keepReading)
528 {
529 xfrBlob(blob);
530 }
531
532 //FIXME400 remove this method completely
533 string simpleCompress(const string& elabel, const string& root)
534 {
535 string label=elabel;
536 // FIXME400: this relies on the semi-canonical escaped output from getName
537 if(strchr(label.c_str(), '\\')) {
538 boost::replace_all(label, "\\.", ".");
539 boost::replace_all(label, "\\032", " ");
540 boost::replace_all(label, "\\\\", "\\");
541 }
542 typedef vector<pair<unsigned int, unsigned int> > parts_t;
543 parts_t parts;
544 vstringtok(parts, label, ".");
545 string ret;
546 ret.reserve(label.size()+4);
547 for(parts_t::const_iterator i=parts.begin(); i!=parts.end(); ++i) {
548 if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + i->first, 1 + label.length() - i->first)) { // also match trailing 0, hence '1 +'
549 const unsigned char rootptr[2]={0xc0,0x11};
550 ret.append((const char *) rootptr, 2);
551 return ret;
552 }
553 ret.append(1, (char)(i->second - i->first));
554 ret.append(label.c_str() + i->first, i->second - i->first);
555 }
556 ret.append(1, (char)0);
557 return ret;
558 }
559
560
561 /** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
562 * If you survive that, feel free to read from the pointer */
563 class DNSPacketMangler
564 {
565 public:
566 explicit DNSPacketMangler(std::string& packet)
567 : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
568 {}
569 DNSPacketMangler(char* packet, size_t length)
570 : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
571 {}
572
573 void skipLabel()
574 {
575 uint8_t len;
576 while((len=get8BitInt())) {
577 if(len >= 0xc0) { // extended label
578 get8BitInt();
579 return;
580 }
581 skipBytes(len);
582 }
583 }
584 void skipBytes(uint16_t bytes)
585 {
586 moveOffset(bytes);
587 }
588 void rewindBytes(uint16_t by)
589 {
590 rewindOffset(by);
591 }
592 uint32_t get32BitInt()
593 {
594 const char* p = d_packet + d_offset;
595 moveOffset(4);
596 uint32_t ret;
597 memcpy(&ret, (void*)p, sizeof(ret));
598 return ntohl(ret);
599 }
600 uint16_t get16BitInt()
601 {
602 const char* p = d_packet + d_offset;
603 moveOffset(2);
604 uint16_t ret;
605 memcpy(&ret, (void*)p, sizeof(ret));
606 return ntohs(ret);
607 }
608
609 uint8_t get8BitInt()
610 {
611 const char* p = d_packet + d_offset;
612 moveOffset(1);
613 return *p;
614 }
615
616 void skipRData()
617 {
618 int toskip = get16BitInt();
619 moveOffset(toskip);
620 }
621
622 void decreaseAndSkip32BitInt(uint32_t decrease)
623 {
624 const char *p = d_packet + d_offset;
625 moveOffset(4);
626
627 uint32_t tmp;
628 memcpy(&tmp, (void*) p, sizeof(tmp));
629 tmp = ntohl(tmp);
630 tmp-=decrease;
631 tmp = htonl(tmp);
632 memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
633 }
634 void setAndSkip32BitInt(uint32_t value)
635 {
636 moveOffset(4);
637
638 value = htonl(value);
639 memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
640 }
641 uint32_t getOffset() const
642 {
643 return d_offset;
644 }
645 private:
646 void moveOffset(uint16_t by)
647 {
648 d_notyouroffset += by;
649 if(d_notyouroffset > d_length)
650 throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
651 + std::to_string(d_length) );
652 }
653 void rewindOffset(uint16_t by)
654 {
655 if(d_notyouroffset < by)
656 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
657 + std::to_string(by));
658 d_notyouroffset -= by;
659 if(d_notyouroffset < 12)
660 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
661 + std::to_string(12));
662 }
663 char* d_packet;
664 size_t d_length;
665
666 uint32_t d_notyouroffset; // only 'moveOffset' can touch this
667 const uint32_t& d_offset; // look.. but don't touch
668
669 };
670
671 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
672 void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)> visitor)
673 {
674 if(length < sizeof(dnsheader))
675 return;
676 try
677 {
678 dnsheader dh;
679 memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
680 uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
681 DNSPacketMangler dpm(packet, length);
682
683 uint64_t n;
684 for(n=0; n < ntohs(dh.qdcount) ; ++n) {
685 dpm.skipLabel();
686 /* type and class */
687 dpm.skipBytes(4);
688 }
689
690 for(n=0; n < numrecords; ++n) {
691 dpm.skipLabel();
692
693 uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
694 uint16_t dnstype = dpm.get16BitInt();
695 uint16_t dnsclass = dpm.get16BitInt();
696
697 if(dnstype == QType::OPT) // not getting near that one with a stick
698 break;
699
700 uint32_t dnsttl = dpm.get32BitInt();
701 uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
702 if (newttl) {
703 dpm.rewindBytes(sizeof(newttl));
704 dpm.setAndSkip32BitInt(newttl);
705 }
706 dpm.skipRData();
707 }
708 }
709 catch(...)
710 {
711 return;
712 }
713 }
714
715 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
716 void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
717 {
718 if(length < sizeof(dnsheader))
719 return;
720 try
721 {
722 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
723 const uint64_t dqcount = ntohs(dh->qdcount);
724 const uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
725 DNSPacketMangler dpm(packet, length);
726
727 uint64_t n;
728 for(n=0; n < dqcount; ++n) {
729 dpm.skipLabel();
730 /* type and class */
731 dpm.skipBytes(4);
732 }
733 // cerr<<"Skipped "<<n<<" questions, now parsing "<<numrecords<<" records"<<endl;
734 for(n=0; n < numrecords; ++n) {
735 dpm.skipLabel();
736
737 uint16_t dnstype = dpm.get16BitInt();
738 /* class */
739 dpm.skipBytes(2);
740
741 if(dnstype == QType::OPT) // not aging that one with a stick
742 break;
743
744 dpm.decreaseAndSkip32BitInt(seconds);
745 dpm.skipRData();
746 }
747 }
748 catch(...)
749 {
750 return;
751 }
752 }
753
754 void ageDNSPacket(std::string& packet, uint32_t seconds)
755 {
756 ageDNSPacket((char*)packet.c_str(), packet.length(), seconds);
757 }
758
759 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA)
760 {
761 uint32_t result = std::numeric_limits<uint32_t>::max();
762 if(length < sizeof(dnsheader)) {
763 return result;
764 }
765 try
766 {
767 const dnsheader* dh = (const dnsheader*) packet;
768 DNSPacketMangler dpm(const_cast<char*>(packet), length);
769
770 const uint16_t qdcount = ntohs(dh->qdcount);
771 for(size_t n = 0; n < qdcount; ++n) {
772 dpm.skipLabel();
773 /* type and class */
774 dpm.skipBytes(4);
775 }
776 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
777 for(size_t n = 0; n < numrecords; ++n) {
778 dpm.skipLabel();
779 const uint16_t dnstype = dpm.get16BitInt();
780 /* class */
781 const uint16_t dnsclass = dpm.get16BitInt();
782
783 if(dnstype == QType::OPT) {
784 break;
785 }
786
787 /* report it if we see a SOA record in the AUTHORITY section */
788 if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) {
789 *seenAuthSOA = true;
790 }
791
792 const uint32_t ttl = dpm.get32BitInt();
793 if (result > ttl) {
794 result = ttl;
795 }
796
797 dpm.skipRData();
798 }
799 }
800 catch(...)
801 {
802 }
803 return result;
804 }
805
806 uint32_t getDNSPacketLength(const char* packet, size_t length)
807 {
808 uint32_t result = length;
809 if(length < sizeof(dnsheader)) {
810 return result;
811 }
812 try
813 {
814 const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
815 DNSPacketMangler dpm(const_cast<char*>(packet), length);
816
817 const uint16_t qdcount = ntohs(dh->qdcount);
818 for(size_t n = 0; n < qdcount; ++n) {
819 dpm.skipLabel();
820 /* type and class */
821 dpm.skipBytes(4);
822 }
823 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
824 for(size_t n = 0; n < numrecords; ++n) {
825 dpm.skipLabel();
826 /* type (2), class (2) and ttl (4) */
827 dpm.skipBytes(8);
828 dpm.skipRData();
829 }
830 result = dpm.getOffset();
831 }
832 catch(...)
833 {
834 }
835 return result;
836 }
837
838 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type)
839 {
840 uint16_t result = 0;
841 if(length < sizeof(dnsheader)) {
842 return result;
843 }
844 try
845 {
846 const dnsheader* dh = (const dnsheader*) packet;
847 DNSPacketMangler dpm(const_cast<char*>(packet), length);
848
849 const uint16_t qdcount = ntohs(dh->qdcount);
850 for(size_t n = 0; n < qdcount; ++n) {
851 dpm.skipLabel();
852 if (section == 0) {
853 uint16_t dnstype = dpm.get16BitInt();
854 if (dnstype == type) {
855 result++;
856 }
857 /* class */
858 dpm.skipBytes(2);
859 } else {
860 /* type and class */
861 dpm.skipBytes(4);
862 }
863 }
864 const uint16_t ancount = ntohs(dh->ancount);
865 for(size_t n = 0; n < ancount; ++n) {
866 dpm.skipLabel();
867 if (section == 1) {
868 uint16_t dnstype = dpm.get16BitInt();
869 if (dnstype == type) {
870 result++;
871 }
872 /* class */
873 dpm.skipBytes(2);
874 } else {
875 /* type and class */
876 dpm.skipBytes(4);
877 }
878 /* ttl */
879 dpm.skipBytes(4);
880 dpm.skipRData();
881 }
882 const uint16_t nscount = ntohs(dh->nscount);
883 for(size_t n = 0; n < nscount; ++n) {
884 dpm.skipLabel();
885 if (section == 2) {
886 uint16_t dnstype = dpm.get16BitInt();
887 if (dnstype == type) {
888 result++;
889 }
890 /* class */
891 dpm.skipBytes(2);
892 } else {
893 /* type and class */
894 dpm.skipBytes(4);
895 }
896 /* ttl */
897 dpm.skipBytes(4);
898 dpm.skipRData();
899 }
900 const uint16_t arcount = ntohs(dh->arcount);
901 for(size_t n = 0; n < arcount; ++n) {
902 dpm.skipLabel();
903 if (section == 3) {
904 uint16_t dnstype = dpm.get16BitInt();
905 if (dnstype == type) {
906 result++;
907 }
908 /* class */
909 dpm.skipBytes(2);
910 } else {
911 /* type and class */
912 dpm.skipBytes(4);
913 }
914 /* ttl */
915 dpm.skipBytes(4);
916 dpm.skipRData();
917 }
918 }
919 catch(...)
920 {
921 }
922 return result;
923 }
924
925 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
926 {
927 if (length < sizeof(dnsheader)) {
928 return false;
929 }
930
931 *payloadSize = 0;
932 *z = 0;
933
934 try
935 {
936 const dnsheader* dh = (const dnsheader*) packet;
937 DNSPacketMangler dpm(const_cast<char*>(packet), length);
938
939 const uint16_t qdcount = ntohs(dh->qdcount);
940 for(size_t n = 0; n < qdcount; ++n) {
941 dpm.skipLabel();
942 /* type and class */
943 dpm.skipBytes(4);
944 }
945 const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
946 for(size_t n = 0; n < numrecords; ++n) {
947 dpm.skipLabel();
948 const uint16_t dnstype = dpm.get16BitInt();
949 const uint16_t dnsclass = dpm.get16BitInt();
950
951 if(dnstype == QType::OPT) {
952 /* skip extended rcode and version */
953 dpm.skipBytes(2);
954 *z = dpm.get16BitInt();
955 *payloadSize = dnsclass;
956 return true;
957 }
958
959 /* TTL */
960 dpm.skipBytes(4);
961 dpm.skipRData();
962 }
963 }
964 catch(...)
965 {
966 }
967
968 return false;
969 }