]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdist-dynblocks.hh
auth: switch circleci mssql image
[thirdparty/pdns.git] / pdns / dnsdist-dynblocks.hh
index 73d4af36d82a25e5d1fee16bdcccf24354704240..5d9923fb07efafbd5b188d6fd767b53ce6b4a8ce 100644 (file)
  */
 #pragma once
 
+#include <unordered_set>
+
 #include "dolog.hh"
 #include "dnsdist-rings.hh"
+#include "statnode.hh"
+
+#include "dnsdist-lua-inspection-ffi.hh"
+
+// dnsdist_ffi_stat_node_t is a lightuserdata
+template<>
+struct LuaContext::Pusher<dnsdist_ffi_stat_node_t*> {
+    static const int minSize = 1;
+    static const int maxSize = 1;
+
+    static PushedObject push(lua_State* state, dnsdist_ffi_stat_node_t* ptr) noexcept {
+        lua_pushlightuserdata(state, ptr);
+        return PushedObject{state, 1};
+    }
+};
+
+typedef std::function<bool(dnsdist_ffi_stat_node_t*)> dnsdist_ffi_stat_node_visitor_t;
+
+struct dnsdist_ffi_stat_node_t
+{
+  dnsdist_ffi_stat_node_t(const StatNode& node_, const StatNode::Stat& self_, const StatNode::Stat& children_): node(node_), self(self_), children(children_)
+  {
+  }
+
+  const StatNode& node;
+  const StatNode::Stat& self;
+  const StatNode::Stat& children;
+};
 
 class DynBlockRulesGroup
 {
@@ -119,7 +149,7 @@ private:
     bool d_enabled{false};
   };
 
-    typedef std::unordered_map<ComboAddress, Counts, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual> counts_t;
+  typedef std::unordered_map<ComboAddress, Counts, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual> counts_t;
 
 public:
   DynBlockRulesGroup()
@@ -149,6 +179,20 @@ public:
     entry = DynBlockRule(reason, blockDuration, rate, warningRate, seconds, action);
   }
 
+  typedef std::function<bool(const StatNode&, const StatNode::Stat&, const StatNode::Stat&)> smtVisitor_t;
+
+  void setSuffixMatchRule(unsigned int seconds, std::string reason, unsigned int blockDuration, DNSAction::Action action, smtVisitor_t visitor)
+  {
+    d_suffixMatchRule = DynBlockRule(reason, blockDuration, 0, 0, seconds, action);
+    d_smtVisitor = visitor;
+  }
+
+  void setSuffixMatchRuleFFI(unsigned int seconds, std::string reason, unsigned int blockDuration, DNSAction::Action action, dnsdist_ffi_stat_node_visitor_t visitor)
+  {
+    d_suffixMatchRule = DynBlockRule(reason, blockDuration, 0, 0, seconds, action);
+    d_smtVisitorFFI = visitor;
+  }
+
   void apply()
   {
     struct timespec now;
@@ -160,6 +204,7 @@ public:
   void apply(const struct timespec& now)
   {
     counts_t counts;
+    StatNode statNodeRoot;
 
     size_t entriesCount = 0;
     if (hasQueryRules()) {
@@ -171,9 +216,9 @@ public:
     counts.reserve(entriesCount);
 
     processQueryRules(counts, now);
-    processResponseRules(counts, now);
+    processResponseRules(counts, statNodeRoot, now);
 
-    if (counts.empty()) {
+    if (counts.empty() && statNodeRoot.empty()) {
       return;
     }
 
@@ -239,6 +284,38 @@ public:
     if (updated && blocks) {
       g_dynblockNMG.setState(*blocks);
     }
+
+    if (!statNodeRoot.empty()) {
+      StatNode::Stat node;
+      std::unordered_set<DNSName> namesToBlock;
+      statNodeRoot.visit([this,&namesToBlock](const StatNode* node_, const StatNode::Stat& self, const StatNode::Stat& children) {
+                           bool block = false;
+
+                           if (d_smtVisitorFFI) {
+                             dnsdist_ffi_stat_node_t tmp(*node_, self, children);
+                             block = d_smtVisitorFFI(&tmp);
+                           }
+                           else {
+                             block = d_smtVisitor(*node_, self, children);
+                           }
+
+                           if (block) {
+                             namesToBlock.insert(DNSName(node_->fullname));
+                           }
+                         },
+        node);
+
+      if (!namesToBlock.empty()) {
+        updated = false;
+        SuffixMatchTree<DynBlock> smtBlocks = g_dynblockSMT.getCopy();
+        for (const auto& name : namesToBlock) {
+          addOrRefreshBlockSMT(smtBlocks, now, name, d_suffixMatchRule, updated);
+        }
+        if (updated) {
+          g_dynblockSMT.setState(smtBlocks);
+        }
+      }
+    }
   }
 
   void excludeRange(const Netmask& range)
@@ -251,12 +328,18 @@ public:
     d_excludedSubnets.addMask(range, false);
   }
 
+  void excludeDomain(const DNSName& domain)
+  {
+    d_excludedDomains.add(domain);
+  }
+
   std::string toString() const
   {
     std::stringstream result;
 
     result << "Query rate rule: " << d_queryRateRule.toString() << std::endl;
     result << "Response rate rule: " << d_respRateRule.toString() << std::endl;
+    result << "SuffixMatch rule: " << d_suffixMatchRule.toString() << std::endl;
     result << "RCode rules: " << std::endl;
     for (const auto& rule : d_rcodeRules) {
       result << "- " << RCode::to_s(rule.first) << ": " << rule.second.toString() << std::endl;
@@ -266,6 +349,7 @@ public:
       result << "- " << QType(rule.first).getName() << ": " << rule.second.toString() << std::endl;
     }
     result << "Excluded Subnets: " << d_excludedSubnets.toString() << std::endl;
+    result << "Excluded Domains: " << d_excludedDomains.toString() << std::endl;
 
     return result.str();
   }
@@ -348,6 +432,44 @@ private:
     updated = true;
   }
 
+  void addOrRefreshBlockSMT(SuffixMatchTree<DynBlock>& blocks, const struct timespec& now, const DNSName& name, const DynBlockRule& rule, bool& updated)
+  {
+    if (d_excludedDomains.check(name)) {
+      /* do not add a block for excluded domains */
+      return;
+    }
+
+    struct timespec until = now;
+    until.tv_sec += rule.d_blockDuration;
+    unsigned int count = 0;
+    const auto& got = blocks.lookup(name);
+    bool expired = false;
+
+    if (got) {
+      if (until < got->until) {
+        // had a longer policy
+        return;
+      }
+
+      if (now < got->until) {
+        // only inherit count on fresh query we are extending
+        count = got->blocks;
+      }
+      else {
+        expired = true;
+      }
+    }
+
+    DynBlock db{rule.d_blockReason, until, name, rule.d_action};
+    db.blocks = count;
+
+    if (!d_beQuiet && (!got || expired)) {
+      warnlog("Inserting dynamic block for %s for %d seconds: %s", name, rule.d_blockDuration, rule.d_blockReason);
+    }
+    blocks.add(name, db);
+    updated = true;
+  }
+
   void addBlock(boost::optional<NetmaskTree<DynBlock> >& blocks, const struct timespec& now, const ComboAddress& requestor, const DynBlockRule& rule, bool& updated)
   {
     addOrRefreshBlock(blocks, now, requestor, rule, updated, false);
@@ -368,6 +490,11 @@ private:
     return d_respRateRule.isEnabled() || !d_rcodeRules.empty();
   }
 
+  bool hasSuffixMatchRules() const
+  {
+    return d_suffixMatchRule.isEnabled();
+  }
+
   bool hasRules() const
   {
     return hasQueryRules() || hasResponseRules();
@@ -410,15 +537,18 @@ private:
     }
   }
 
-  void processResponseRules(counts_t& counts, const struct timespec& now)
+  void processResponseRules(counts_t& counts, StatNode& root, const struct timespec& now)
   {
-    if (!hasResponseRules()) {
+    if (!hasResponseRules() && !hasSuffixMatchRules()) {
       return;
     }
 
     d_respRateRule.d_cutOff = d_respRateRule.d_minTime = now;
     d_respRateRule.d_cutOff.tv_sec -= d_respRateRule.d_seconds;
 
+    d_suffixMatchRule.d_cutOff = d_suffixMatchRule.d_minTime = now;
+    d_suffixMatchRule.d_cutOff.tv_sec -= d_suffixMatchRule.d_seconds;
+
     for (auto& rule : d_rcodeRules) {
       rule.second.d_cutOff = rule.second.d_minTime = now;
       rule.second.d_cutOff.tv_sec -= rule.second.d_seconds;
@@ -432,6 +562,7 @@ private:
         }
 
         bool respRateMatches = d_respRateRule.matches(c.when);
+        bool suffixMatchRuleMatches = d_suffixMatchRule.matches(c.when);
         bool rcodeRuleMatches = checkIfResponseCodeMatches(c);
 
         if (respRateMatches || rcodeRuleMatches) {
@@ -443,6 +574,10 @@ private:
             entry.d_rcodeCounts[c.dh.rcode]++;
           }
         }
+
+        if (suffixMatchRuleMatches) {
+          root.submit(c.name, c.dh.rcode, boost::none);
+        }
       }
     }
   }
@@ -451,6 +586,10 @@ private:
   std::map<uint16_t, DynBlockRule> d_qtypeRules;
   DynBlockRule d_queryRateRule;
   DynBlockRule d_respRateRule;
+  DynBlockRule d_suffixMatchRule;
   NetmaskGroup d_excludedSubnets;
+  SuffixMatchNode d_excludedDomains;
+  smtVisitor_t d_smtVisitor;
+  dnsdist_ffi_stat_node_visitor_t d_smtVisitorFFI;
   bool d_beQuiet{false};
 };