From: Remi Gacogne Date: Fri, 25 Jul 2025 14:12:39 +0000 (+0200) Subject: dnsdist: Speed up response content matching X-Git-Tag: dnsdist-2.0.1~15^2~5 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0f5831d0348ef3599a09c7894a4005c01e7bc05f;p=thirdparty%2Fpdns.git dnsdist: Speed up response content matching This commit introduces a new method to compare a `DNSName` against a view of raw, wire-format bytes, skipping the allocation and copy that is usually required to get a second `DNSName` object to compare against. This signifitcantly reduces the amount of time matching a DNS response received from a backend against the content we expect to find. Signed-off-by: Remi Gacogne (cherry picked from commit 67eb73850f3141c44963d95ef815fe6a0586d2a8) Signed-off-by: Remi Gacogne --- diff --git a/pdns/dnsdistdist/dnsdist.cc b/pdns/dnsdistdist/dnsdist.cc index 56ced819b..f1a3e6dc5 100644 --- a/pdns/dnsdistdist/dnsdist.cc +++ b/pdns/dnsdistdist/dnsdist.cc @@ -285,12 +285,22 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, return false; } - uint16_t rqtype{}; - uint16_t rqclass{}; - DNSName rqname; try { + uint16_t rqtype{}; + uint16_t rqclass{}; + if (response.size() < (sizeof(dnsheader) + qname.wirelength() + sizeof(rqtype) + sizeof(rqclass))) { + return false; + } + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - rqname = DNSName(reinterpret_cast(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass); + const std::string_view packetView(reinterpret_cast(response.data() + sizeof(dnsheader)), response.size() - sizeof(dnsheader)); + if (qname.matches(packetView)) { + size_t pos = sizeof(dnsheader) + qname.wirelength(); + rqtype = response.at(pos) * 256 + response.at(pos + 1); + rqclass = response.at(pos + 2) * 256 + response.at(pos + 3); + return rqtype == qtype && rqclass == qclass; + } + return false; } catch (const std::exception& e) { if (remote && !response.empty() && static_cast(response.size()) > sizeof(dnsheader)) { @@ -302,8 +312,6 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, } return false; } - - return rqtype == qtype && rqclass == qclass && rqname == qname; } static void restoreFlags(struct dnsheader* dnsHeader, uint16_t origFlags) diff --git a/pdns/dnsname.cc b/pdns/dnsname.cc index 13c5cdb19..5eb6dd8bf 100644 --- a/pdns/dnsname.cc +++ b/pdns/dnsname.cc @@ -732,6 +732,23 @@ bool DNSName::RawLabelsVisitor::empty() const return d_position == 0; } +bool DNSName::matches(const std::string_view& wire_uncompressed) const +{ + if (wire_uncompressed.empty() != empty() || wire_uncompressed.size() < d_storage.size()) { + return false; + } + + const auto* us = d_storage.cbegin(); + const auto* p = wire_uncompressed.cbegin(); + for (; us != d_storage.cend() && p != wire_uncompressed.cend(); ++us, ++p) { + if (dns_tolower(*p) != dns_tolower(*us)) { + return false; + } + } + + return us == d_storage.cend(); +} + #if defined(PDNS_AUTH) // [ std::ostream & operator<<(std::ostream &ostr, const ZoneName& zone) { diff --git a/pdns/dnsname.hh b/pdns/dnsname.hh index 630e9be4d..8d7a9571e 100644 --- a/pdns/dnsname.hh +++ b/pdns/dnsname.hh @@ -112,6 +112,7 @@ public: bool isPartOf(const DNSName& rhs) const; //!< Are we part of the rhs name? Note that name.isPartOf(name). inline bool operator==(const DNSName& rhs) const; //!< DNS-native comparison (case insensitive) - empty compares to empty bool operator!=(const DNSName& other) const { return !(*this == other); } + bool matches(const std::string_view& wire_uncompressed) const; // DNS-native (case insensitive) comparison against raw data in wire format std::string toString(const std::string& separator=".", const bool trailing=true) const; //!< Our human-friendly, escaped, representation void toString(std::string& output, const std::string& separator=".", const bool trailing=true) const; diff --git a/pdns/test-dnsname_cc.cc b/pdns/test-dnsname_cc.cc index a478e59ba..3e7cab573 100644 --- a/pdns/test-dnsname_cc.cc +++ b/pdns/test-dnsname_cc.cc @@ -1032,6 +1032,33 @@ BOOST_AUTO_TEST_CASE(test_getcommonlabels) { BOOST_CHECK_EQUAL(name5.getCommonLabels(name1), DNSName()); } +BOOST_AUTO_TEST_CASE(test_raw_data_comparison) { + const DNSName aroot("a.root-servers.net"); + PacketBuffer query; + GenericDNSPacketWriter packetWriter(query, aroot, QType::A, QClass::IN, 0); + + { + const std::string_view raw(reinterpret_cast(query.data()) + sizeof(dnsheader), query.size() - sizeof(dnsheader)); + BOOST_CHECK(aroot.matches(raw)); + + DNSName differentCase("A.RooT-Servers.NET"); + BOOST_CHECK(differentCase.matches(raw)); + + const DNSName broot("b.root-servers.net"); + BOOST_CHECK(!(broot.matches(raw))); + + /* last character differs */ + const DNSName notaroot("a.root-servers.nes"); + BOOST_CHECK(!(notaroot.matches(raw))); + } + + { + /* too short */ + const std::string_view raw(reinterpret_cast(query.data() + sizeof(dnsheader)), aroot.wirelength() - 1); + BOOST_CHECK(!(aroot.matches(raw))); + } +} + #if defined(PDNS_AUTH) BOOST_AUTO_TEST_CASE(test_variantnames) { ZoneName zone1("..variant");