]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Implement SuffixMatchTree::getBestMatch() to get the name that matched
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 15 Jun 2022 12:15:41 +0000 (14:15 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 15 Jun 2022 13:35:56 +0000 (15:35 +0200)
pdns/dnsname.hh
pdns/test-dnsname_cc.cc

index d0b655703a4096ff22f9ff0f981048ebb089e7b6..0a1f24505129ecb4fcd1a58b5470d67da6496b52 100644 (file)
@@ -22,6 +22,7 @@
 #pragma once
 #include <array>
 #include <cstring>
+#include <optional>
 #include <string>
 #include <vector>
 #include <set>
@@ -294,6 +295,8 @@ inline DNSName operator+(const DNSName& lhs, const DNSName& rhs)
   return ret;
 }
 
+extern const DNSName g_rootdnsname, g_wildcarddnsname;
+
 template<typename T>
 struct SuffixMatchTree
 {
@@ -447,24 +450,60 @@ struct SuffixMatchTree
     child->remove(labels);
   }
 
-  T* lookup(const DNSName& name)  const
+  T* lookup(const DNSName& name) const
+  {
+    auto bestNode = getBestNode(name);
+    if (bestNode) {
+      return &bestNode->d_value;
+    }
+    return nullptr;
+  }
+
+  std::optional<DNSName> getBestMatch(const DNSName& name) const
+  {
+    if (children.empty()) { // speed up empty set
+      return endNode ? std::optional<DNSName>(g_rootdnsname) : std::nullopt;
+    }
+
+    auto visitor = name.getRawLabelsVisitor();
+    return getBestMatch(visitor);
+  }
+
+  // Returns all end-nodes, fully qualified (not as separate labels)
+  std::vector<DNSName> getNodes() const {
+    std::vector<DNSName> ret;
+    if (endNode) {
+      ret.push_back(DNSName(d_name));
+    }
+    for (const auto& child : children) {
+      auto nodes = child.getNodes();
+      ret.reserve(ret.size() + nodes.size());
+      for (const auto &node: nodes) {
+        ret.push_back(node + DNSName(d_name));
+      }
+    }
+    return ret;
+  }
+
+private:
+  const SuffixMatchTree* getBestNode(const DNSName& name)  const
   {
     if (children.empty()) { // speed up empty set
       if (endNode) {
-        return &d_value;
+        return this;
       }
       return nullptr;
     }
 
     auto visitor = name.getRawLabelsVisitor();
-    return lookup(visitor);
+    return getBestNode(visitor);
   }
 
-  T* lookup(DNSName::RawLabelsVisitor& visitor) const
+  const SuffixMatchTree* getBestNode(DNSName::RawLabelsVisitor& visitor) const
   {
     if (visitor.empty()) { // optimization
       if (endNode) {
-        return &d_value;
+        return this;
       }
       return nullptr;
     }
@@ -472,33 +511,45 @@ struct SuffixMatchTree
     const LightKey lk{visitor.back()};
     auto child = children.find(lk);
     if (child == children.end()) {
-      if(endNode) {
-        return &d_value;
+      if (endNode) {
+        return this;
       }
       return nullptr;
     }
     visitor.pop_back();
-    auto result = child->lookup(visitor);
+    auto result = child->getBestNode(visitor);
     if (result) {
       return result;
     }
-    return endNode ? &d_value : nullptr;
+    return endNode ? this : nullptr;
   }
 
-  // Returns all end-nodes, fully qualified (not as separate labels)
-  std::vector<DNSName> getNodes() const {
-    std::vector<DNSName> ret;
-    if (endNode) {
-      ret.push_back(DNSName(d_name));
+  std::optional<DNSName> getBestMatch(DNSName::RawLabelsVisitor& visitor) const
+  {
+    if (visitor.empty()) { // optimization
+      if (endNode) {
+        return std::optional<DNSName>(d_name);
+      }
+      return std::nullopt;
     }
-    for (const auto& child : children) {
-      auto nodes = child.getNodes();
-      ret.reserve(ret.size() + nodes.size());
-      for (const auto &node: nodes) {
-        ret.push_back(node + DNSName(d_name));
+
+    const LightKey lk{visitor.back()};
+    auto child = children.find(lk);
+    if (child == children.end()) {
+      if (endNode) {
+        return std::optional<DNSName>(d_name);
       }
+      return std::nullopt;
     }
-    return ret;
+    visitor.pop_back();
+    auto result = child->getBestMatch(visitor);
+    if (result) {
+      if (!d_name.empty()) {
+        result->appendRawLabel(d_name);
+      }
+      return result;
+    }
+    return endNode ? std::optional<DNSName>(d_name) : std::nullopt;
   }
 };
 
@@ -555,6 +606,11 @@ struct SuffixMatchNode
       return d_tree.lookup(dnsname) != nullptr;
     }
 
+    std::optional<DNSName> getBestMatch(const DNSName& name) const
+    {
+      return d_tree.getBestMatch(name);
+    }
+
     std::string toString() const
     {
       std::string ret;
@@ -599,8 +655,6 @@ bool DNSName::operator==(const DNSName& rhs) const
   return true;
 }
 
-extern const DNSName g_rootdnsname, g_wildcarddnsname;
-
 struct DNSNameSet: public std::unordered_set<DNSName> {
     std::string toString() const {
         std::ostringstream oss;
index a874649e2cfee752223d866f0ba27f8206fd1895..0001d1ee04f6fe4618f8ace6cd1317ff6de68fa6 100644 (file)
@@ -523,14 +523,19 @@ BOOST_AUTO_TEST_CASE(test_suffixmatch) {
 
   smn.add(DNSName("news.bbc.co.uk."));
   BOOST_CHECK(smn.check(DNSName("news.bbc.co.uk.")));
+  BOOST_CHECK(smn.getBestMatch(DNSName("news.bbc.co.uk")) == DNSName("news.bbc.co.uk."));
   BOOST_CHECK(smn.check(DNSName("www.news.bbc.co.uk.")));
+  BOOST_CHECK(smn.getBestMatch(DNSName("www.news.bbc.co.uk")) == DNSName("news.bbc.co.uk."));
   BOOST_CHECK(smn.check(DNSName("www.www.www.www.www.news.bbc.co.uk.")));
   BOOST_CHECK(!smn.check(DNSName("images.bbc.co.uk.")));
+  BOOST_CHECK(smn.getBestMatch(DNSName("images.bbc.co.uk")) == std::nullopt);
 
   BOOST_CHECK(!smn.check(DNSName("www.news.gov.uk.")));
+  BOOST_CHECK(smn.getBestMatch(DNSName("www.news.gov.uk")) == std::nullopt);
 
   smn.add(g_rootdnsname); // block the root
   BOOST_CHECK(smn.check(DNSName("a.root-servers.net.")));
+  BOOST_CHECK(smn.getBestMatch(DNSName("a.root-servers.net.")) == g_rootdnsname);
 
   DNSName examplenet("example.net.");
   DNSName net("net.");