]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsname: Use a view instead of pointer arithmetic in DNSName::packetParser()
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 18 Jan 2024 13:00:31 +0000 (14:00 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 23 Jan 2024 11:20:51 +0000 (12:20 +0100)
pdns/dnsname.cc
pdns/dnsname.hh

index c45dd4f1d9474261f2eb41c92207720bc3848205..5e01af3ce11fc835feb3918e01b3f0c2529d6b29 100644 (file)
@@ -124,18 +124,21 @@ static void checkLabelLength(uint8_t length)
 }
 
 // this parses a DNS name until a compression pointer is found
-const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start, const unsigned char* pos, const unsigned char* end, bool uncompress)
+size_t DNSName::parsePacketUncompressed(const UnsignedCharView& view, size_t pos, bool uncompress)
 {
+  const size_t initialPos = pos;
   size_t totalLength = 0;
   unsigned char labellen = 0;
 
   do {
-    labellen = *pos;
+    labellen = view.at(pos);
     ++pos;
+
     if (labellen == 0) {
       --pos;
       break;
     }
+
     if (labellen >= 0xc0) {
       if (!uncompress) {
         throw std::range_error("Found compressed label, instructed not to follow");
@@ -143,7 +146,8 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start
       --pos;
       break;
     }
-    if (labellen & 0xc0) {
+
+    if ((labellen & 0xc0) != 0) {
       throw std::range_error("Found an invalid label length in qname (only one of the first two bits is set)");
     }
     checkLabelLength(labellen);
@@ -151,13 +155,13 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start
     if (totalLength + labellen > s_maxDNSNameLength - 1) {
       throw std::range_error("name too long to append");
     }
-    if (pos + labellen >= end) {
+    if (pos + labellen >= view.size()) {
       throw std::range_error("Found an invalid label length in qname");
     }
     pos += labellen;
     totalLength += 1 + labellen;
   }
-  while (labellen != 0 && pos < end);
+  while (labellen != 0 && pos < view.size());
 
   if (totalLength != 0) {
     auto existingSize = d_storage.size();
@@ -168,7 +172,7 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start
     }
     d_storage.reserve(existingSize + totalLength + 1);
     d_storage.resize(existingSize + totalLength);
-    memcpy(&d_storage.at(existingSize), start, totalLength);
+    memcpy(&d_storage.at(existingSize), &view.at(initialPos), totalLength);
     d_storage.append(1, static_cast<char>(0));
   }
   return pos;
@@ -177,10 +181,6 @@ const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start
 // this should be the __only__ dns name parser in PowerDNS.
 void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset)
 {
-  const unsigned char* pos=(const unsigned char*)qpos;
-  unsigned char labellen;
-  const unsigned char *opos = pos;
-
   if (offset >= len) {
     throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")");
   }
@@ -188,13 +188,14 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc
   if (offset < static_cast<size_t>(minOffset)) {
     throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")");
   }
+  unsigned char labellen{0};
 
-  const unsigned char* end = pos + len;
-  pos += offset;
-
-  pos = parsePacketUncompressed(opos + offset, pos, end, uncompress);
+  UnsignedCharView view(qpos, len);
+  auto pos = parsePacketUncompressed(view, offset, uncompress);
 
-  if ((labellen=*pos++) && pos < end) {
+  labellen = view.at(pos);
+  pos++;
+  if (labellen != 0 && pos < view.size()) {
     if (labellen < 0xc0) {
       abort();
     }
@@ -204,7 +205,7 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc
     }
 
     labellen &= (~0xc0);
-    size_t newpos = (labellen << 8) + *pos;
+    size_t newpos = (labellen << 8) + view.at(pos);
 
     if (newpos >= offset) {
       throw std::range_error("Found a forward reference during label decompression");
@@ -218,31 +219,32 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc
       throw std::range_error("Abort label decompression after 100 redirects");
     }
 
-    packetParser(reinterpret_cast<const char*>(opos), len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset);
+    packetParser(qpos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset);
 
     pos++;
   }
 
   if (d_storage.empty()) {
-    d_storage.append(1, (char)0); // we just parsed the root
+    d_storage.append(1, static_cast<char>(0)); // we just parsed the root
   }
 
   if (consumed != nullptr) {
-    *consumed = pos - opos - offset;
+    *consumed = pos - offset;
   }
 
   if (qtype != nullptr) {
-    if (pos + 2 > end) {
-      throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")");
+    if (pos + 2 > view.size()) {
+      throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")");
     }
-    *qtype = (*pos)*256 + *(pos+1);
+    *qtype = view.at(pos)*256 + view.at(pos+1);
   }
+
   pos += 2;
   if (qclass != nullptr) {
-    if (pos + 2 > end) {
-      throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")");
+    if (pos + 2 > view.size()) {
+      throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string(pos + 2)+ " > "+std::to_string(len)+")");
     }
-    *qclass = (*pos)*256 + *(pos+1);
+    *qclass = view.at(pos)*256 + view.at(pos+1);
   }
 }
 
index 08a2030046aba499daffb7508d3b421e668d74a7..e7a9f4b4eafd1149386b5a7019a649c16235cca1 100644 (file)
@@ -216,8 +216,28 @@ public:
 private:
   string_t d_storage;
 
+  class UnsignedCharView
+  {
+  public:
+    UnsignedCharView(const char* data_, size_t size_): view(data_, size_)
+    {
+    }
+    const unsigned char& at(std::string_view::size_type pos) const
+    {
+      return reinterpret_cast<const unsigned char&>(view.at(pos));
+    }
+
+    size_t size() const
+    {
+      return view.size();
+    }
+
+  private:
+    std::string_view view;
+  };
+
   void packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset);
-  const unsigned char* parsePacketUncompressed(const unsigned char* start, const unsigned char* pos, const unsigned char* end, bool uncompress);
+  size_t parsePacketUncompressed(const UnsignedCharView& view, size_t position, bool uncompress);
   static void appendEscapedLabel(std::string& appendTo, const char* orig, size_t len);
   static std::string unescapeLabel(const std::string& orig);
   static void throwSafeRangeError(const std::string& msg, const char* buf, size_t length);