From: Otto Date: Fri, 22 Oct 2021 13:14:04 +0000 (+0200) Subject: Detect a malformed question early so we can drop it instead of letting the X-Git-Tag: auth-4.7.0-alpha1~14^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=538c11f9a55a3d88c907fcd07db57bda75dd956d;p=thirdparty%2Fpdns.git Detect a malformed question early so we can drop it instead of letting the worker do that --- diff --git a/pdns/dns.cc b/pdns/dns.cc index 01a63d8013..12689d9831 100644 --- a/pdns/dns.cc +++ b/pdns/dns.cc @@ -80,9 +80,10 @@ std::string Opcode::to_s(uint8_t opcode) { } // goal is to hash based purely on the question name, and turn error into 'default' -uint32_t hashQuestion(const uint8_t* packet, uint16_t packet_len, uint32_t init) +uint32_t hashQuestion(const uint8_t* packet, uint16_t packet_len, uint32_t init, bool& ok) { if (packet_len < sizeof(dnsheader)) { + ok = false; return init; } // C++ 17 does not have std::u8string_view @@ -92,12 +93,14 @@ uint32_t hashQuestion(const uint8_t* packet, uint16_t packet_len, uint32_t init) while (len < name.length()) { uint8_t labellen = name[len++]; if (labellen == 0) { + ok = true; // len is name.length() at max as it was < before the increment return burtleCI(name.data(), len, init); } len += labellen; } // We've encountered a label that is too long + ok = false; return init; } diff --git a/pdns/dns.hh b/pdns/dns.hh index a23ad5ba15..f23204114a 100644 --- a/pdns/dns.hh +++ b/pdns/dns.hh @@ -236,7 +236,7 @@ inline uint16_t * getFlagsFromDNSHeader(struct dnsheader * dh) extern time_t s_starttime; -uint32_t hashQuestion(const uint8_t* packet, uint16_t len, uint32_t init); +uint32_t hashQuestion(const uint8_t* packet, uint16_t len, uint32_t init, bool& ok); struct TSIGTriplet { diff --git a/pdns/pdns_recursor.cc b/pdns/pdns_recursor.cc index 9014af9ca7..0cf831af40 100644 --- a/pdns/pdns_recursor.cc +++ b/pdns/pdns_recursor.cc @@ -2396,7 +2396,13 @@ void distributeAsyncFunction(const string& packet, const pipefunc_t& func) _exit(1); } - unsigned int hash = hashQuestion(reinterpret_cast(packet.data()), packet.length(), g_disthashseed); + bool ok; + unsigned int hash = hashQuestion(reinterpret_cast(packet.data()), packet.length(), g_disthashseed, ok); + if (!ok) { + // hashQuestion does detect invalid names, so we might as well punt here instead of in the worker thread + g_stats.ignoredCount++; + throw MOADNSException("too-short (" + std::to_string(packet.length()) + ") or invalid name"); + } unsigned int target = selectWorker(hash); ThreadMSG* tmsg = new ThreadMSG(); diff --git a/pdns/test-dnsname_cc.cc b/pdns/test-dnsname_cc.cc index 98aed701bf..f1fca240c9 100644 --- a/pdns/test-dnsname_cc.cc +++ b/pdns/test-dnsname_cc.cc @@ -391,39 +391,48 @@ BOOST_AUTO_TEST_CASE(test_QuestionHash) { vector packet(sizeof(dnsheader)); reportBasicTypes(); + bool ok; // A return init case - BOOST_CHECK_EQUAL(hashQuestion(packet.data(), sizeof(dnsheader), 0xffU), 0xffU); + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), sizeof(dnsheader), 0xffU, ok), 0xffU); + BOOST_CHECK(!ok); // We subtract 4 from the packet sizes since DNSPacketWriter adds a type and a class // We expect the hash of the root to be unequal to the burtle init value DNSPacketWriter dpw0(packet, DNSName("."), QType::AAAA); - BOOST_CHECK(hashQuestion(packet.data(), packet.size() - 4, 0xffU) != 0xffU); + BOOST_CHECK(hashQuestion(packet.data(), packet.size() - 4, 0xffU, ok) != 0xffU); + BOOST_CHECK(ok); // A truncated buffer should return the init value DNSPacketWriter dpw1(packet, DNSName("."), QType::AAAA); - BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 5, 0xffU), 0xffU); + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 5, 0xffU, ok), 0xffU); + BOOST_CHECK(!ok); DNSPacketWriter dpw2(packet, DNSName("www.ds9a.nl."), QType::AAAA); // Let's make an invalid name by overwriting the length of the second label just outside the buffer packet[sizeof(dnsheader) + 4] = 8; - BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU), 0xffU); + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU, ok), 0xffU); + BOOST_CHECK(!ok); DNSPacketWriter dpw3(packet, DNSName("www.ds9a.nl."), QType::AAAA); // Let's make an invalid name by overwriting the length of the second label way outside the buffer packet[sizeof(dnsheader) + 4] = 0xff; - BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU), 0xffU); + BOOST_CHECK_EQUAL(hashQuestion(packet.data(), packet.size() - 4, 0xffU, ok), 0xffU); + BOOST_CHECK(!ok); DNSPacketWriter dpw4(packet, DNSName("www.ds9a.nl."), QType::AAAA); - auto hash1 = hashQuestion(&packet[0], packet.size() - 4, 0); + auto hash1 = hashQuestion(&packet[0], packet.size() - 4, 0, ok); + BOOST_CHECK(ok); DNSPacketWriter dpw5(packet, DNSName("wWw.Ds9A.nL."), QType::AAAA); - auto hash2 = hashQuestion(&packet[0], packet.size() - 4, 0); + auto hash2 = hashQuestion(&packet[0], packet.size() - 4, 0, ok); BOOST_CHECK_EQUAL(hash1, hash2); + BOOST_CHECK(ok); vector counts(1500); for(unsigned int n = 0; n < 100000; ++n) { packet.clear(); DNSPacketWriter dpw(packet, DNSName(std::to_string(n) + "." + std::to_string(n*2) + "."), QType::AAAA); - counts[hashQuestion(&packet[0], packet.size() - 4, 0) % counts.size()]++; + BOOST_CHECK(ok); + counts[hashQuestion(&packet[0], packet.size() - 4, 0, ok) % counts.size()]++; } double sum = std::accumulate(std::begin(counts), std::end(counts), 0.0);