]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdist-lua-bindings.cc
Merge pull request #11526 from Y7n05h/master
[thirdparty/pdns.git] / pdns / dnsdist-lua-bindings.cc
index 6f64c341ceabbb449ae987038a9806d69a778722..74e73294006eeeecba71cec0c2492328cdcc88b4 100644 (file)
@@ -19,6 +19,7 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
+#include "bpf-filter.hh"
 #include "config.h"
 #include "dnsdist.hh"
 #include "dnsdist-lua.hh"
@@ -457,6 +458,8 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
       convertParamsToConfig("ipv4", BPFFilter::MapType::IPv4);
       convertParamsToConfig("ipv6", BPFFilter::MapType::IPv6);
       convertParamsToConfig("qnames", BPFFilter::MapType::QNames);
+      convertParamsToConfig("cidr4", BPFFilter::MapType::CIDR4);
+      convertParamsToConfig("cidr6", BPFFilter::MapType::CIDR6);
 
       BPFFilter::MapFormat format = BPFFilter::MapFormat::Legacy;
       bool external = false;
@@ -498,7 +501,26 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
         }
       }
     });
-
+  luaCtx.registerFunction<void (std::shared_ptr<BPFFilter>::*)(const string& range, uint32_t action, boost::optional<bool> force)>("addRangeRule", [](std::shared_ptr<BPFFilter> bpf, const string& range, uint32_t action, boost::optional<bool> force) {
+    if (!bpf) {
+      return;
+    }
+    BPFFilter::MatchAction match;
+    switch (action) {
+    case 0:
+      match = BPFFilter::MatchAction::Pass;
+      break;
+    case 1:
+      match = BPFFilter::MatchAction::Drop;
+      break;
+    case 2:
+      match = BPFFilter::MatchAction::Truncate;
+      break;
+    default:
+      throw std::runtime_error("Unsupported action for BPFFilter::block");
+    }
+    return bpf->addRangeRule(Netmask(range), force.value_or(false), match);
+  });
   luaCtx.registerFunction<void(std::shared_ptr<BPFFilter>::*)(const DNSName& qname, boost::optional<uint16_t> qtype, boost::optional<uint32_t> action)>("blockQName", [](std::shared_ptr<BPFFilter> bpf, const DNSName& qname, boost::optional<uint16_t> qtype, boost::optional<uint32_t> action) {
       if (bpf) {
         if (!action) {
@@ -530,7 +552,29 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
         return bpf->unblock(ca);
       }
     });
-
+  luaCtx.registerFunction<void (std::shared_ptr<BPFFilter>::*)(const string& range)>("rmRangeRule", [](std::shared_ptr<BPFFilter> bpf, const string& range) {
+    if (!bpf) {
+      return;
+    }
+    bpf->rmRangeRule(Netmask(range));
+  });
+  luaCtx.registerFunction<std::string (std::shared_ptr<BPFFilter>::*)() const>("lsRangeRule", [](const std::shared_ptr<BPFFilter> bpf) {
+    setLuaNoSideEffect();
+    std::string res;
+    if (!bpf) {
+      return res;
+    }
+    const auto rangeStat = bpf->getRangeRule();
+    for (const auto& value : rangeStat) {
+      if (value.first.isIPv4()) {
+        res += BPFFilter::toString(value.second.action) + "\t " + value.first.toString() + "\n";
+      }
+      else if (value.first.isIPv6()) {
+        res += BPFFilter::toString(value.second.action) + "\t[" + value.first.toString() + "]\n";
+      }
+    }
+    return res;
+  });
   luaCtx.registerFunction<void(std::shared_ptr<BPFFilter>::*)(const DNSName& qname, boost::optional<uint16_t> qtype)>("unblockQName", [](std::shared_ptr<BPFFilter> bpf, const DNSName& qname, boost::optional<uint16_t> qtype) {
       if (bpf) {
         return bpf->unblock(qname, qtype ? *qtype : 255);
@@ -550,6 +594,15 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
             res += "[" + value.first.toString() + "]: " + std::to_string(value.second) + "\n";
           }
         }
+        const auto rangeStat = bpf->getRangeRule();
+        for (const auto& value : rangeStat) {
+          if (value.first.isIPv4()) {
+            res += BPFFilter::toString(value.second.action) + "\t " + value.first.toString() + ": " + std::to_string(value.second.counter) + "\n";
+          }
+          else if (value.first.isIPv6()) {
+            res += BPFFilter::toString(value.second.action) + "\t[" + value.first.toString() + "]: " + std::to_string(value.second.counter) + "\n";
+          }
+        }
         auto qstats = bpf->getQNameStats();
         for (const auto& value : qstats) {
           res += std::get<0>(value).toString() + " " + std::to_string(std::get<1>(value)) + ": " + std::to_string(std::get<2>(value)) + "\n";