From: Remi Gacogne Date: Wed, 15 Jun 2022 12:15:41 +0000 (+0200) Subject: Implement SuffixMatchTree::getBestMatch() to get the name that matched X-Git-Tag: auth-4.8.0-alpha0~47^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=e11f6994d8b2e9715c0a002137e11aba944baf7a;p=thirdparty%2Fpdns.git Implement SuffixMatchTree::getBestMatch() to get the name that matched --- diff --git a/pdns/dnsname.hh b/pdns/dnsname.hh index d0b655703a..0a1f245051 100644 --- a/pdns/dnsname.hh +++ b/pdns/dnsname.hh @@ -22,6 +22,7 @@ #pragma once #include #include +#include #include #include #include @@ -294,6 +295,8 @@ inline DNSName operator+(const DNSName& lhs, const DNSName& rhs) return ret; } +extern const DNSName g_rootdnsname, g_wildcarddnsname; + template struct SuffixMatchTree { @@ -447,24 +450,60 @@ struct SuffixMatchTree child->remove(labels); } - T* lookup(const DNSName& name) const + T* lookup(const DNSName& name) const + { + auto bestNode = getBestNode(name); + if (bestNode) { + return &bestNode->d_value; + } + return nullptr; + } + + std::optional getBestMatch(const DNSName& name) const + { + if (children.empty()) { // speed up empty set + return endNode ? std::optional(g_rootdnsname) : std::nullopt; + } + + auto visitor = name.getRawLabelsVisitor(); + return getBestMatch(visitor); + } + + // Returns all end-nodes, fully qualified (not as separate labels) + std::vector getNodes() const { + std::vector ret; + if (endNode) { + ret.push_back(DNSName(d_name)); + } + for (const auto& child : children) { + auto nodes = child.getNodes(); + ret.reserve(ret.size() + nodes.size()); + for (const auto &node: nodes) { + ret.push_back(node + DNSName(d_name)); + } + } + return ret; + } + +private: + const SuffixMatchTree* getBestNode(const DNSName& name) const { if (children.empty()) { // speed up empty set if (endNode) { - return &d_value; + return this; } return nullptr; } auto visitor = name.getRawLabelsVisitor(); - return lookup(visitor); + return getBestNode(visitor); } - T* lookup(DNSName::RawLabelsVisitor& visitor) const + const SuffixMatchTree* getBestNode(DNSName::RawLabelsVisitor& visitor) const { if (visitor.empty()) { // optimization if (endNode) { - return &d_value; + return this; } return nullptr; } @@ -472,33 +511,45 @@ struct SuffixMatchTree const LightKey lk{visitor.back()}; auto child = children.find(lk); if (child == children.end()) { - if(endNode) { - return &d_value; + if (endNode) { + return this; } return nullptr; } visitor.pop_back(); - auto result = child->lookup(visitor); + auto result = child->getBestNode(visitor); if (result) { return result; } - return endNode ? &d_value : nullptr; + return endNode ? this : nullptr; } - // Returns all end-nodes, fully qualified (not as separate labels) - std::vector getNodes() const { - std::vector ret; - if (endNode) { - ret.push_back(DNSName(d_name)); + std::optional getBestMatch(DNSName::RawLabelsVisitor& visitor) const + { + if (visitor.empty()) { // optimization + if (endNode) { + return std::optional(d_name); + } + return std::nullopt; } - for (const auto& child : children) { - auto nodes = child.getNodes(); - ret.reserve(ret.size() + nodes.size()); - for (const auto &node: nodes) { - ret.push_back(node + DNSName(d_name)); + + const LightKey lk{visitor.back()}; + auto child = children.find(lk); + if (child == children.end()) { + if (endNode) { + return std::optional(d_name); } + return std::nullopt; } - return ret; + visitor.pop_back(); + auto result = child->getBestMatch(visitor); + if (result) { + if (!d_name.empty()) { + result->appendRawLabel(d_name); + } + return result; + } + return endNode ? std::optional(d_name) : std::nullopt; } }; @@ -555,6 +606,11 @@ struct SuffixMatchNode return d_tree.lookup(dnsname) != nullptr; } + std::optional getBestMatch(const DNSName& name) const + { + return d_tree.getBestMatch(name); + } + std::string toString() const { std::string ret; @@ -599,8 +655,6 @@ bool DNSName::operator==(const DNSName& rhs) const return true; } -extern const DNSName g_rootdnsname, g_wildcarddnsname; - struct DNSNameSet: public std::unordered_set { std::string toString() const { std::ostringstream oss; diff --git a/pdns/test-dnsname_cc.cc b/pdns/test-dnsname_cc.cc index a874649e2c..0001d1ee04 100644 --- a/pdns/test-dnsname_cc.cc +++ b/pdns/test-dnsname_cc.cc @@ -523,14 +523,19 @@ BOOST_AUTO_TEST_CASE(test_suffixmatch) { smn.add(DNSName("news.bbc.co.uk.")); BOOST_CHECK(smn.check(DNSName("news.bbc.co.uk."))); + BOOST_CHECK(smn.getBestMatch(DNSName("news.bbc.co.uk")) == DNSName("news.bbc.co.uk.")); BOOST_CHECK(smn.check(DNSName("www.news.bbc.co.uk."))); + BOOST_CHECK(smn.getBestMatch(DNSName("www.news.bbc.co.uk")) == DNSName("news.bbc.co.uk.")); BOOST_CHECK(smn.check(DNSName("www.www.www.www.www.news.bbc.co.uk."))); BOOST_CHECK(!smn.check(DNSName("images.bbc.co.uk."))); + BOOST_CHECK(smn.getBestMatch(DNSName("images.bbc.co.uk")) == std::nullopt); BOOST_CHECK(!smn.check(DNSName("www.news.gov.uk."))); + BOOST_CHECK(smn.getBestMatch(DNSName("www.news.gov.uk")) == std::nullopt); smn.add(g_rootdnsname); // block the root BOOST_CHECK(smn.check(DNSName("a.root-servers.net."))); + BOOST_CHECK(smn.getBestMatch(DNSName("a.root-servers.net.")) == g_rootdnsname); DNSName examplenet("example.net."); DNSName net("net.");