2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
27 #include <unordered_set>
31 // #include <netinet/in.h>
35 #include "dnswriter.hh"
37 #include "noinitvector.hh"
38 #include "pdnsexception.hh"
40 #include "svc-records.hh"
42 /** DNS records have three representations:
44 2) parsed in a class, ready for use
47 We should implement bidirectional transitions between 1&2 and 2&3.
48 Currently we have: 1 -> 2
51 We can add: 2 -> 1 easily by reversing the packetwriter
52 And we might be able to reverse 2 -> 3 as well
55 #include "namespaces.hh"
57 class MOADNSException : public runtime_error
60 MOADNSException(const string& str) : runtime_error(str)
70 PacketReader(const std::string_view& content, uint16_t initialPos=sizeof(dnsheader))
71 : d_pos(initialPos), d_startrecordpos(initialPos), d_content(content)
73 if(content.size() > std::numeric_limits<uint16_t>::max())
74 throw std::out_of_range("packet too large");
76 d_recordlen = (uint16_t) content.size();
80 uint32_t get32BitInt();
81 uint16_t get16BitInt();
84 void xfrNodeOrLocatorID(NodeOrLocatorID& val);
85 void xfr48BitInt(uint64_t& val);
87 void xfr32BitInt(uint32_t& val)
92 void xfrIP(uint32_t& val)
98 void xfrIP6(std::string &val) {
102 void xfrCAWithoutPort(uint8_t version, ComboAddress &val) {
104 if (version == 4) xfrBlob(blob, 4);
105 else if (version == 6) xfrBlob(blob, 16);
106 else throw runtime_error("invalid IP protocol");
107 val = makeComboAddressFromRaw(version, blob);
110 void xfrCAPort(ComboAddress &val) {
113 val.sin4.sin_port = port;
116 void xfrTime(uint32_t& val)
122 void xfr16BitInt(uint16_t& val)
127 void xfrType(uint16_t& val)
133 void xfr8BitInt(uint8_t& val)
138 void xfrName(DNSName& name, bool /* compress */ = false, bool /* noDot */ = false)
143 void xfrText(string &text, bool multi=false, bool lenField=true)
145 text=getText(multi, lenField);
148 void xfrUnquotedText(string &text, bool lenField){
149 text=getUnquotedText(lenField);
152 void xfrBlob(string& blob);
153 void xfrBlobNoSpaces(string& blob, int len);
154 void xfrBlob(string& blob, int length);
155 void xfrHexBlob(string& blob, bool keepReading=false);
156 void xfrSvcParamKeyVals(set<SvcParam> &kvs);
158 void getDnsrecordheader(struct dnsrecordheader &ah);
159 void copyRecord(vector<unsigned char>& dest, uint16_t len);
160 void copyRecord(unsigned char* dest, uint16_t len);
163 string getText(bool multi, bool lenField);
164 string getUnquotedText(bool lenField);
167 bool eof() { return true; };
168 const string getRemaining() const {
172 uint16_t getPosition() const
177 void skip(uint16_t n)
184 uint16_t d_startrecordpos; // needed for getBlob later on
185 uint16_t d_recordlen; // ditto
186 uint16_t not_used; // Aligns the whole class on 8-byte boundaries
187 const std::string_view d_content;
192 class DNSRecordContent
195 static std::shared_ptr<DNSRecordContent> make(const DNSRecord& dr, PacketReader& pr);
196 static std::shared_ptr<DNSRecordContent> make(const DNSRecord& dr, PacketReader& pr, uint16_t opcode);
197 static std::shared_ptr<DNSRecordContent> make(uint16_t qtype, uint16_t qclass, const string& zone);
198 static string upgradeContent(const DNSName& qname, const QType& qtype, const string& content);
200 virtual std::string getZoneRepresentation(bool noDot=false) const = 0;
201 virtual ~DNSRecordContent() {}
202 virtual void toPacket(DNSPacketWriter& pw) const = 0;
203 // returns the wire format of the content, possibly including compressed pointers pointing to the owner name (unless canonic or lowerCase are set)
204 string serialize(const DNSName& qname, bool canonic=false, bool lowerCase=false) const
206 vector<uint8_t> packet;
207 DNSPacketWriter pw(packet, g_rootdnsname, 1);
212 pw.setLowercase(true);
214 pw.startRecord(qname, this->getType());
218 pw.getRecordPayload(record); // needs to be called before commit()
222 virtual bool operator==(const DNSRecordContent& rhs) const
224 return typeid(*this)==typeid(rhs) && this->getZoneRepresentation() == rhs.getZoneRepresentation();
227 // parse the content in wire format, possibly including compressed pointers pointing to the owner name
228 static shared_ptr<DNSRecordContent> deserialize(const DNSName& qname, uint16_t qtype, const string& serialized);
230 void doRecordCheck(const struct DNSRecord&){}
232 typedef std::shared_ptr<DNSRecordContent> makerfunc_t(const struct DNSRecord& dr, PacketReader& pr);
233 typedef std::shared_ptr<DNSRecordContent> zmakerfunc_t(const string& str);
235 static void regist(uint16_t cl, uint16_t ty, makerfunc_t* f, zmakerfunc_t* z, const char* name)
238 getTypemap()[pair(cl,ty)]=f;
240 getZmakermap()[pair(cl,ty)]=z;
242 getT2Namemap().emplace(pair(cl, ty), name);
243 getN2Typemap().emplace(name, pair(cl, ty));
246 static void unregist(uint16_t cl, uint16_t ty)
248 auto key = pair(cl, ty);
249 getTypemap().erase(key);
250 getZmakermap().erase(key);
253 static bool isUnknownType(const string& name)
255 return boost::starts_with(name, "TYPE") || boost::starts_with(name, "type");
258 static uint16_t TypeToNumber(const string& name)
260 n2typemap_t::const_iterator iter = getN2Typemap().find(toUpper(name));
261 if(iter != getN2Typemap().end())
262 return iter->second.second;
264 if (isUnknownType(name)) {
265 return pdns::checked_stoi<uint16_t>(name.substr(4));
268 throw runtime_error("Unknown DNS type '"+name+"'");
271 static const string NumberToType(uint16_t num, uint16_t classnum = QClass::IN)
273 auto iter = getT2Namemap().find(pair(classnum, num));
274 if(iter == getT2Namemap().end())
275 return "TYPE" + std::to_string(num);
276 // throw runtime_error("Unknown DNS type with numerical id "+std::to_string(num));
281 * \brief Return whether we have implemented a content representation for this type
283 static bool isRegisteredType(uint16_t rtype, uint16_t rclass = QClass::IN);
285 virtual uint16_t getType() const = 0;
288 typedef std::map<std::pair<uint16_t, uint16_t>, makerfunc_t* > typemap_t;
289 typedef std::map<std::pair<uint16_t, uint16_t>, zmakerfunc_t* > zmakermap_t;
290 typedef std::map<std::pair<uint16_t, uint16_t>, string > t2namemap_t;
291 typedef std::map<string, std::pair<uint16_t, uint16_t> > n2typemap_t;
292 static typemap_t& getTypemap();
293 static t2namemap_t& getT2Namemap();
294 static n2typemap_t& getN2Typemap();
295 static zmakermap_t& getZmakermap();
303 explicit DNSRecord(const DNSResourceRecord& rr);
304 DNSRecord(const std::string& name,
305 std::shared_ptr<DNSRecordContent> content,
307 const uint16_t qclass = QClass::IN,
308 const uint32_t ttl = 86400,
309 const uint16_t clen = 0,
310 const DNSResourceRecord::Place place = DNSResourceRecord::ANSWER) :
311 d_name(DNSName(name)),
312 d_content(std::move(content)),
321 std::shared_ptr<const DNSRecordContent> d_content;
327 DNSResourceRecord::Place d_place{DNSResourceRecord::ANSWER};
329 [[nodiscard]] std::string print(const std::string& indent = "") const
332 s << indent << "Content = " << d_content->getZoneRepresentation() << std::endl;
333 s << indent << "Type = " << d_type << std::endl;
334 s << indent << "Class = " << d_class << std::endl;
335 s << indent << "TTL = " << d_ttl << std::endl;
336 s << indent << "clen = " << d_clen << std::endl;
337 s << indent << "Place = " << std::to_string(d_place) << std::endl;
341 void setContent(const std::shared_ptr<const DNSRecordContent>& content)
346 void setContent(std::shared_ptr<const DNSRecordContent>&& content)
348 d_content = std::move(content);
351 [[nodiscard]] const std::shared_ptr<const DNSRecordContent>& getContent() const
356 bool operator<(const DNSRecord& rhs) const
358 if(std::tie(d_name, d_type, d_class, d_ttl) < std::tie(rhs.d_name, rhs.d_type, rhs.d_class, rhs.d_ttl))
361 if(std::tie(d_name, d_type, d_class, d_ttl) != std::tie(rhs.d_name, rhs.d_type, rhs.d_class, rhs.d_ttl))
366 lzrp=toLower(d_content->getZoneRepresentation());
368 rzrp=toLower(rhs.d_content->getZoneRepresentation());
373 // this orders in canonical order and keeps the SOA record on top
374 static bool prettyCompare(const DNSRecord& a, const DNSRecord& b)
376 auto aType = (a.d_type == QType::SOA) ? 0 : a.d_type;
377 auto bType = (b.d_type == QType::SOA) ? 0 : b.d_type;
379 if(a.d_name.canonCompare(b.d_name))
381 if(b.d_name.canonCompare(a.d_name))
384 if(std::tie(aType, a.d_class, a.d_ttl) < std::tie(bType, b.d_class, b.d_ttl))
387 if(std::tie(aType, a.d_class, a.d_ttl) != std::tie(bType, b.d_class, b.d_ttl))
392 lzrp = a.d_content->getZoneRepresentation();
394 rzrp = b.d_content->getZoneRepresentation();
399 #if !defined(RECURSOR)
404 return toLower(lzrp) < toLower(rzrp);
408 bool operator==(const DNSRecord& rhs) const
410 if (d_type != rhs.d_type || d_class != rhs.d_class || d_name != rhs.d_name) {
414 return *d_content == *rhs.d_content;
421 uint8_t scopeMask{0};
423 DNSName wildcardname;
425 bool disabled{false};
429 class UnknownRecordContent : public DNSRecordContent
432 UnknownRecordContent(const DNSRecord& dr, PacketReader& pr)
435 pr.copyRecord(d_record, dr.d_clen);
438 UnknownRecordContent(const string& zone);
440 string getZoneRepresentation(bool noDot) const override;
441 void toPacket(DNSPacketWriter& pw) const override;
442 uint16_t getType() const override
447 const vector<uint8_t>& getRawContent() const
454 vector<uint8_t> d_record;
457 //! This class can be used to parse incoming packets, and is copyable
458 class MOADNSParser : public boost::noncopyable
461 //! Parse from a string
462 MOADNSParser(bool query, const string& buffer): d_tsigPos(0)
467 //! Parse from a pointer and length
468 MOADNSParser(bool query, const char *packet, unsigned int len) : d_tsigPos(0)
470 init(query, std::string_view(packet, len));
474 uint16_t d_qclass, d_qtype;
478 typedef vector<pair<DNSRecord, uint16_t > > answers_t;
480 //! All answers contained in this packet (everything *but* the question section)
483 uint16_t getTSIGPos() const
488 bool hasEDNS() const;
491 void init(bool query, const std::string_view& packet);
495 string simpleCompress(const string& label, const string& root="");
496 void ageDNSPacket(char* packet, size_t length, uint32_t seconds, const dnsheader_aligned&);
497 void ageDNSPacket(std::string& packet, uint32_t seconds, const dnsheader_aligned&);
498 void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor);
499 void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::unordered_set<QType>& qtypes);
500 void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::unordered_set<QType>& qtypes);
501 void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::unordered_set<QType>& qtypes);
502 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr);
503 uint32_t getDNSPacketLength(const char* packet, size_t length);
504 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
505 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z);
506 /* call the visitor for every records in the answer, authority and additional sections, passing the section, class, type, ttl, rdatalength and rdata
507 to the visitor. Stops whenever the visitor returns false or at the end of the packet */
508 bool visitDNSPacket(const std::string_view& packet, const std::function<bool(uint8_t, uint16_t, uint16_t, uint32_t, uint16_t, const char*)>& visitor);
511 std::shared_ptr<const T> getRR(const DNSRecord& dr)
513 return std::dynamic_pointer_cast<const T>(dr.getContent());
516 /** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
517 * If you survive that, feel free to read from the pointer */
518 class DNSPacketMangler
521 explicit DNSPacketMangler(std::string& packet)
522 : d_packet(packet.data()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
524 DNSPacketMangler(char* packet, size_t length)
525 : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
528 /*! Advances past a wire-format domain name
529 * The name is not checked for adherence to length restrictions.
530 * Compression pointers are not followed.
532 void skipDomainName()
535 while((len=get8BitInt())) {
536 if(len >= 0xc0) { // extended label
544 void skipBytes(uint16_t bytes)
548 void rewindBytes(uint16_t by)
552 uint32_t get32BitInt()
554 const char* p = d_packet + d_offset;
557 memcpy(&ret, p, sizeof(ret));
560 uint16_t get16BitInt()
562 const char* p = d_packet + d_offset;
565 memcpy(&ret, p, sizeof(ret));
571 const char* p = d_packet + d_offset;
578 int toskip = get16BitInt();
582 void decreaseAndSkip32BitInt(uint32_t decrease)
584 const char *p = d_packet + d_offset;
588 memcpy(&tmp, p, sizeof(tmp));
590 if (tmp > decrease) {
596 memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
599 void setAndSkip32BitInt(uint32_t value)
603 value = htonl(value);
604 memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
607 uint32_t getOffset() const
613 void moveOffset(uint16_t by)
615 d_notyouroffset += by;
616 if(d_notyouroffset > d_length)
617 throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
618 + std::to_string(d_length) );
621 void rewindOffset(uint16_t by)
623 if(d_notyouroffset < by)
624 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
625 + std::to_string(by));
626 d_notyouroffset -= by;
627 if(d_notyouroffset < 12)
628 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
629 + std::to_string(12));
635 uint32_t d_notyouroffset; // only 'moveOffset' can touch this
636 const uint32_t& d_offset; // look.. but don't touch