From: Remi Gacogne Date: Tue, 24 Dec 2024 15:59:24 +0000 (+0100) Subject: dnsdist: Use the rules factory to unit test them X-Git-Tag: dnsdist-2.0.0-alpha1~160^2~28 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=64e3851e213c23f9f1a870d8272f22e17fe98d67;p=thirdparty%2Fpdns.git dnsdist: Use the rules factory to unit test them --- diff --git a/pdns/dnsdistdist/test-dnsdistrules_cc.cc b/pdns/dnsdistdist/test-dnsdistrules_cc.cc index 8c91e3699e..692b50dbfb 100644 --- a/pdns/dnsdistdist/test-dnsdistrules_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistrules_cc.cc @@ -8,6 +8,7 @@ #include #include +#include "dnsdist-rules.hh" #include "dnsdist-rules-factory.hh" void checkParameterBound(const std::string& parameter, uint64_t value, size_t max) @@ -17,6 +18,74 @@ void checkParameterBound(const std::string& parameter, uint64_t value, size_t ma } } +struct RuleParameter +{ + std::string name; + std::variant value; +}; + +template +ParameterType getRequiredRuleParameter(const std::string& ruleName, std::vector& parameters, const std::string& parameterName) +{ + for (auto paramIt = parameters.begin(); paramIt != parameters.end(); ) { + if (paramIt->name != parameterName) { + ++paramIt; + continue; + } + auto value = std::get(paramIt->value); + parameters.erase(paramIt); + return value; + } + + throw std::runtime_error("Missing required parameter '" + parameterName + "' for selector '" + ruleName + "'"); +} + +template +ParameterType getOptionalRuleParameter(const std::string& ruleName, std::vector& parameters, const std::string& parameterName, ParameterType defaultValue) +{ + for (auto paramIt = parameters.begin(); paramIt != parameters.end(); ) { + if (paramIt->name != parameterName) { + ++paramIt; + continue; + } + auto value = std::get(paramIt->value); + parameters.erase(paramIt); + return value; + } + + return defaultValue; +} + +class TestMaxQPSIPRule : public DNSRule +{ +public: + TestMaxQPSIPRule(const std::string& ruleName, std::vector& parameters): + d_qps(getRequiredRuleParameter(ruleName, parameters, "qps")), + d_burst(getOptionalRuleParameter(ruleName, parameters, "burst", d_qps)), + d_ipv4trunc(getOptionalRuleParameter(ruleName, parameters, "ipv4-truncation", 32)) + { + } + + bool matches(const DNSQuestion* dnsQuestion) const override + { + return true; + } + + string toString() const override + { + return ""; + } +private: + unsigned int d_qps; + unsigned int d_burst; + unsigned int d_ipv4trunc; +}; + +static std::shared_ptr buildSelector(const std::string& type, std::vector& parameters) +{ + return std::make_shared(type, parameters); +} + static DNSQuestion getDQ(const DNSName* providedName = nullptr) { static const DNSName qname("powerdns.com."); @@ -42,7 +111,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { unsigned int expiration = 300; unsigned int cleanupDelay = 60; unsigned int scanFraction = 10; - MaxQPSIPRule rule(maxQPS, maxBurst, 32, 64, expiration, cleanupDelay, scanFraction); + auto rule = dnsdist::selectors::getMaxQPSIPSelector(maxQPS, 32, 64, maxBurst, expiration, cleanupDelay, scanFraction, 1); InternalQueryState ids; ids.qname = DNSName("powerdns.com."); @@ -62,35 +131,35 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { for (size_t idx = 0; idx < maxQPS; idx++) { /* let's use different source ports, it shouldn't matter */ ids.origRemote = ComboAddress("192.0.2.1:" + std::to_string(idx)); - BOOST_CHECK_EQUAL(rule.matches(&dq), false); - BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1U); + BOOST_CHECK_EQUAL(rule->matches(&dq), false); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), 1U); } /* maxQPS + 1, we should be blocked */ - BOOST_CHECK_EQUAL(rule.matches(&dq), true); - BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1U); + BOOST_CHECK_EQUAL(rule->matches(&dq), true); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), 1U); /* remove all entries that have not been updated since 'now' + 1, so all of them */ expiredTime.tv_sec += 1; - rule.cleanup(expiredTime); + rule->cleanup(expiredTime); /* we should have been cleaned up */ - BOOST_CHECK_EQUAL(rule.getEntriesCount(), 0U); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), 0U); struct timespec beginInsertionTime; gettime(&beginInsertionTime); /* we should not be blocked anymore */ - BOOST_CHECK_EQUAL(rule.matches(&dq), false); + BOOST_CHECK_EQUAL(rule->matches(&dq), false); /* and we be back */ - BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1U); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), 1U); /* Let's insert a lot of different sources now */ for (size_t idxByte3 = 0; idxByte3 < 256; idxByte3++) { for (size_t idxByte4 = 0; idxByte4 < 256; idxByte4++) { ids.origRemote = ComboAddress("10.0." + std::to_string(idxByte3) + "." + std::to_string(idxByte4)); - BOOST_CHECK_EQUAL(rule.matches(&dq), false); + BOOST_CHECK_EQUAL(rule->matches(&dq), false); } } struct timespec endInsertionTime; @@ -98,32 +167,32 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { /* don't forget the existing entry */ size_t total = 1 + 256 * 256; - BOOST_CHECK_EQUAL(rule.getEntriesCount(), total); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), total); /* make sure all entries are still valid */ struct timespec notExpiredTime = beginInsertionTime; notExpiredTime.tv_sec -= 1; size_t scanned = 0; - auto removed = rule.cleanup(notExpiredTime, &scanned); + auto removed = rule->cleanup(notExpiredTime, &scanned); BOOST_CHECK_EQUAL(removed, 0U); /* the first entry should still have been valid, we should not have scanned more */ - BOOST_CHECK_EQUAL(scanned, rule.getNumberOfShards()); - BOOST_CHECK_EQUAL(rule.getEntriesCount(), total); + BOOST_CHECK_EQUAL(scanned, rule->getNumberOfShards()); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), total); /* make sure all entries are _not_ valid anymore */ expiredTime = endInsertionTime; expiredTime.tv_sec += 1; - removed = rule.cleanup(expiredTime, &scanned); - BOOST_CHECK_EQUAL(removed, (total / scanFraction) + 1 + rule.getNumberOfShards()); + removed = rule->cleanup(expiredTime, &scanned); + BOOST_CHECK_EQUAL(removed, (total / scanFraction) + 1 + rule->getNumberOfShards()); /* we should not have scanned more than scanFraction */ BOOST_CHECK_EQUAL(scanned, removed); - BOOST_CHECK_EQUAL(rule.getEntriesCount(), total - removed); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), total - removed); - rule.clear(); - BOOST_CHECK_EQUAL(rule.getEntriesCount(), 0U); - removed = rule.cleanup(expiredTime, &scanned); + rule->clear(); + BOOST_CHECK_EQUAL(rule->getEntriesCount(), 0U); + removed = rule->cleanup(expiredTime, &scanned); BOOST_CHECK_EQUAL(removed, 0U); BOOST_CHECK_EQUAL(scanned, 0U); } @@ -223,6 +292,12 @@ BOOST_AUTO_TEST_CASE(test_payloadSizeRule) { } BOOST_CHECK_THROW(PayloadSizeRule("invalid", 42U), std::runtime_error); + + std::vector parameters{ + RuleParameter{ "qps", 5U }, + RuleParameter{ "ipv4-truncation", 24U }, + }; + auto got = buildSelector("TestMaxQPSIPRule", parameters); } BOOST_AUTO_TEST_SUITE_END()