]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Rewrite of hashQuestion that avoids muliple burtle calls and use safer idiom
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Thu, 26 Aug 2021 13:49:55 +0000 (15:49 +0200)
committerOtto <otto.moerbeek@open-xchange.com>
Fri, 22 Oct 2021 07:40:33 +0000 (09:40 +0200)
pdns/dns.cc
pdns/dns.hh
pdns/pdns_recursor.cc
pdns/test-dnsname_cc.cc

index 03dfff715d7845b2f8e664a5af329cd4e04623c6..01a63d8013154a9ee2a47e361e8a5e645e2e116d 100644 (file)
@@ -80,22 +80,24 @@ std::string Opcode::to_s(uint8_t opcode) {
 }
 
 // goal is to hash based purely on the question name, and turn error into 'default'
-uint32_t hashQuestion(const char* packet, uint16_t len, uint32_t init)
+uint32_t hashQuestion(const uint8_t* packet, uint16_t packet_len, uint32_t init)
 {
-  if(len < 12) 
+  if (packet_len < sizeof(dnsheader)) {
     return init;
-  
-  uint32_t ret=init;
-  const unsigned char* end = (const unsigned char*)packet+len;
-  const unsigned char* pos = (const unsigned char*)packet+12;
+  }
+  // C++ 17 does not have std::u8string_view
+  std::basic_string_view<uint8_t> name(packet + sizeof(dnsheader), packet_len - sizeof(dnsheader));
+  std::basic_string_view<uint8_t>::size_type len = 0;
 
-  unsigned char labellen;
-  while((labellen=*pos++) && pos < end) { 
-    if(pos + labellen + 1 > end) // include length field  in hash
-      return 0;
-    ret=burtleCI(pos, labellen+1, ret);
-    pos += labellen;
+  while (len < name.length()) {
+    uint8_t labellen = name[len++];
+    if (labellen == 0) {
+      // len is name.length() at max as it was < before the increment
+      return burtleCI(name.data(), len, init);
+    }
+    len += labellen;
   }
-  return ret;
+  // We've encountered a label that is too long
+  return init;
 }
 
index 764ea46a810ee6a1ecb3a9076a1e7919b78a0d99..a23ad5ba157992e248c5fbe6520e00e92185eeca 100644 (file)
@@ -236,7 +236,7 @@ inline uint16_t * getFlagsFromDNSHeader(struct dnsheader * dh)
 
 extern time_t s_starttime;
 
-uint32_t hashQuestion(const char* packet, uint16_t len, uint32_t init);
+uint32_t hashQuestion(const uint8_t* packet, uint16_t len, uint32_t init);
 
 struct TSIGTriplet
 {
index 2070bd054a8ed0bb43f828e1e1fb950825f37fe2..977867b69e1fb9d36cba547e9856ba2190503a3f 100644 (file)
@@ -4086,7 +4086,7 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func)
     _exit(1);
   }
 
-  unsigned int hash = hashQuestion(packet.c_str(), packet.length(), g_disthashseed);
+  unsigned int hash = hashQuestion(reinterpret_cast<const uint8_t*>(packet.data()), packet.length(), g_disthashseed);
   unsigned int target = selectWorker(hash);
 
   ThreadMSG* tmsg = new ThreadMSG();
index af1b2571c16987b0403751e965c360110e67af7c..6c88a8baaf5592f750e66fde5ebed31c0da9b9b2 100644 (file)
@@ -390,33 +390,53 @@ BOOST_AUTO_TEST_CASE(test_hashContainer) {
 BOOST_AUTO_TEST_CASE(test_QuestionHash) {
   vector<unsigned char> packet;
   reportBasicTypes();
-  DNSPacketWriter dpw1(packet, DNSName("www.ds9a.nl."), QType::AAAA);
-  
-  auto hash1=hashQuestion((char*)&packet[0], packet.size(), 0);
-  DNSPacketWriter dpw2(packet, DNSName("wWw.Ds9A.nL."), QType::AAAA);
-  auto hash2=hashQuestion((char*)&packet[0], packet.size(), 0);
+
+  // A return init case
+  BOOST_CHECK_EQUAL(hashQuestion(packet.data(), sizeof(dnsheader), 0xffU), 0xffU);
+
+  // We subtract 4 from the packet sizes since DNSPacketWriter adds a type and a class
+  // W eexcpect the hash of the root to bne unequal to the burtle init value
+  DNSPacketWriter dpw0(packet, DNSName("."), QType::AAAA);
+  BOOST_CHECK(hashQuestion(packet.data(), packet.size() - 4, 0xffU) != 0xffU);
+
+  // A truncated buffer
+  DNSPacketWriter dpw1(packet, DNSName("."), QType::AAAA);
+  BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 5, 0xffU), 0xffU);
+
+  DNSPacketWriter dpw2(packet, DNSName("www.ds9a.nl."), QType::AAAA);
+  // Let's make an invalid name by overwriting the length of the second label just outside
+  packet[sizeof(dnsheader) + 4] = 8;
+  BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU), 0xffU);
+
+  DNSPacketWriter dpw3(packet, DNSName("www.ds9a.nl."), QType::AAAA);
+  // Let's make an invalid name by overwriting the length of the second label way outside
+  packet[sizeof(dnsheader) + 4] = 0xff;
+  BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU), 0xffU);
+
+  DNSPacketWriter dpw4(packet, DNSName("www.ds9a.nl."), QType::AAAA);
+  auto hash1 = hashQuestion(&packet[0], packet.size() - 4, 0);
+  DNSPacketWriter dpw5(packet, DNSName("wWw.Ds9A.nL."), QType::AAAA);
+  auto hash2 = hashQuestion(&packet[0], packet.size() - 4, 0);
   BOOST_CHECK_EQUAL(hash1, hash2);
+
   vector<uint32_t> counts(1500);
-  for(unsigned int n=0; n < 100000; ++n) {
+  for(unsigned int n = 0; n < 100000; ++n) {
     packet.clear();
-    DNSPacketWriter dpw3(packet, DNSName(std::to_string(n)+"."+std::to_string(n*2)+"."), QType::AAAA);
-    counts[hashQuestion((char*)&packet[0], packet.size(), 0) % counts.size()]++;
+    DNSPacketWriter dpw(packet, DNSName(std::to_string(n) + "." + std::to_string(n*2) + "."), QType::AAAA);
+    counts[hashQuestion(&packet[0], packet.size() - 4, 0) % counts.size()]++;
   }
-  
+
   double sum = std::accumulate(std::begin(counts), std::end(counts), 0.0);
   double m =  sum / counts.size();
-  
+
   double accum = 0.0;
   std::for_each (std::begin(counts), std::end(counts), [&](const double d) {
       accum += (d - m) * (d - m);
   });
-      
+
   double stdev = sqrt(accum / (counts.size()-1));
-  BOOST_CHECK(stdev < 10);      
+  BOOST_CHECK(stdev < 10);
 }
-  
 
 BOOST_AUTO_TEST_CASE(test_packetParse) {
   vector<unsigned char> packet;