From: Otto Moerbeek Date: Thu, 26 Aug 2021 13:49:55 +0000 (+0200) Subject: Rewrite of hashQuestion that avoids muliple burtle calls and use safer idiom X-Git-Tag: rec-4.6.0-beta1~45^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9ce75916d4c343e40ba780b4ef389eaa61178438;p=thirdparty%2Fpdns.git Rewrite of hashQuestion that avoids muliple burtle calls and use safer idiom --- diff --git a/pdns/dns.cc b/pdns/dns.cc index 03dfff715d..01a63d8013 100644 --- a/pdns/dns.cc +++ b/pdns/dns.cc @@ -80,22 +80,24 @@ std::string Opcode::to_s(uint8_t opcode) { } // goal is to hash based purely on the question name, and turn error into 'default' -uint32_t hashQuestion(const char* packet, uint16_t len, uint32_t init) +uint32_t hashQuestion(const uint8_t* packet, uint16_t packet_len, uint32_t init) { - if(len < 12) + if (packet_len < sizeof(dnsheader)) { return init; - - uint32_t ret=init; - const unsigned char* end = (const unsigned char*)packet+len; - const unsigned char* pos = (const unsigned char*)packet+12; + } + // C++ 17 does not have std::u8string_view + std::basic_string_view name(packet + sizeof(dnsheader), packet_len - sizeof(dnsheader)); + std::basic_string_view::size_type len = 0; - unsigned char labellen; - while((labellen=*pos++) && pos < end) { - if(pos + labellen + 1 > end) // include length field in hash - return 0; - ret=burtleCI(pos, labellen+1, ret); - pos += labellen; + while (len < name.length()) { + uint8_t labellen = name[len++]; + if (labellen == 0) { + // len is name.length() at max as it was < before the increment + return burtleCI(name.data(), len, init); + } + len += labellen; } - return ret; + // We've encountered a label that is too long + return init; } diff --git a/pdns/dns.hh b/pdns/dns.hh index 764ea46a81..a23ad5ba15 100644 --- a/pdns/dns.hh +++ b/pdns/dns.hh @@ -236,7 +236,7 @@ inline uint16_t * getFlagsFromDNSHeader(struct dnsheader * dh) extern time_t s_starttime; -uint32_t hashQuestion(const char* packet, uint16_t len, uint32_t init); +uint32_t hashQuestion(const uint8_t* packet, uint16_t len, uint32_t init); struct TSIGTriplet { diff --git a/pdns/pdns_recursor.cc b/pdns/pdns_recursor.cc index 2070bd054a..977867b69e 100644 --- a/pdns/pdns_recursor.cc +++ b/pdns/pdns_recursor.cc @@ -4086,7 +4086,7 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func) _exit(1); } - unsigned int hash = hashQuestion(packet.c_str(), packet.length(), g_disthashseed); + unsigned int hash = hashQuestion(reinterpret_cast(packet.data()), packet.length(), g_disthashseed); unsigned int target = selectWorker(hash); ThreadMSG* tmsg = new ThreadMSG(); diff --git a/pdns/test-dnsname_cc.cc b/pdns/test-dnsname_cc.cc index af1b2571c1..6c88a8baaf 100644 --- a/pdns/test-dnsname_cc.cc +++ b/pdns/test-dnsname_cc.cc @@ -390,33 +390,53 @@ BOOST_AUTO_TEST_CASE(test_hashContainer) { BOOST_AUTO_TEST_CASE(test_QuestionHash) { vector packet; reportBasicTypes(); - DNSPacketWriter dpw1(packet, DNSName("www.ds9a.nl."), QType::AAAA); - - auto hash1=hashQuestion((char*)&packet[0], packet.size(), 0); - DNSPacketWriter dpw2(packet, DNSName("wWw.Ds9A.nL."), QType::AAAA); - auto hash2=hashQuestion((char*)&packet[0], packet.size(), 0); + + // A return init case + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), sizeof(dnsheader), 0xffU), 0xffU); + + // We subtract 4 from the packet sizes since DNSPacketWriter adds a type and a class + // W eexcpect the hash of the root to bne unequal to the burtle init value + DNSPacketWriter dpw0(packet, DNSName("."), QType::AAAA); + BOOST_CHECK(hashQuestion(packet.data(), packet.size() - 4, 0xffU) != 0xffU); + + // A truncated buffer + DNSPacketWriter dpw1(packet, DNSName("."), QType::AAAA); + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 5, 0xffU), 0xffU); + + DNSPacketWriter dpw2(packet, DNSName("www.ds9a.nl."), QType::AAAA); + // Let's make an invalid name by overwriting the length of the second label just outside + packet[sizeof(dnsheader) + 4] = 8; + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU), 0xffU); + + DNSPacketWriter dpw3(packet, DNSName("www.ds9a.nl."), QType::AAAA); + // Let's make an invalid name by overwriting the length of the second label way outside + packet[sizeof(dnsheader) + 4] = 0xff; + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU), 0xffU); + + DNSPacketWriter dpw4(packet, DNSName("www.ds9a.nl."), QType::AAAA); + auto hash1 = hashQuestion(&packet[0], packet.size() - 4, 0); + DNSPacketWriter dpw5(packet, DNSName("wWw.Ds9A.nL."), QType::AAAA); + auto hash2 = hashQuestion(&packet[0], packet.size() - 4, 0); BOOST_CHECK_EQUAL(hash1, hash2); - + vector counts(1500); - - for(unsigned int n=0; n < 100000; ++n) { + for(unsigned int n = 0; n < 100000; ++n) { packet.clear(); - DNSPacketWriter dpw3(packet, DNSName(std::to_string(n)+"."+std::to_string(n*2)+"."), QType::AAAA); - counts[hashQuestion((char*)&packet[0], packet.size(), 0) % counts.size()]++; + DNSPacketWriter dpw(packet, DNSName(std::to_string(n) + "." + std::to_string(n*2) + "."), QType::AAAA); + counts[hashQuestion(&packet[0], packet.size() - 4, 0) % counts.size()]++; } - + double sum = std::accumulate(std::begin(counts), std::end(counts), 0.0); double m = sum / counts.size(); - + double accum = 0.0; std::for_each (std::begin(counts), std::end(counts), [&](const double d) { accum += (d - m) * (d - m); }); - + double stdev = sqrt(accum / (counts.size()-1)); - BOOST_CHECK(stdev < 10); + BOOST_CHECK(stdev < 10); } - BOOST_AUTO_TEST_CASE(test_packetParse) { vector packet;