]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Use the rules factory to unit test them
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 24 Dec 2024 15:59:24 +0000 (16:59 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 16 Jan 2025 08:50:30 +0000 (09:50 +0100)
pdns/dnsdistdist/test-dnsdistrules_cc.cc

index 8c91e3699e187ddbe9276c1b9390b0f2f6168253..692b50dbfb6e1014a32d511063cbfcafb7b3798b 100644 (file)
@@ -8,6 +8,7 @@
 #include <thread>
 #include <boost/test/unit_test.hpp>
 
+#include "dnsdist-rules.hh"
 #include "dnsdist-rules-factory.hh"
 
 void checkParameterBound(const std::string& parameter, uint64_t value, size_t max)
@@ -17,6 +18,74 @@ void checkParameterBound(const std::string& parameter, uint64_t value, size_t ma
   }
 }
 
+struct RuleParameter
+{
+  std::string name;
+  std::variant<unsigned int, std::string> value;
+};
+
+template <typename ParameterType>
+ParameterType getRequiredRuleParameter(const std::string& ruleName, std::vector<RuleParameter>& parameters, const std::string& parameterName)
+{
+  for (auto paramIt = parameters.begin(); paramIt != parameters.end(); ) {
+    if (paramIt->name != parameterName) {
+      ++paramIt;
+      continue;
+    }
+    auto value = std::get<ParameterType>(paramIt->value);
+    parameters.erase(paramIt);
+    return value;
+  }
+
+  throw std::runtime_error("Missing required parameter '" + parameterName + "' for selector '" + ruleName + "'");
+}
+
+template <typename ParameterType>
+ParameterType getOptionalRuleParameter(const std::string& ruleName, std::vector<RuleParameter>& parameters, const std::string& parameterName, ParameterType defaultValue)
+{
+  for (auto paramIt = parameters.begin(); paramIt != parameters.end(); ) {
+    if (paramIt->name != parameterName) {
+      ++paramIt;
+      continue;
+    }
+    auto value = std::get<ParameterType>(paramIt->value);
+    parameters.erase(paramIt);
+    return value;
+  }
+
+  return defaultValue;
+}
+
+class TestMaxQPSIPRule : public DNSRule
+{
+public:
+  TestMaxQPSIPRule(const std::string& ruleName, std::vector<RuleParameter>& parameters):
+    d_qps(getRequiredRuleParameter<unsigned int>(ruleName, parameters, "qps")),
+    d_burst(getOptionalRuleParameter<unsigned int>(ruleName, parameters, "burst", d_qps)),
+    d_ipv4trunc(getOptionalRuleParameter<unsigned int>(ruleName, parameters, "ipv4-truncation", 32))
+  {
+  }
+
+  bool matches(const DNSQuestion* dnsQuestion) const override
+  {
+    return true;
+  }
+
+  string toString() const override
+  {
+    return "";
+  }
+private:
+  unsigned int d_qps;
+  unsigned int d_burst;
+  unsigned int d_ipv4trunc;
+};
+
+static std::shared_ptr<DNSRule> buildSelector(const std::string& type, std::vector<RuleParameter>& parameters)
+{
+  return std::make_shared<TestMaxQPSIPRule>(type, parameters);
+}
+
 static DNSQuestion getDQ(const DNSName* providedName = nullptr)
 {
   static const DNSName qname("powerdns.com.");
@@ -42,7 +111,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   unsigned int expiration = 300;
   unsigned int cleanupDelay = 60;
   unsigned int scanFraction = 10;
-  MaxQPSIPRule rule(maxQPS, maxBurst, 32, 64, expiration, cleanupDelay, scanFraction);
+  auto rule = dnsdist::selectors::getMaxQPSIPSelector(maxQPS, 32, 64, maxBurst, expiration, cleanupDelay, scanFraction, 1);
 
   InternalQueryState ids;
   ids.qname = DNSName("powerdns.com.");
@@ -62,35 +131,35 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
   for (size_t idx = 0; idx < maxQPS; idx++) {
     /* let's use different source ports, it shouldn't matter */
     ids.origRemote = ComboAddress("192.0.2.1:" + std::to_string(idx));
-    BOOST_CHECK_EQUAL(rule.matches(&dq), false);
-    BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1U);
+    BOOST_CHECK_EQUAL(rule->matches(&dq), false);
+    BOOST_CHECK_EQUAL(rule->getEntriesCount(), 1U);
   }
 
   /* maxQPS + 1, we should be blocked */
-  BOOST_CHECK_EQUAL(rule.matches(&dq), true);
-  BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1U);
+  BOOST_CHECK_EQUAL(rule->matches(&dq), true);
+  BOOST_CHECK_EQUAL(rule->getEntriesCount(), 1U);
 
   /* remove all entries that have not been updated since 'now' + 1,
      so all of them */
   expiredTime.tv_sec += 1;
-  rule.cleanup(expiredTime);
+  rule->cleanup(expiredTime);
 
   /* we should have been cleaned up */
-  BOOST_CHECK_EQUAL(rule.getEntriesCount(), 0U);
+  BOOST_CHECK_EQUAL(rule->getEntriesCount(), 0U);
 
   struct timespec beginInsertionTime;
   gettime(&beginInsertionTime);
   /* we should not be blocked anymore */
-  BOOST_CHECK_EQUAL(rule.matches(&dq), false);
+  BOOST_CHECK_EQUAL(rule->matches(&dq), false);
   /* and we be back */
-  BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1U);
+  BOOST_CHECK_EQUAL(rule->getEntriesCount(), 1U);
 
 
   /* Let's insert a lot of different sources now */
   for (size_t idxByte3 = 0; idxByte3 < 256; idxByte3++) {
     for (size_t idxByte4 = 0; idxByte4 < 256; idxByte4++) {
       ids.origRemote = ComboAddress("10.0." + std::to_string(idxByte3) + "." + std::to_string(idxByte4));
-      BOOST_CHECK_EQUAL(rule.matches(&dq), false);
+      BOOST_CHECK_EQUAL(rule->matches(&dq), false);
     }
   }
   struct timespec endInsertionTime;
@@ -98,32 +167,32 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) {
 
   /* don't forget the existing entry */
   size_t total = 1 + 256 * 256;
-  BOOST_CHECK_EQUAL(rule.getEntriesCount(), total);
+  BOOST_CHECK_EQUAL(rule->getEntriesCount(), total);
 
   /* make sure all entries are still valid */
   struct timespec notExpiredTime = beginInsertionTime;
   notExpiredTime.tv_sec -= 1;
 
   size_t scanned = 0;
-  auto removed = rule.cleanup(notExpiredTime, &scanned);
+  auto removed = rule->cleanup(notExpiredTime, &scanned);
   BOOST_CHECK_EQUAL(removed, 0U);
   /* the first entry should still have been valid, we should not have scanned more */
-  BOOST_CHECK_EQUAL(scanned, rule.getNumberOfShards());
-  BOOST_CHECK_EQUAL(rule.getEntriesCount(), total);
+  BOOST_CHECK_EQUAL(scanned, rule->getNumberOfShards());
+  BOOST_CHECK_EQUAL(rule->getEntriesCount(), total);
 
   /* make sure all entries are _not_ valid anymore */
   expiredTime = endInsertionTime;
   expiredTime.tv_sec += 1;
 
-  removed = rule.cleanup(expiredTime, &scanned);
-  BOOST_CHECK_EQUAL(removed, (total / scanFraction) + 1 + rule.getNumberOfShards());
+  removed = rule->cleanup(expiredTime, &scanned);
+  BOOST_CHECK_EQUAL(removed, (total / scanFraction) + 1 + rule->getNumberOfShards());
   /* we should not have scanned more than scanFraction */
   BOOST_CHECK_EQUAL(scanned, removed);
-  BOOST_CHECK_EQUAL(rule.getEntriesCount(), total - removed);
+  BOOST_CHECK_EQUAL(rule->getEntriesCount(), total - removed);
 
-  rule.clear();
-  BOOST_CHECK_EQUAL(rule.getEntriesCount(), 0U);
-  removed = rule.cleanup(expiredTime, &scanned);
+  rule->clear();
+  BOOST_CHECK_EQUAL(rule->getEntriesCount(), 0U);
+  removed = rule->cleanup(expiredTime, &scanned);
   BOOST_CHECK_EQUAL(removed, 0U);
   BOOST_CHECK_EQUAL(scanned, 0U);
 }
@@ -223,6 +292,12 @@ BOOST_AUTO_TEST_CASE(test_payloadSizeRule) {
   }
 
   BOOST_CHECK_THROW(PayloadSizeRule("invalid", 42U), std::runtime_error);
+
+  std::vector<RuleParameter> parameters{
+    RuleParameter{ "qps", 5U },
+    RuleParameter{ "ipv4-truncation", 24U },
+  };
+  auto got = buildSelector("TestMaxQPSIPRule", parameters);
 }
 
 BOOST_AUTO_TEST_SUITE_END()