]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Expose selectors and actions to YAML-originated Lua contexts
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 4 Nov 2025 10:07:46 +0000 (11:07 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 13 Nov 2025 09:42:18 +0000 (10:42 +0100)
Making it possible to use selectors and actions in Lua code
declared in the YAML configuration (inline or not).

Signed-off-by: Remi Gacogne <remi.gacogne@powerdns.com>
pdns/dnsdistdist/dnsdist-lua-actions.cc
pdns/dnsdistdist/dnsdist-lua-rules.cc
pdns/dnsdistdist/dnsdist-lua.cc
pdns/dnsdistdist/dnsdist-lua.hh

index 506ad0bf9bbd7d87be213e13b3e0c286e9b0a290..3a15cd2fe24d8578c096d8db77d44deeb248a4a4 100644 (file)
 #include "remote_logger.hh"
 #include <stdexcept>
 
-template <typename ActionT, typename IdentifierT>
-static void addAction(IdentifierT identifier, const luadnsrule_t& var, const std::shared_ptr<ActionT>& action, boost::optional<luaruleparams_t>& params)
-{
-  setLuaSideEffect();
-
-  std::string name;
-  boost::uuids::uuid uuid{};
-  uint64_t creationOrder = 0;
-  parseRuleParams(params, uuid, name, creationOrder);
-  checkAllParametersConsumed("addAction", params);
-
-  auto rule = makeRule(var, "addAction");
-  dnsdist::configuration::updateRuntimeConfiguration([identifier, &rule, &action, &name, &uuid, creationOrder](dnsdist::configuration::RuntimeConfiguration& config) {
-    dnsdist::rules::add(config.d_ruleChains, identifier, std::move(rule), action, std::move(name), uuid, creationOrder);
-  });
-}
-
 using responseParams_t = std::unordered_map<std::string, boost::variant<bool, uint32_t>>;
 
 static dnsdist::ResponseConfig parseResponseConfig(boost::optional<responseParams_t>& vars)
@@ -92,39 +75,6 @@ void setupLuaActions(LuaContext& luaCtx)
     return std::make_shared<dnsdist::rules::RuleAction>(ruleaction);
   });
 
-  for (const auto& chain : dnsdist::rules::getRuleChainDescriptions()) {
-    auto fullName = std::string("add") + chain.prefix + std::string("Action");
-    luaCtx.writeFunction(fullName, [&fullName, &chain](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) {
-      if (era.type() != typeid(std::shared_ptr<DNSAction>)) {
-        throw std::runtime_error(fullName + "() can only be called with query-related actions, not response-related ones. Are you looking for addResponseAction()?");
-      }
-
-      addAction(chain.identifier, var, boost::get<std::shared_ptr<DNSAction>>(era), params);
-    });
-    fullName = std::string("get") + chain.prefix + std::string("Action");
-    luaCtx.writeFunction(fullName, [&chain](unsigned int num) {
-      setLuaNoSideEffect();
-      boost::optional<std::shared_ptr<DNSAction>> ret;
-      const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains;
-      const auto& ruleactions = dnsdist::rules::getRuleChain(chains, chain.identifier);
-      if (num < ruleactions.size()) {
-        ret = ruleactions[num].d_action;
-      }
-      return ret;
-    });
-  }
-
-  for (const auto& chain : dnsdist::rules::getResponseRuleChainDescriptions()) {
-    const auto fullName = std::string("add") + chain.prefix + std::string("ResponseAction");
-    luaCtx.writeFunction(fullName, [&fullName, &chain](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) {
-      if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) {
-        throw std::runtime_error(fullName + "() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?");
-      }
-
-      addAction(chain.identifier, var, boost::get<std::shared_ptr<DNSResponseAction>>(era), params);
-    });
-  }
-
   luaCtx.registerFunction<void (DNSAction::*)() const>("printStats", [](const DNSAction& action) {
     setLuaNoSideEffect();
     auto stats = action.getStats();
index a12eed23ea74fe6d509a1f74afeb9f94a47e02ab..0643311575a7d068667ce7e2c8f1bb199cd34482 100644 (file)
@@ -276,6 +276,23 @@ static std::vector<T> getTopRules(const std::vector<T>& rules, unsigned int top)
   return results;
 }
 
+template <typename ActionT, typename IdentifierT>
+static void addRule(IdentifierT identifier, const std::string& methodName, const luadnsrule_t& var, const std::shared_ptr<ActionT>& action, boost::optional<luaruleparams_t>& params)
+{
+  setLuaSideEffect();
+
+  std::string name;
+  boost::uuids::uuid uuid{};
+  uint64_t creationOrder = 0;
+  parseRuleParams(params, uuid, name, creationOrder);
+  checkAllParametersConsumed(methodName, params);
+
+  auto rule = makeRule(var, methodName);
+  dnsdist::configuration::updateRuntimeConfiguration([identifier, &rule, &action, &name, &uuid, creationOrder](dnsdist::configuration::RuntimeConfiguration& config) {
+    dnsdist::rules::add(config.d_ruleChains, identifier, std::move(rule), action, std::move(name), uuid, creationOrder);
+  });
+}
+
 template <typename T>
 static LuaArray<T> toLuaArray(std::vector<T>&& rules)
 {
@@ -350,25 +367,8 @@ std::optional<T> boostToStandardOptional(const boost::optional<T>& boostOpt)
 }
 }
 
-// NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
-void setupLuaRules(LuaContext& luaCtx)
+void setupLuaRuleChainsManagement(LuaContext& luaCtx)
 {
-  luaCtx.writeFunction("makeRule", [](const luadnsrule_t& var) -> std::shared_ptr<DNSRule> {
-    return makeRule(var, "makeRule");
-  });
-
-  luaCtx.registerFunction<string (std::shared_ptr<DNSRule>::*)() const>("toString", [](const std::shared_ptr<DNSRule>& rule) { return rule->toString(); });
-
-  luaCtx.registerFunction<uint64_t (std::shared_ptr<DNSRule>::*)() const>("getMatches", [](const std::shared_ptr<DNSRule>& rule) { return rule->d_matches.load(); });
-
-  luaCtx.registerFunction<std::shared_ptr<DNSRule> (dnsdist::rules::RuleAction::*)() const>("getSelector", [](const dnsdist::rules::RuleAction& rule) { return rule.d_rule; });
-
-  luaCtx.registerFunction<std::shared_ptr<DNSAction> (dnsdist::rules::RuleAction::*)() const>("getAction", [](const dnsdist::rules::RuleAction& rule) { return rule.d_action; });
-
-  luaCtx.registerFunction<std::shared_ptr<DNSRule> (dnsdist::rules::ResponseRuleAction::*)() const>("getSelector", [](const dnsdist::rules::ResponseRuleAction& rule) { return rule.d_rule; });
-
-  luaCtx.registerFunction<std::shared_ptr<DNSResponseAction> (dnsdist::rules::ResponseRuleAction::*)() const>("getAction", [](const dnsdist::rules::ResponseRuleAction& rule) { return rule.d_action; });
-
   for (const auto& chain : dnsdist::rules::getResponseRuleChainDescriptions()) {
     luaCtx.writeFunction("show" + chain.prefix + "ResponseRules", [&chain](boost::optional<ruleparams_t> vars) {
       showRules(chain.identifier, vars);
@@ -469,26 +469,38 @@ void setupLuaRules(LuaContext& luaCtx)
     });
   }
 
-  luaCtx.writeFunction("SuffixMatchNodeRule", qnameSuffixRule);
+  for (const auto& chain : dnsdist::rules::getRuleChainDescriptions()) {
+    auto fullName = std::string("add") + chain.prefix + std::string("Action");
+    luaCtx.writeFunction(fullName, [&fullName, &chain](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) {
+      if (era.type() != typeid(std::shared_ptr<DNSAction>)) {
+        throw std::runtime_error(fullName + "() can only be called with query-related actions, not response-related ones. Are you looking for addResponseAction()?");
+      }
 
-  luaCtx.writeFunction("NetmaskGroupRule", [](const boost::variant<const NetmaskGroup&, std::string, const LuaArray<std::string>> netmasks, boost::optional<bool> src, boost::optional<bool> quiet) {
-    if (netmasks.type() == typeid(string)) {
-      NetmaskGroup nmg;
-      nmg.addMask(*boost::get<std::string>(&netmasks));
-      return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true, quiet ? *quiet : false));
-    }
+      addRule(chain.identifier, fullName, var, boost::get<std::shared_ptr<DNSAction>>(era), params);
+    });
+    fullName = std::string("get") + chain.prefix + std::string("Action");
+    luaCtx.writeFunction(fullName, [&chain](unsigned int num) {
+      setLuaNoSideEffect();
+      boost::optional<std::shared_ptr<DNSAction>> ret;
+      const auto& chains = dnsdist::configuration::getCurrentRuntimeConfiguration().d_ruleChains;
+      const auto& ruleactions = dnsdist::rules::getRuleChain(chains, chain.identifier);
+      if (num < ruleactions.size()) {
+        ret = ruleactions[num].d_action;
+      }
+      return ret;
+    });
+  }
 
-    if (netmasks.type() == typeid(LuaArray<std::string>)) {
-      NetmaskGroup nmg;
-      for (const auto& str : *boost::get<const LuaArray<std::string>>(&netmasks)) {
-        nmg.addMask(str.second);
+  for (const auto& chain : dnsdist::rules::getResponseRuleChainDescriptions()) {
+    const auto fullName = std::string("add") + chain.prefix + std::string("ResponseAction");
+    luaCtx.writeFunction(fullName, [&fullName, &chain](const luadnsrule_t& var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) {
+      if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) {
+        throw std::runtime_error(fullName + "() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?");
       }
-      return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true, quiet ? *quiet : false));
-    }
 
-    const auto& nmg = *boost::get<const NetmaskGroup&>(&netmasks);
-    return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true, quiet ? *quiet : false));
-  });
+      addRule(chain.identifier, fullName, var, boost::get<std::shared_ptr<DNSResponseAction>>(era), params);
+    });
+  }
 
   luaCtx.writeFunction("benchRule", [](const std::shared_ptr<DNSRule>& rule, boost::optional<unsigned int> times_, boost::optional<string> suffix_) {
     setLuaNoSideEffect();
@@ -531,6 +543,47 @@ void setupLuaRules(LuaContext& luaCtx)
     double udiff = swatch.udiff();
     g_outputBuffer = (boost::format("Had %d matches out of %d, %.1f qps, in %.1f us\n") % matches % times % (1000000 * (1.0 * times / udiff)) % udiff).str();
   });
+}
+
+// NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
+void setupLuaSelectors(LuaContext& luaCtx)
+{
+  luaCtx.writeFunction("makeRule", [](const luadnsrule_t& var) -> std::shared_ptr<DNSRule> {
+    return makeRule(var, "makeRule");
+  });
+
+  luaCtx.registerFunction<string (std::shared_ptr<DNSRule>::*)() const>("toString", [](const std::shared_ptr<DNSRule>& rule) { return rule->toString(); });
+
+  luaCtx.registerFunction<uint64_t (std::shared_ptr<DNSRule>::*)() const>("getMatches", [](const std::shared_ptr<DNSRule>& rule) { return rule->d_matches.load(); });
+
+  luaCtx.registerFunction<std::shared_ptr<DNSRule> (dnsdist::rules::RuleAction::*)() const>("getSelector", [](const dnsdist::rules::RuleAction& rule) { return rule.d_rule; });
+
+  luaCtx.registerFunction<std::shared_ptr<DNSAction> (dnsdist::rules::RuleAction::*)() const>("getAction", [](const dnsdist::rules::RuleAction& rule) { return rule.d_action; });
+
+  luaCtx.registerFunction<std::shared_ptr<DNSRule> (dnsdist::rules::ResponseRuleAction::*)() const>("getSelector", [](const dnsdist::rules::ResponseRuleAction& rule) { return rule.d_rule; });
+
+  luaCtx.registerFunction<std::shared_ptr<DNSResponseAction> (dnsdist::rules::ResponseRuleAction::*)() const>("getAction", [](const dnsdist::rules::ResponseRuleAction& rule) { return rule.d_action; });
+
+  luaCtx.writeFunction("SuffixMatchNodeRule", qnameSuffixRule);
+
+  luaCtx.writeFunction("NetmaskGroupRule", [](const boost::variant<const NetmaskGroup&, std::string, const LuaArray<std::string>> netmasks, boost::optional<bool> src, boost::optional<bool> quiet) {
+    if (netmasks.type() == typeid(string)) {
+      NetmaskGroup nmg;
+      nmg.addMask(*boost::get<std::string>(&netmasks));
+      return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true, quiet ? *quiet : false));
+    }
+
+    if (netmasks.type() == typeid(LuaArray<std::string>)) {
+      NetmaskGroup nmg;
+      for (const auto& str : *boost::get<const LuaArray<std::string>>(&netmasks)) {
+        nmg.addMask(str.second);
+      }
+      return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true, quiet ? *quiet : false));
+    }
+
+    const auto& nmg = *boost::get<const NetmaskGroup&>(&netmasks);
+    return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true, quiet ? *quiet : false));
+  });
 
   luaCtx.writeFunction("QNameSuffixRule", qnameSuffixRule);
 
index ef54de73b06b82b021990dd529ec46bd601a9896..8ec44f2df8dd9e9acaf483105d7779317417f30e 100644 (file)
@@ -3213,6 +3213,8 @@ void setupLuaBindingsOnly(LuaContext& luaCtx, bool client, bool configCheck)
   setupLuaInspection(luaCtx);
   setupLuaVars(luaCtx);
   setupLuaWeb(luaCtx);
+  setupLuaActions(luaCtx);
+  setupLuaSelectors(luaCtx);
   dnsdist::configuration::yaml::addLuaBindingsForYAMLObjects(luaCtx);
 
 #ifdef LUAJIT_VERSION
@@ -3228,8 +3230,7 @@ void setupLuaConfigurationOptions(LuaContext& luaCtx, bool client, bool configCh
   }
 
   setupLuaConfig(luaCtx, client, configCheck);
-  setupLuaActions(luaCtx);
-  setupLuaRules(luaCtx);
+  setupLuaRuleChainsManagement(luaCtx);
   dnsdist::lua::hooks::setupLuaHooks(luaCtx);
 }
 
index 980078c533f99260748cdeecbfdf3ad40fda90b8..47ab0b258ee3f67010da7738441a29f1f5129dfc 100644 (file)
@@ -56,7 +56,8 @@ void setupLuaBindingsNetwork(LuaContext& luaCtx, bool client);
 void setupLuaBindingsPacketCache(LuaContext& luaCtx, bool client);
 void setupLuaBindingsProtoBuf(LuaContext& luaCtx, bool client, bool configCheck);
 void setupLuaBindingsRings(LuaContext& luaCtx, bool client);
-void setupLuaRules(LuaContext& luaCtx);
+void setupLuaRuleChainsManagement(LuaContext& luaCtx);
+void setupLuaSelectors(LuaContext& luaCtx);
 void setupLuaInspection(LuaContext& luaCtx);
 void setupLuaVars(LuaContext& luaCtx);
 void setupLuaWeb(LuaContext& luaCtx);