]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsname: Optimize parsing of uncompressed labels
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 28 Dec 2023 16:07:01 +0000 (17:07 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 23 Jan 2024 11:14:43 +0000 (12:14 +0100)
The gist of this change is to stop allocating and copying per label
when parsing DNSNames from the wire format, as long as we do not
encounter a compression pointer, so that we only allocate and copy
once for as many labels as possible.

This has a noticeable impact in some of our speedtest results:

| Test | Before | After |
| --- | --- | --- |
| 'parse 'empty-query'' | 7282032.6 runs/s, 0.14 us/run | 13519722.8 runs/s, 0.07 us/run |
| 'parse 'empty-query' bare' | 7512588.4 runs/s, 0.13 us/run | 14421770.5 runs/s, 0.07 us/run |
| 'parse 'typical-referral' bare | 917539.2 runs/s, 1.09 us/run | 1151581.7 runs/s, 0.87 us/run |
| 'parse 'typical-referral'' | 626927.3 runs/s, 1.60 us/run | 711754.3 runs/s, 1.40 us/run |

The improvement is quite clear when the number of labels increases:

| Number of labels | Before | After |
| --- | --- | --- |
| 1 | 16280173.9 runs/s, 0.06 us/run | 15798338.6 runs/s, 0.06 us/run |
| 2 | 11591389.8 runs/s, 0.09 us/run | 15677266.9 runs/s, 0.06 us/run |
| 3 | 9008087.9 runs/s, 0.11 us/run | 14705491.1 runs/s, 0.07 us/run |
| 4 | 7391707.9 runs/s, 0.14 us/run | 14368828.1 runs/s, 0.07 us/run |
| 5 | 6172025.9 runs/s, 0.16 us/run | 14326900.3 runs/s, 0.07 us/run |
| 6 | 5396152.4 runs/s, 0.19 us/run | 13585892.7 runs/s, 0.07 us/run |
| 7 | 4763488.4 runs/s, 0.21 us/run | 12824105.9 runs/s, 0.08 us/run |
| 8 | 4323804.8 runs/s, 0.23 us/run | 12494736.6 runs/s, 0.08 us/run |
| 9 | 3877356.8 runs/s, 0.26 us/run | 12308737.6 runs/s, 0.08 us/run |
| ... | ... | ... |
| 127 | 360564.0 runs/s, 2.77 us/run | 2782692.4 runs/s, 0.36 us/run |

pdns/dnsname.cc
pdns/dnsname.hh

index 4534788b0da3704037dc6524bfe0e49be8a83bba..c45dd4f1d9474261f2eb41c92207720bc3848205 100644 (file)
@@ -113,6 +113,67 @@ DNSName::DNSName(const char* pos, size_t len, size_t offset, bool uncompress, ui
   packetParser(pos, len, offset, uncompress, qtype, qclass, consumed, 0, minOffset);
 }
 
+static void checkLabelLength(uint8_t length)
+{
+  if (length == 0) {
+    throw std::range_error("no such thing as an empty label to append");
+  }
+  if (length > 63) {
+    throw std::range_error("label too long to append");
+  }
+}
+
+// this parses a DNS name until a compression pointer is found
+const unsigned char* DNSName::parsePacketUncompressed(const unsigned char* start, const unsigned char* pos, const unsigned char* end, bool uncompress)
+{
+  size_t totalLength = 0;
+  unsigned char labellen = 0;
+
+  do {
+    labellen = *pos;
+    ++pos;
+    if (labellen == 0) {
+      --pos;
+      break;
+    }
+    if (labellen >= 0xc0) {
+      if (!uncompress) {
+        throw std::range_error("Found compressed label, instructed not to follow");
+      }
+      --pos;
+      break;
+    }
+    if (labellen & 0xc0) {
+      throw std::range_error("Found an invalid label length in qname (only one of the first two bits is set)");
+    }
+    checkLabelLength(labellen);
+    // reserve one byte for the label length
+    if (totalLength + labellen > s_maxDNSNameLength - 1) {
+      throw std::range_error("name too long to append");
+    }
+    if (pos + labellen >= end) {
+      throw std::range_error("Found an invalid label length in qname");
+    }
+    pos += labellen;
+    totalLength += 1 + labellen;
+  }
+  while (labellen != 0 && pos < end);
+
+  if (totalLength != 0) {
+    auto existingSize = d_storage.size();
+    if (existingSize > 0) {
+      // remove the last label count, we are about to override it */
+      --existingSize;
+      d_storage.resize(existingSize);
+    }
+    d_storage.reserve(existingSize + totalLength + 1);
+    d_storage.resize(existingSize + totalLength);
+    memcpy(&d_storage.at(existingSize), start, totalLength);
+    d_storage.append(1, static_cast<char>(0));
+  }
+  return pos;
+}
+
 // this should be the __only__ dns name parser in PowerDNS.
 void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset)
 {
@@ -123,62 +184,65 @@ void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool unc
   if (offset >= len) {
     throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")");
   }
+
   if (offset < static_cast<size_t>(minOffset)) {
     throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")");
   }
 
   const unsigned char* end = pos + len;
   pos += offset;
-  while((labellen=*pos++) && pos < end) { // "scan and copy"
-    if(labellen >= 0xc0) {
-      if(!uncompress)
-        throw std::range_error("Found compressed label, instructed not to follow");
 
-      labellen &= (~0xc0);
-      size_t newpos = (labellen << 8) + *(const unsigned char*)pos;
+  pos = parsePacketUncompressed(opos + offset, pos, end, uncompress);
 
-      if (newpos < offset) {
-        if (newpos < minOffset) {
-          throw std::range_error("Invalid label position during decompression ("+std::to_string(newpos)+ " < "+std::to_string(minOffset)+")");
-        }
-        if (++depth > 100) {
-          throw std::range_error("Abort label decompression after 100 redirects");
-        }
-        packetParser((const char*)opos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset);
-      } else {
-        throw std::range_error("Found a forward reference during label decompression");
-      }
-      pos++;
-      break;
-    } else if(labellen & 0xc0) {
-      throw std::range_error("Found an invalid label length in qname (only one of the first two bits is set)");
+  if ((labellen=*pos++) && pos < end) {
+    if (labellen < 0xc0) {
+      abort();
     }
-    if (pos + labellen < end) {
-      appendRawLabel((const char*)pos, labellen);
+
+    if (!uncompress) {
+      throw std::range_error("Found compressed label, instructed not to follow");
     }
-    else {
-      throw std::range_error("Found an invalid label length in qname");
+
+    labellen &= (~0xc0);
+    size_t newpos = (labellen << 8) + *pos;
+
+    if (newpos >= offset) {
+      throw std::range_error("Found a forward reference during label decompression");
+    }
+
+    if (newpos < minOffset) {
+      throw std::range_error("Invalid label position during decompression ("+std::to_string(newpos)+ " < "+std::to_string(minOffset)+")");
     }
-    pos+=labellen;
+
+    if (++depth > 100) {
+      throw std::range_error("Abort label decompression after 100 redirects");
+    }
+
+    packetParser(reinterpret_cast<const char*>(opos), len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset);
+
+    pos++;
   }
+
   if (d_storage.empty()) {
     d_storage.append(1, (char)0); // we just parsed the root
   }
+
   if (consumed != nullptr) {
     *consumed = pos - opos - offset;
   }
+
   if (qtype != nullptr) {
     if (pos + 2 > end) {
       throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")");
     }
-    *qtype=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1);
+    *qtype = (*pos)*256 + *(pos+1);
   }
-  pos+=2;
-  if(qclass) {
+  pos += 2;
+  if (qclass != nullptr) {
     if (pos + 2 > end) {
       throw std::range_error("Trying to read qclass past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")");
     }
-    *qclass=(*(const unsigned char*)pos)*256 + *((const unsigned char*)pos+1);
+    *qclass = (*pos)*256 + *(pos+1);
   }
 }
 
@@ -294,12 +358,12 @@ DNSName DNSName::makeRelative(const DNSName& zone) const
   return ret;
 }
 
-void DNSName::makeUsRelative(const DNSName& zone) 
+void DNSName::makeUsRelative(const DNSName& zone)
 {
   if (isPartOf(zone)) {
     d_storage.erase(d_storage.size()-zone.d_storage.size());
     d_storage.append(1, (char)0); // put back the trailing 0
-  } 
+  }
   else {
     clear();
   }
@@ -355,20 +419,19 @@ void DNSName::appendRawLabel(const std::string& label)
 
 void DNSName::appendRawLabel(const char* start, unsigned int length)
 {
-  if (length==0) {
-    throw std::range_error("no such thing as an empty label to append");
-  }
-  if (length > 63) {
-    throw std::range_error("label too long to append");
-  }
-  if (d_storage.size() + length > s_maxDNSNameLength - 1) { // reserve one byte for the label length
+  checkLabelLength(length);
+
+  // reserve one byte for the label length
+  if (d_storage.size() + length > s_maxDNSNameLength - 1) {
     throw std::range_error("name too long to append");
   }
 
   if (d_storage.empty()) {
+    d_storage.reserve(1 + length + 1);
     d_storage.append(1, (char)length);
   }
   else {
+    d_storage.reserve(d_storage.size() + length + 1);
     *d_storage.rbegin()=(char)length;
   }
   d_storage.append(start, length);
@@ -377,26 +440,27 @@ void DNSName::appendRawLabel(const char* start, unsigned int length)
 
 void DNSName::prependRawLabel(const std::string& label)
 {
-  if (label.empty()) {
-    throw std::range_error("no such thing as an empty label to prepend");
-  }
-  if (label.size() > 63) {
-    throw std::range_error("label too long to prepend");
-  }
-  if (d_storage.size() + label.size() > s_maxDNSNameLength - 1) { // reserve one byte for the label length
+  checkLabelLength(label.size());
+
+  // reserve one byte for the label length
+  if (d_storage.size() + label.size() > s_maxDNSNameLength - 1) {
     throw std::range_error("name too long to prepend");
   }
 
   if (d_storage.empty()) {
+    d_storage.reserve(1 + label.size() + 1);
     d_storage.append(1, (char)0);
   }
+  else {
+    d_storage.reserve(d_storage.size() + 1 + label.size());
+  }
 
   string_t prep(1, (char)label.size());
   prep.append(label.c_str(), label.size());
   d_storage = prep+d_storage;
 }
 
-bool DNSName::slowCanonCompare(const DNSName& rhs) const 
+bool DNSName::slowCanonCompare(const DNSName& rhs) const
 {
   auto ours=getRawLabels(), rhsLabels = rhs.getRawLabels();
   return std::lexicographical_compare(ours.rbegin(), ours.rend(), rhsLabels.rbegin(), rhsLabels.rend(), CIStringCompare());
index 64c14732fe7c21d2bcc1f7381ad90c46e8baaa80..08a2030046aba499daffb7508d3b421e668d74a7 100644 (file)
@@ -217,6 +217,7 @@ private:
   string_t d_storage;
 
   void packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset);
+  const unsigned char* parsePacketUncompressed(const unsigned char* start, const unsigned char* pos, const unsigned char* end, bool uncompress);
   static void appendEscapedLabel(std::string& appendTo, const char* orig, size_t len);
   static std::string unescapeLabel(const std::string& orig);
   static void throwSafeRangeError(const std::string& msg, const char* buf, size_t length);