]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Merge pull request #6364 from rgacogne/dnsdist-macro-sonar
authorRemi Gacogne <rgacogne@users.noreply.github.com>
Tue, 20 Mar 2018 15:12:56 +0000 (16:12 +0100)
committerGitHub <noreply@github.com>
Tue, 20 Mar 2018 15:12:56 +0000 (16:12 +0100)
dnsdist: Fix 'unreachable code' warnings reported by SonarCloud

16 files changed:
pdns/dnsdist-carbon.cc
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-inspection.cc
pdns/dnsdist-lua-vars.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-lua.hh
pdns/dnsdist.hh
pdns/dnsdistdist/docs/reference/constants.rst
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsrecords.hh
pdns/statnode.cc
pdns/statnode.hh
regression-tests.dnsdist/.gitignore
regression-tests.dnsdist/test_Advanced.py
regression-tests.dnsdist/test_Carbon.py
regression-tests.dnsdist/test_Responses.py

index caae0bb941d49adfce8ba8b263a95b63cbe45fd8..6ef88767f041a083603ce516e29504984a04b895 100644 (file)
@@ -112,6 +112,7 @@ try
           const string base = "dnsdist." + hostname + ".main.pools." + poolName + ".";
           const std::shared_ptr<ServerPool> pool = entry.second;
           str<<base<<"servers" << " " << pool->servers.size() << " " << now << "\r\n";
+          str<<base<<"servers-up" << " " << pool->countServersUp() << " " << now << "\r\n";
           if (pool->packetCache != nullptr) {
             const auto& cache = pool->packetCache;
             str<<base<<"cache-size" << " " << cache->getMaxEntries() << " " << now << "\r\n";
index 9a9339f495fb8c1c4d0359ab6eb9aeb072837af2..fa937924ac5c851e380a8c59355a965806f783d0 100644 (file)
@@ -29,6 +29,7 @@
 #include "ednsoptions.hh"
 #include "fstrm_logger.hh"
 #include "remote_logger.hh"
+#include "boost/optional/optional_io.hpp"
 
 class DropAction : public DNSAction
 {
@@ -311,8 +312,15 @@ DNSAction::Action LuaAction::operator()(DNSQuestion* dq, string* ruleresult) con
   std::lock_guard<std::mutex> lock(g_luamutex);
   try {
     auto ret = d_func(dq);
-    if(ruleresult)
-      *ruleresult=std::get<1>(ret);
+    if (ruleresult) {
+      if (boost::optional<string> rule = std::get<1>(ret)) {
+        *ruleresult = *rule;
+      }
+      else {
+        // default to empty string
+        ruleresult->clear();
+      }
+    }
     return (Action)std::get<0>(ret);
   } catch (std::exception &e) {
     warnlog("LuaAction failed inside lua, returning ServFail: %s", e.what());
@@ -327,8 +335,15 @@ DNSResponseAction::Action LuaResponseAction::operator()(DNSResponse* dr, string*
   std::lock_guard<std::mutex> lock(g_luamutex);
   try {
     auto ret = d_func(dr);
-    if(ruleresult)
-      *ruleresult=std::get<1>(ret);
+    if(ruleresult) {
+      if (boost::optional<string> rule = std::get<1>(ret)) {
+        *ruleresult = *rule;
+      }
+      else {
+        // default to empty string
+        ruleresult->clear();
+      }
+    }
     return (Action)std::get<0>(ret);
   } catch (std::exception &e) {
     warnlog("LuaResponseAction failed inside lua, returning ServFail: %s", e.what());
index 2e21865054b3d0f38f00149a9aa1678f9b292f06..d3b14e2c6d1af01f98f5310e800554c09f213a24 100644 (file)
 
 #include "statnode.hh"
 
-static std::unordered_map<int, vector<boost::variant<string,double>>> getGenResponses(unsigned int top, boost::optional<int> labels, std::function<bool(const Rings::Response&)> pred)
+static std::unordered_map<unsigned int, vector<boost::variant<string,double>>> getGenResponses(unsigned int top, boost::optional<int> labels, std::function<bool(const Rings::Response&)> pred)
 {
   setLuaNoSideEffect();
-  map<DNSName, int> counts;
+  map<DNSName, unsigned int> counts;
   unsigned int total=0;
   {
     std::lock_guard<std::mutex> lock(g_rings.respMutex);
@@ -53,7 +53,7 @@ static std::unordered_map<int, vector<boost::variant<string,double>>> getGenResp
     }
   }
   //      cout<<"Looked at "<<total<<" responses, "<<counts.size()<<" different ones"<<endl;
-  vector<pair<int, DNSName>> rcounts;
+  vector<pair<unsigned int, DNSName>> rcounts;
   rcounts.reserve(counts.size());
   for(const auto& c : counts)
     rcounts.push_back(make_pair(c.second, c.first.makeLowerCase()));
@@ -63,7 +63,7 @@ static std::unordered_map<int, vector<boost::variant<string,double>>> getGenResp
          return b.first < a.first;
        });
 
-  std::unordered_map<int, vector<boost::variant<string,double>>> ret;
+  std::unordered_map<unsigned int, vector<boost::variant<string,double>>> ret;
   unsigned int count=1, rest=0;
   for(const auto& rc : rcounts) {
     if(count==top+1)
@@ -75,19 +75,20 @@ static std::unordered_map<int, vector<boost::variant<string,double>>> getGenResp
   return ret;
 }
 
-static map<ComboAddress,int> filterScore(const map<ComboAddress, unsigned int,ComboAddress::addressOnlyLessThan >& counts,
-                                 double delta, int rate)
-{
-  std::multimap<unsigned int,ComboAddress> score;
-  for(const auto& e : counts)
-    score.insert({e.second, e.first});
+typedef std::unordered_map<ComboAddress, unsigned int, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual> counts_t;
 
-  map<ComboAddress,int> ret;
+static counts_t filterScore(const counts_t& counts,
+                        double delta, unsigned int rate)
+{
+  counts_t ret;
 
   double lim = delta*rate;
-  for(auto s = score.crbegin(); s != score.crend() && s->first > lim; ++s) {
-    ret[s->second]=s->first;
+  for(const auto& c : counts) {
+    if (c.second > lim) {
+      ret[c.first] = c.second;
+    }
   }
+
   return ret;
 }
 
@@ -103,23 +104,23 @@ static void statNodeRespRing(statvisitor_t visitor, unsigned int seconds)
     cutoff.tv_sec -= seconds;
   }
 
-  std::lock_guard<std::mutex> lock(g_rings.respMutex);
-
   StatNode root;
-  for(const auto& c : g_rings.respRing) {
-    if (now < c.when)
-      continue;
+  {
+    std::lock_guard<std::mutex> lock(g_rings.respMutex);
+    for(const auto& c : g_rings.respRing) {
+      if (now < c.when)
+        continue;
 
-    if (seconds && c.when < cutoff)
-      continue;
+      if (seconds && c.when < cutoff)
+        continue;
 
-    root.submit(c.name, c.dh.rcode, c.requestor);
+      root.submit(c.name, c.dh.rcode, boost::none);
+    }
   }
-  StatNode::Stat node;
 
-  root.visit([&visitor](const StatNode* node_, const StatNode::Stat& self, const StatNode::Stat& children) {
+  StatNode::Stat node;
+  root.visit([visitor](const StatNode* node_, const StatNode::Stat& self, const StatNode::Stat& children) {
       visitor(*node_, self, children);},  node);
-
 }
 
 static vector<pair<unsigned int, std::unordered_map<string,string> > > getRespRing(boost::optional<int> rcode)
@@ -141,8 +142,7 @@ static vector<pair<unsigned int, std::unordered_map<string,string> > > getRespRi
   return ret;
 }
 
-typedef   map<ComboAddress, unsigned int,ComboAddress::addressOnlyLessThan > counts_t;
-static map<ComboAddress,int> exceedRespGen(int rate, int seconds, std::function<void(counts_t&, const Rings::Response&)> T)
+static counts_t exceedRespGen(unsigned int rate, int seconds, std::function<void(counts_t&, const Rings::Response&)> T)
 {
   counts_t counts;
   struct timespec cutoff, mintime, now;
@@ -150,22 +150,26 @@ static map<ComboAddress,int> exceedRespGen(int rate, int seconds, std::function<
   cutoff = mintime = now;
   cutoff.tv_sec -= seconds;
 
-  std::lock_guard<std::mutex> lock(g_rings.respMutex);
-  for(const auto& c : g_rings.respRing) {
-    if(seconds && c.when < cutoff)
-      continue;
-    if(now < c.when)
-      continue;
-
-    T(counts, c);
-    if(c.when < mintime)
-      mintime = c.when;
+  {
+    std::lock_guard<std::mutex> lock(g_rings.respMutex);
+    counts.reserve(g_rings.respRing.size());
+    for(const auto& c : g_rings.respRing) {
+      if(seconds && c.when < cutoff)
+        continue;
+      if(now < c.when)
+        continue;
+
+      T(counts, c);
+      if(c.when < mintime)
+        mintime = c.when;
+    }
   }
+
   double delta = seconds ? seconds : DiffTime(now, mintime);
   return filterScore(counts, delta, rate);
 }
 
-static map<ComboAddress,int> exceedQueryGen(int rate, int seconds, std::function<void(counts_t&, const Rings::Query&)> T)
+static counts_t exceedQueryGen(unsigned int rate, int seconds, std::function<void(counts_t&, const Rings::Query&)> T)
 {
   counts_t counts;
   struct timespec cutoff, mintime, now;
@@ -173,22 +177,26 @@ static map<ComboAddress,int> exceedQueryGen(int rate, int seconds, std::function
   cutoff = mintime = now;
   cutoff.tv_sec -= seconds;
 
-  ReadLock rl(&g_rings.queryLock);
-  for(const auto& c : g_rings.queryRing) {
-    if(seconds && c.when < cutoff)
-      continue;
-    if(now < c.when)
-      continue;
-    T(counts, c);
-    if(c.when < mintime)
-      mintime = c.when;
+  {
+    ReadLock rl(&g_rings.queryLock);
+    counts.reserve(g_rings.queryRing.size());
+    for(const auto& c : g_rings.queryRing) {
+      if(seconds && c.when < cutoff)
+        continue;
+      if(now < c.when)
+        continue;
+      T(counts, c);
+      if(c.when < mintime)
+        mintime = c.when;
+    }
   }
+
   double delta = seconds ? seconds : DiffTime(now, mintime);
   return filterScore(counts, delta, rate);
 }
 
 
-static map<ComboAddress,int> exceedRCode(int rate, int seconds, int rcode)
+static counts_t exceedRCode(unsigned int rate, int seconds, int rcode)
 {
   return exceedRespGen(rate, seconds, [rcode](counts_t& counts, const Rings::Response& r)
                   {
@@ -197,7 +205,7 @@ static map<ComboAddress,int> exceedRCode(int rate, int seconds, int rcode)
                   });
 }
 
-static map<ComboAddress,int> exceedRespByterate(int rate, int seconds)
+static counts_t exceedRespByterate(unsigned int rate, int seconds)
 {
   return exceedRespGen(rate, seconds, [](counts_t& counts, const Rings::Response& r)
                   {
@@ -210,7 +218,7 @@ void setupLuaInspection()
   g_lua.writeFunction("topClients", [](boost::optional<unsigned int> top_) {
       setLuaNoSideEffect();
       auto top = top_.get_value_or(10);
-      map<ComboAddress, int,ComboAddress::addressOnlyLessThan > counts;
+      map<ComboAddress, unsigned int,ComboAddress::addressOnlyLessThan > counts;
       unsigned int total=0;
       {
         ReadLock rl(&g_rings.queryLock);
@@ -219,7 +227,7 @@ void setupLuaInspection()
           total++;
         }
       }
-      vector<pair<int, ComboAddress>> rcounts;
+      vector<pair<unsigned int, ComboAddress>> rcounts;
       rcounts.reserve(counts.size());
       for(const auto& c : counts)
        rcounts.push_back(make_pair(c.second, c.first));
@@ -241,7 +249,7 @@ void setupLuaInspection()
 
   g_lua.writeFunction("getTopQueries", [](unsigned int top, boost::optional<int> labels) {
       setLuaNoSideEffect();
-      map<DNSName, int> counts;
+      map<DNSName, unsigned int> counts;
       unsigned int total=0;
       if(!labels) {
        ReadLock rl(&g_rings.queryLock);
@@ -260,7 +268,7 @@ void setupLuaInspection()
        }
       }
       // cout<<"Looked at "<<total<<" queries, "<<counts.size()<<" different ones"<<endl;
-      vector<pair<int, DNSName>> rcounts;
+      vector<pair<unsigned int, DNSName>> rcounts;
       rcounts.reserve(counts.size());
       for(const auto& c : counts)
        rcounts.push_back(make_pair(c.second, c.first.makeLowerCase()));
@@ -270,7 +278,7 @@ void setupLuaInspection()
             return b.first < a.first;
           });
 
-      std::unordered_map<int, vector<boost::variant<string,double>>> ret;
+      std::unordered_map<unsigned int, vector<boost::variant<string,double>>> ret;
       unsigned int count=1, rest=0;
       for(const auto& rc : rcounts) {
        if(count==top+1)
index 8c9ec6fce3e28872f48a156b5a4aab410d4aa682..75722662685996182b3d60e9e694ae56e4f289c3 100644 (file)
@@ -23,7 +23,7 @@
 
 void setupLuaVars()
 {
-    g_lua.writeVariable("DNSAction", std::unordered_map<string,int>{
+  g_lua.writeVariable("DNSAction", std::unordered_map<string,int>{
       {"Drop", (int)DNSAction::Action::Drop},
       {"Nxdomain", (int)DNSAction::Action::Nxdomain},
       {"Refused", (int)DNSAction::Action::Refused},
@@ -40,6 +40,7 @@ void setupLuaVars()
   g_lua.writeVariable("DNSResponseAction", std::unordered_map<string,int>{
       {"Allow",        (int)DNSResponseAction::Action::Allow        },
       {"Delay",        (int)DNSResponseAction::Action::Delay        },
+      {"Drop",         (int)DNSResponseAction::Action::Drop         },
       {"HeaderModify", (int)DNSResponseAction::Action::HeaderModify },
       {"ServFail",     (int)DNSResponseAction::Action::ServFail     },
       {"None",         (int)DNSResponseAction::Action::None         }
index d8a8c916f2e1e16637922fcba8aa2960452ef3f6..97f589fa11958a359994f55e9805fa7c89afff9c 100644 (file)
@@ -821,7 +821,10 @@ void setupLuaConfig(bool client)
     });
 
   g_lua.writeFunction("addDynBlocks",
-                      [](const map<ComboAddress,int>& m, const std::string& msg, boost::optional<int> seconds, boost::optional<DNSAction::Action> action) {
+                      [](const std::unordered_map<ComboAddress,unsigned int, ComboAddress::addressOnlyHash, ComboAddress::addressOnlyEqual>& m, const std::string& msg, boost::optional<int> seconds, boost::optional<DNSAction::Action> action) {
+                           if (m.empty()) {
+                             return;
+                           }
                            setLuaSideEffect();
                           auto slow = g_dynblockNMG.getCopy();
                           struct timespec until, now;
@@ -852,6 +855,9 @@ void setupLuaConfig(bool client)
 
   g_lua.writeFunction("addDynBlockSMT",
                       [](const vector<pair<unsigned int, string> >&names, const std::string& msg, boost::optional<int> seconds, boost::optional<DNSAction::Action> action) {
+                           if (names.empty()) {
+                             return;
+                           }
                            setLuaSideEffect();
                           auto slow = g_dynblockSMT.getCopy();
                           struct timespec until, now;
index 5c7d73c1daafcdd2cf095ced12957cdc4e46d80b..19c5f39655502bfaa2faaa161b767416aa2c1a72 100644 (file)
@@ -24,7 +24,7 @@
 class LuaAction : public DNSAction
 {
 public:
-  typedef std::function<std::tuple<int, string>(DNSQuestion* dq)> func_t;
+  typedef std::function<std::tuple<int, boost::optional<string> >(DNSQuestion* dq)> func_t;
   LuaAction(LuaAction::func_t func) : d_func(func)
   {}
   Action operator()(DNSQuestion* dq, string* ruleresult) const override;
@@ -39,7 +39,7 @@ private:
 class LuaResponseAction : public DNSResponseAction
 {
 public:
-  typedef std::function<std::tuple<int, string>(DNSResponse* dr)> func_t;
+  typedef std::function<std::tuple<int, boost::optional<string> >(DNSResponse* dr)> func_t;
   LuaResponseAction(LuaResponseAction::func_t func) : d_func(func)
   {}
   Action operator()(DNSResponse* dr, string* ruleresult) const override;
index cfc92475ff9018312e73a59fcb8f41b6953adf24..2c76ca263fe752646242e0827f985e9808f1bcaa 100644 (file)
@@ -90,12 +90,12 @@ struct DNSResponse : DNSQuestion
     DNSQuestion(name, type, class_, lc, rem, header, bufferSize, responseLen, isTcp, queryTime_) { }
 };
 
-/* so what could you do: 
-   drop, 
-   fake up nxdomain, 
-   provide actual answer, 
-   allow & and stop processing, 
-   continue processing, 
+/* so what could you do:
+   drop,
+   fake up nxdomain,
+   provide actual answer,
+   allow & and stop processing,
+   continue processing,
    modify header:    (servfail|refused|notimp), set TC=1,
    send to pool */
 
@@ -172,7 +172,7 @@ struct DNSDistStats
   stat_t cacheHits{0};
   stat_t cacheMisses{0};
   stat_t latency0_1{0}, latency1_10{0}, latency10_50{0}, latency50_100{0}, latency100_1000{0}, latencySlow{0};
-  
+
   double latencyAvg100{0}, latencyAvg1000{0}, latencyAvg10000{0}, latencyAvg1000000{0};
   typedef std::function<uint64_t(const std::string&)> statfunction_t;
   typedef boost::variant<stat_t*, double*, statfunction_t> entry_t;
@@ -187,7 +187,7 @@ struct DNSDistStats
     {"rule-servfail", &ruleServFail},
     {"self-answered", &selfAnswered},
     {"downstream-timeouts", &downstreamTimeouts},
-    {"downstream-send-errors", &downstreamSendErrors}, 
+    {"downstream-send-errors", &downstreamSendErrors},
     {"trunc-failures", &truncFail},
     {"no-policy", &noPolicy},
     {"latency0-1", &latency0_1},
@@ -211,7 +211,7 @@ struct DNSDistStats
     {"cpu-user-msec", getCPUTimeUser},
     {"cpu-sys-msec", getCPUTimeSystem},
     {"fd-usage", getOpenFileDescriptors},
-    {"dyn-blocked", &dynBlocked}, 
+    {"dyn-blocked", &dynBlocked},
     {"dyn-block-nmg-size", [](const std::string&) { return g_dynblockNMG.getLocal()->size(); }}
   };
 };
@@ -229,21 +229,21 @@ struct StopWatch
   struct timespec d_start{0,0};
   bool d_needRealTime{false};
 
-  void start() {  
+  void start() {
     if(gettime(&d_start, d_needRealTime) < 0)
       unixDie("Getting timestamp");
-    
+
   }
 
   void set(const struct timespec& from) {
     d_start = from;
   }
-  
+
   double udiff() const {
     struct timespec now;
     if(gettime(&now, d_needRealTime) < 0)
       unixDie("Getting timestamp");
-    
+
     return 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0;
   }
 
@@ -251,7 +251,7 @@ struct StopWatch
     struct timespec now;
     if(gettime(&now, d_needRealTime) < 0)
       unixDie("Getting timestamp");
-    
+
     auto ret= 1000000.0*(now.tv_sec - d_start.tv_sec) + (now.tv_nsec - d_start.tv_nsec)/1000.0;
     d_start = now;
     return ret;
@@ -291,7 +291,7 @@ public:
     if(d_passthrough)
       return true;
     auto delta = d_prev.udiffAndSet();
-  
+
     d_tokens += 1.0*d_rate * (delta/1000000.0);
 
     if(d_tokens > d_burst)
@@ -306,7 +306,7 @@ public:
     else
       d_blocked++;
 
-    return ret; 
+    return ret;
   }
 private:
   bool d_passthrough{true};
@@ -397,7 +397,7 @@ struct Rings {
 
   std::unordered_map<int, vector<boost::variant<string,double> > > getTopBandwidth(unsigned int numentries);
   size_t numDistinctRequestors();
-  void setCapacity(size_t newCapacity) 
+  void setCapacity(size_t newCapacity)
   {
     {
       WriteLock wl(&queryLock);
@@ -646,6 +646,16 @@ struct ServerPool
   NumberedVector<shared_ptr<DownstreamState>> servers;
   std::shared_ptr<DNSDistPacketCache> packetCache{nullptr};
   std::shared_ptr<ServerPolicy> policy{nullptr};
+
+  size_t countServersUp() const {
+    size_t upFound = 0;
+    for (const auto& server : servers) {
+      if (std::get<1>(server)->isUp() ) {
+        upFound++;
+      };
+    };
+    return upFound;
+  };
 };
 using pools_t=map<std::string,std::shared_ptr<ServerPool>>;
 void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy);
index 38b994d8cc99a684e08d679b29d0a8755e83c55d..176bd11180a7c4ced661af65cbb5e9c6565304e8 100644 (file)
@@ -73,7 +73,7 @@ DNS Section
 DNSAction
 ---------
 
-These constants represent an Action that can be returned from the functions invoked by :func:`addLuaAction` and :func:`addLuaResponseAction`.
+These constants represent an Action that can be returned from the functions invoked by :func:`addLuaAction`.
 
  * ``DNSAction.Allow``: let the query pass, skipping other rules
  * ``DNSAction.Delay``: delay the response for the specified milliseconds (UDP-only), continue to the next rule
@@ -83,4 +83,21 @@ These constants represent an Action that can be returned from the functions invo
  * ``DNSAction.Nxdomain``: return a response with a NXDomain rcode
  * ``DNSAction.Pool``: use the specified pool to forward this query
  * ``DNSAction.Refused``: return a response with a Refused rcode
+ * ``DNSAction.ServFail``: return a response with a ServFail rcode
  * ``DNSAction.Spoof``: spoof the response using the supplied IPv4 (A), IPv6 (AAAA) or string (CNAME) value
+ * ``DNSAction.Truncate``: truncate the response
+
+
+.. _DNSResponseAction:
+
+DNSResponseAction
+-----------------
+
+These constants represent an Action that can be returned from the functions invoked by :func:`addLuaResponseAction`.
+
+ * ``DNSResponseAction.Allow``: let the response pass, skipping other rules
+ * ``DNSResponseAction.Delay``: delay the response for the specified milliseconds (UDP-only), continue to the next rule
+ * ``DNSResponseAction.Drop``: drop the response
+ * ``DNSResponseAction.HeaderModify``: indicate that the query has been turned into a response
+ * ``DNSResponseAction.None``: continue to the next rule
+ * ``DNSResponseAction.ServFail``: return a response with a ServFail rcode
index 33da71a9e349328943b172532c4bb54cc36a80e8..878645b0d60d6ef334d89af946dba2ebea523104 100644 (file)
@@ -150,7 +150,9 @@ Rule Generators
 
   Invoke a Lua function that accepts a :class:`DNSQuestion`.
   This function works similar to using :func:`LuaAction`.
-  The ``function`` should return a :ref:`DNSAction`. If the Lua code fails, ServFail is returned.
+  The ``function`` should return both a :ref:`DNSAction` and its argument `rule`. The `rule` is used as an argument
+  of the following :ref:`DNSAction`: `DNSAction.Spoof`, `DNSAction.Pool` and `DNSAction.Delay`. As of version `1.3.0`, you can
+  omit the argument. For earlier releases, simply return an empty string. If the Lua code fails, ServFail is returned.
 
   :param DNSRule: match queries based on this rule
   :param string function: the name of a Lua function
@@ -160,6 +162,20 @@ Rule Generators
 
   * ``uuid``: string - UUID to assign to the new rule. By default a random UUID is generated for each rule.
 
+  ::
+
+    function luarule(dq)
+      if(dq.qtype==dnsdist.NAPTR)
+      then
+        return DNSAction.Pool, "abuse" -- send to abuse pool
+      else
+        return DNSAction.None, ""      -- no action
+        -- return DNSAction.None       -- as of dnsdist version 1.3.0
+      end
+    end
+
+    addLuaAction(AllRule(), luarule)
+
 .. function:: addLuaResponseAction(DNSrule, function [, options])
 
   .. versionchanged:: 1.3.0
@@ -167,7 +183,9 @@ Rule Generators
 
   Invoke a Lua function that accepts a :class:`DNSResponse`.
   This function works similar to using :func:`LuaResponseAction`.
-  The ``function`` should return a :ref:`DNSResponseAction`. If the Lua code fails, ServFail is returned.
+  The ``function`` should return both a :ref:`DNSResponseAction` and its argument `rule`. The `rule` is used as an argument
+  of the `DNSResponseAction.Delay`. As of version `1.3.0`, you can omit the argument (see :func:`addLuaAction`). For earlier
+  releases, simply return an empty string. If the Lua code fails, ServFail is returned.
 
   :param DNSRule: match queries based on this rule
   :param string function: the name of a Lua function
index fe19b55c7a6e0a345a882a176ddc4c0df0c29777..42204280bb687e0bf87be3ac97e61b2824bb946b 100644 (file)
@@ -46,7 +46,7 @@ class NAPTRRecordContent : public DNSRecordContent
 public:
   NAPTRRecordContent(uint16_t order, uint16_t preference, string flags, string services, string regexp, DNSName replacement);
 
-  includeboilerplate(NAPTR);
+  includeboilerplate(NAPTR)
   template<class Convertor> void xfrRecordContent(Convertor& conv);
 private:
   uint16_t d_order, d_preference;
@@ -60,7 +60,7 @@ class ARecordContent : public DNSRecordContent
 public:
   explicit ARecordContent(const ComboAddress& ca);
   explicit ARecordContent(uint32_t ip);
-  includeboilerplate(A);
+  includeboilerplate(A)
   void doRecordCheck(const DNSRecord& dr);
   ComboAddress getCA(int port=0) const;
   bool operator==(const DNSRecordContent& rhs) const override
@@ -78,7 +78,7 @@ class AAAARecordContent : public DNSRecordContent
 public:
   AAAARecordContent(std::string &val);
   explicit AAAARecordContent(const ComboAddress& ca);
-  includeboilerplate(AAAA);
+  includeboilerplate(AAAA)
   ComboAddress getCA(int port=0) const;
   bool operator==(const DNSRecordContent& rhs) const override
   {
index d33836438328623c7dbc635be86d1759e0eb2ac4..2d40f582ad868636cc486acfc9d9398a2abd4d83 100644 (file)
@@ -43,7 +43,6 @@ void  StatNode::visit(visitor_t visitor, Stat &newstat, unsigned int depth) cons
   
   Stat selfstat(childstat);
 
-
   for(const children_t::value_type& child :  children) {
     child.second.visit(visitor, childstat, depth+8);
   }
@@ -54,18 +53,15 @@ void  StatNode::visit(visitor_t visitor, Stat &newstat, unsigned int depth) cons
 }
 
 
-void StatNode::submit(const DNSName& domain, int rcode, const ComboAddress& remote)
+void StatNode::submit(const DNSName& domain, int rcode, boost::optional<const ComboAddress&> remote)
 {
   //  cerr<<"FIRST submit called on '"<<domain<<"'"<<endl;
-  vector<string> tmp = domain.getRawLabels();
+  std::vector<string> tmp = domain.getRawLabels();
   if(tmp.empty())
     return;
 
-  deque<string> parts;
-  for(auto const i : tmp) {
-    parts.push_back(i);
-  }
-  children[parts.back()].submit(parts, "", rcode, remote, 1);
+  auto last = tmp.end() - 1;
+  children[*last].submit(last, tmp.begin(), "", rcode, remote, 1);
 }
 
 /* www.powerdns.com. -> 
@@ -75,24 +71,22 @@ void StatNode::submit(const DNSName& domain, int rcode, const ComboAddress& remo
    www.powerdns.com. 
 */
 
-void StatNode::submit(deque<string>& labels, const std::string& domain, int rcode, const ComboAddress& remote, unsigned int count)
+void StatNode::submit(std::vector<string>::const_iterator end, std::vector<string>::const_iterator begin, const std::string& domain, int rcode, boost::optional<const ComboAddress&> remote, unsigned int count)
 {
-  if(labels.empty())
-    return;
   //  cerr<<"Submit called for domain='"<<domain<<"': ";
   //  for(const std::string& n :  labels) 
   //    cerr<<n<<".";
   //  cerr<<endl;
   if(name.empty()) {
 
-    name=labels.back();
+    name=*end;
     //    cerr<<"Set short name to '"<<name<<"'"<<endl;
   }
   else {
     //    cerr<<"Short name was already set to '"<<name<<"'"<<endl;
   }
 
-  if(labels.size()==1) {
+  if(end == begin) {
     if (fullname.empty()) {
       fullname=name+"."+domain;
       labelsCount = count;
@@ -107,7 +101,10 @@ void StatNode::submit(deque<string>& labels, const std::string& domain, int rcod
       s.servfails++;
     else if(rcode==3)
       s.nxdomains++;
-    s.remotes[remote]++;
+
+    if (remote) {
+      s.remotes[*remote]++;
+    }
   }
   else {
     if (fullname.empty()) {
@@ -115,8 +112,8 @@ void StatNode::submit(deque<string>& labels, const std::string& domain, int rcod
       labelsCount = count;
     }
     //    cerr<<"Not yet end, set our fullname to '"<<fullname<<"', recursing"<<endl;
-    labels.pop_back();
-    children[labels.back()].submit(labels, fullname, rcode, remote, count+1);
+    end--;
+    children[*end].submit(end, begin, fullname, rcode, remote, count+1);
   }
 }
 
index 26759fd2ce0ceee650b189897850c3049ab13f26..39c20c0fc986302c559376842c25e90d9b845a86 100644 (file)
@@ -29,10 +29,10 @@ class StatNode
 {
 public:
 
-  struct Stat 
+  struct Stat
   {
     Stat() : queries(0), noerrors(0), nxdomains(0), servfails(0), drops(0){}
-    int queries, noerrors, nxdomains, servfails, drops;
+    uint64_t queries, noerrors, nxdomains, servfails, drops;
 
     Stat& operator+=(const Stat& rhs) {
       queries+=rhs.queries;
@@ -41,8 +41,8 @@ public:
       servfails+=rhs.servfails;
       drops+=rhs.drops;
 
-      for(const remotes_t::value_type& rem :  rhs.remotes) {
-       remotes[rem.first]+=rem.second;
+      for(const remotes_t::value_type& rem : rhs.remotes) {
+        remotes[rem.first]+=rem.second;
       }
       return *this;
     }
@@ -55,13 +55,14 @@ public:
   std::string fullname;
   unsigned int labelsCount{0};
 
-  void submit(const DNSName& domain, int rcode, const ComboAddress& remote);
-  void submit(std::deque<std::string>& labels, const std::string& domain, int rcode, const ComboAddress& remote, unsigned int count);
+  void submit(const DNSName& domain, int rcode, boost::optional<const ComboAddress&> remote);
 
   Stat print(unsigned int depth=0, Stat newstat=Stat(), bool silent=false) const;
   typedef boost::function<void(const StatNode*, const Stat& selfstat, const Stat& childstat)> visitor_t;
   void visit(visitor_t visitor, Stat& newstat, unsigned int depth=0) const;
   typedef std::map<std::string,StatNode, CIStringCompare> children_t;
   children_t children;
-  
+
+private:
+  void submit(std::vector<string>::const_iterator end, std::vector<string>::const_iterator begin, const std::string& domain, int rcode, boost::optional<const ComboAddress&> remote, unsigned int count);
 };
index 6fcb424ebee0a04272fa3db8ecab3d58087a0c4e..8502da2174c4eb02ad840e10549a5ef2ca7993a0 100644 (file)
@@ -9,3 +9,10 @@ DNSCryptResolver*
 dnsdist.log
 /*_pb2.py
 /__pycache__/
+ca.key
+ca.pem
+ca.srl
+server.chain
+server.csr
+server.key
+server.pem
index 6f2c3f838dea264f2c5799bb4b392264016aecb6..c4ab74103d029f4146bc9fd8a6c76776d0f19760 100644 (file)
@@ -1279,6 +1279,42 @@ class TestAdvancedLuaRefused(DNSDistTest):
         refusedResponse.id = receivedResponse.id
         self.assertEquals(receivedResponse, refusedResponse)
 
+class TestAdvancedLuaActionReturnSyntax(DNSDistTest):
+
+    _config_template = """
+    function refuse(dq)
+        return DNSAction.Refused
+    end
+    addAction(AllRule(), LuaAction(refuse))
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testRefusedWithEmptyRule(self):
+        """
+        Advanced: Short syntax for LuaAction return values
+        """
+        name = 'short.refused.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+        refusedResponse = dns.message.make_response(query)
+        refusedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        refusedResponse.id = receivedResponse.id
+        self.assertEquals(receivedResponse, refusedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertTrue(receivedResponse)
+        refusedResponse.id = receivedResponse.id
+        self.assertEquals(receivedResponse, refusedResponse)
+
 class TestAdvancedLuaTruncated(DNSDistTest):
 
     _config_template = """
index a06e90aba6c5ffebf938e4af022d7d729cbac2c7..5aa9c72f6a6effd41740485c4611573a84440700 100644 (file)
@@ -15,8 +15,15 @@ class TestCarbon(DNSDistTest):
     _carbonQueue2 = Queue()
     _carbonInterval = 2
     _carbonCounters = {}
-    _config_params = ['_carbonServer1Port', '_carbonServer1Name', '_carbonInterval', '_carbonServer2Port', '_carbonServer2Name', '_carbonInterval']
+    _config_params = ['_carbonServer1Port', '_carbonServer1Name', '_carbonInterval',
+                      '_carbonServer2Port', '_carbonServer2Name', '_carbonInterval']
     _config_template = """
+    s = newServer{address="127.0.0.1:5353"}
+    s:setDown()
+    s = newServer{address="127.0.0.1:5354"}
+    s:setUp()
+    s = newServer{address="127.0.0.1:5355"}
+    s:setUp()
     carbonServer("127.0.0.1:%s", "%s", %s)
     carbonServer("127.0.0.1:%s", "%s", %s)
     """
@@ -105,3 +112,60 @@ class TestCarbon(DNSDistTest):
         for key in self._carbonCounters:
             value = self._carbonCounters[key]
             self.assertTrue(value >= 1)
+
+    def testCarbonServerUp(self):
+        # wait for the carbon data to be sent
+        time.sleep(self._carbonInterval + 1)
+
+        # first server
+        self.assertFalse(self._carbonQueue1.empty())
+        data1 = self._carbonQueue1.get(False)
+        # second server
+        self.assertFalse(self._carbonQueue2.empty())
+        data2 = self._carbonQueue2.get(False)
+        after = time.time()
+
+        # check the first carbon server got both servers and
+        # servers-up metrics and that they are the same as
+        # configured in the class definition
+        self.assertTrue(data1)
+        self.assertTrue(len(data1.splitlines()) > 1)
+        expectedStart = b"dnsdist.%s.main.pools._default_.servers" % self._carbonServer1Name.encode('UTF-8')
+        for line in data1.splitlines():
+            if expectedStart in line:
+                parts = line.split(b' ')
+                if 'servers-up' in line:
+                    self.assertEquals(len(parts), 3)
+                    self.assertTrue(parts[1].isdigit())
+                    self.assertEquals(int(parts[1]), 2)
+                    self.assertTrue(parts[2].isdigit())
+                    self.assertTrue(int(parts[2]) <= int(after))
+                else:
+                    self.assertEquals(len(parts), 3)
+                    self.assertTrue(parts[1].isdigit())
+                    self.assertEquals(int(parts[1]), 3)
+                    self.assertTrue(parts[2].isdigit())
+                    self.assertTrue(int(parts[2]) <= int(after))
+
+        # check the second carbon server got both servers and
+        # servers-up metrics and that they are the same as
+        # configured in the class definition and the same as
+        # the first carbon server
+        self.assertTrue(data2)
+        self.assertTrue(len(data2.splitlines()) > 1)
+        expectedStart = b"dnsdist.%s.main.pools._default_.servers" % self._carbonServer2Name.encode('UTF-8')
+        for line in data2.splitlines():
+            if expectedStart in line:
+                parts = line.split(b' ')
+                if 'servers-up' in line:
+                    self.assertEquals(len(parts), 3)
+                    self.assertTrue(parts[1].isdigit())
+                    self.assertEquals(int(parts[1]), 2)
+                    self.assertTrue(parts[2].isdigit())
+                    self.assertTrue(int(parts[2]) <= int(after))
+                else:
+                    self.assertEquals(len(parts), 3)
+                    self.assertTrue(parts[1].isdigit())
+                    self.assertEquals(int(parts[1]), 3)
+                    self.assertTrue(parts[2].isdigit())
+                    self.assertTrue(int(parts[2]) <= int(after))
index c3df9450749588ff9627b6ef14fa348cbf1e33b3..494143c6d657d8420805b574bdc31b2e50c30ef3 100644 (file)
@@ -249,3 +249,60 @@ class TestResponseRuleEditTTL(DNSDistTest):
         self.assertEquals(response, receivedResponse)
         self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
         self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+
+class TestResponseLuaActionReturnSyntax(DNSDistTest):
+
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+    function customDelay(dr)
+      return DNSResponseAction.Delay, "1000"
+    end
+    function customDrop(dr)
+      return DNSResponseAction.Drop
+    end
+    addResponseAction("drop.responses.tests.powerdns.com.", LuaResponseAction(customDrop))
+    addResponseAction(RCodeRule(dnsdist.NXDOMAIN), LuaResponseAction(customDelay))
+    """
+
+    def testResponseActionDelayed(self):
+        """
+        Responses: Delayed via LuaResponseAction
+
+        Send an A query to "delayed.responses.tests.powerdns.com.",
+        check that the response delay is longer than 1000 ms
+        for a NXDomain response over UDP, shorter for a NoError one.
+        """
+        name = 'delayed.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        # NX over UDP
+        response.set_rcode(dns.rcode.NXDOMAIN)
+        begin = datetime.now()
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        end = datetime.now()
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        self.assertTrue((end - begin) > timedelta(0, 1))
+
+    def testDropped(self):
+        """
+        Responses: Dropped via user defined LuaResponseAction
+
+        Send an A query to "drop.responses.tests.powerdns.com.",
+        check that the response (not the query) is dropped.
+        """
+        name = 'drop.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(receivedResponse, None)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(receivedResponse, None)