From: Remi Gacogne Date: Thu, 18 Jan 2024 13:00:31 +0000 (+0100) Subject: dnsname: Use a view instead of pointer arithmetic in DNSName::packetParser() X-Git-Tag: dnsdist-1.9.0-rc1~2^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2e163d91c793764304f6f1c83dbcd5fbbd546226;p=thirdparty%2Fpdns.git dnsname: Use a view instead of pointer arithmetic in DNSName::packetParser() --- diff --git a/pdns/dnsname.cc b/pdns/dnsname.cc index c45dd4f1d9..5e01af3ce1 100644 --- a/pdns/dnsname.cc +++ b/pdns/dnsname.cc @@ -124,18 +124,21 @@ static void checkLabelLength(uint8_t length) } // this parses a DNS name until a compression pointer is found -const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start, const unsigned char* pos, const unsigned char* end, bool uncompress) +size_t DNSName::parsePacketUncompressed(const UnsignedCharView& view, size_t pos, bool uncompress) { + const size_t initialPos = pos; size_t totalLength = 0; unsigned char labellen = 0; do { - labellen = *pos; + labellen = view.at(pos); ++pos; + if (labellen == 0) { --pos; break; } + if (labellen >= 0xc0) { if (!uncompress) { throw std::range_error("Found compressed label, instructed not to follow"); @@ -143,7 +146,8 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start --pos; break; } - if (labellen & 0xc0) { + + if ((labellen & 0xc0) != 0) { throw std::range_error("Found an invalid label length in qname (only one of the first two bits is set)"); } checkLabelLength(labellen); @@ -151,13 +155,13 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start if (totalLength + labellen > s_maxDNSNameLength - 1) { throw std::range_error("name too long to append"); } - if (pos + labellen >= end) { + if (pos + labellen >= view.size()) { throw std::range_error("Found an invalid label length in qname"); } pos += labellen; totalLength += 1 + labellen; } - while (labellen != 0 && pos < end); + while (labellen != 0 && pos < view.size()); if (totalLength != 0) { auto existingSize = d_storage.size(); @@ -168,7 +172,7 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start } d_storage.reserve(existingSize + totalLength + 1); d_storage.resize(existingSize + totalLength); - memcpy(&d_storage.at(existingSize), start, totalLength); + memcpy(&d_storage.at(existingSize), &view.at(initialPos), totalLength); d_storage.append(1, static_cast(0)); } return pos; @@ -177,10 +181,6 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start // this should be the __only__ dns name parser in PowerDNS. void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset) { - const unsigned char* pos=(const unsigned char*)qpos; - unsigned char labellen; - const unsigned char *opos = pos; - if (offset >= len) { throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")"); } @@ -188,13 +188,14 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc if (offset < static_cast(minOffset)) { throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")"); } + unsigned char labellen{0}; - const unsigned char* end = pos + len; - pos += offset; - - pos = parsePacketUncompressed(opos + offset, pos, end, uncompress); + UnsignedCharView view(qpos, len); + auto pos = parsePacketUncompressed(view, offset, uncompress); - if ((labellen=*pos++) && pos < end) { + labellen = view.at(pos); + pos++; + if (labellen != 0 && pos < view.size()) { if (labellen < 0xc0) { abort(); } @@ -204,7 +205,7 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc } labellen &= (~0xc0); - size_t newpos = (labellen << 8) + *pos; + size_t newpos = (labellen << 8) + view.at(pos); if (newpos >= offset) { throw std::range_error("Found a forward reference during label decompression"); @@ -218,31 +219,32 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc throw std::range_error("Abort label decompression after 100 redirects"); } - packetParser(reinterpret_cast(opos), len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset); + packetParser(qpos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset); pos++; } if (d_storage.empty()) { - d_storage.append(1, (char)0); // we just parsed the root + d_storage.append(1, static_cast(0)); // we just parsed the root } if (consumed != nullptr) { - *consumed = pos - opos - offset; + *consumed = pos - offset; } if (qtype != nullptr) { - if (pos + 2 > end) { - throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")"); + if (pos + 2 > view.size()) { + throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")"); } - *qtype = (*pos)*256 + *(pos+1); + *qtype = view.at(pos)*256 + view.at(pos+1); } + pos += 2; if (qclass != nullptr) { - if (pos + 2 > end) { - throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")"); + if (pos + 2 > view.size()) { + throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")"); } - *qclass = (*pos)*256 + *(pos+1); + *qclass = view.at(pos)*256 + view.at(pos+1); } } diff --git a/pdns/dnsname.hh b/pdns/dnsname.hh index 08a2030046..e7a9f4b4ea 100644 --- a/pdns/dnsname.hh +++ b/pdns/dnsname.hh @@ -216,8 +216,28 @@ public: private: string_t d_storage; + class UnsignedCharView + { + public: + UnsignedCharView(const char* data_, size_t size_): view(data_, size_) + { + } + const unsigned char& at(std::string_view::size_type pos) const + { + return reinterpret_cast(view.at(pos)); + } + + size_t size() const + { + return view.size(); + } + + private: + std::string_view view; + }; + void packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset); - const unsigned char* parsePacketUncompressed(const unsigned char* start, const unsigned char* pos, const unsigned char* end, bool uncompress); + size_t parsePacketUncompressed(const UnsignedCharView& view, size_t position, bool uncompress); static void appendEscapedLabel(std::string& appendTo, const char* orig, size_t len); static std::string unescapeLabel(const std::string& orig); static void throwSafeRangeError(const std::string& msg, const char* buf, size_t length);