]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Shared throttle map
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Fri, 29 Apr 2022 15:00:23 +0000 (17:00 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Fri, 29 Apr 2022 15:00:23 +0000 (17:00 +0200)
pdns/rec_channel_rec.cc
pdns/recursordist/rec-main.cc
pdns/syncres.cc
pdns/syncres.hh

index 845d78a46396c257f5920dd6de8de53eb16b3a2e..1f76636a02391f0f32acfcd213b4efd692a59bc1 100644 (file)
@@ -966,16 +966,6 @@ static string doCurrentQueries()
   return broadcastAccFunction<string>(pleaseGetCurrentQueries);
 }
 
-uint64_t* pleaseGetThrottleSize()
-{
-  return new uint64_t(SyncRes::getThrottledServersSize());
-}
-
-static uint64_t getThrottleSize()
-{
-  return broadcastAccFunction<uint64_t>(pleaseGetThrottleSize);
-}
-
 static uint64_t getNegCacheSize()
 {
   return g_negCache->size();
@@ -1240,7 +1230,7 @@ static void registerAllStats1()
   addGetStat("max-mthread-stack", &g_stats.maxMThreadStackUsage);
 
   addGetStat("negcache-entries", getNegCacheSize);
-  addGetStat("throttle-entries", getThrottleSize);
+  addGetStat("throttle-entries", SyncRes::getThrottledServersSize);
 
   addGetStat("nsspeeds-entries", SyncRes::getNSSpeedsSize);
   addGetStat("failed-host-entries", SyncRes::getFailedServersSize);
@@ -2002,7 +1992,7 @@ RecursorControlChannel::Answer RecursorControlParser::getAnswer(int s, const str
     return doDumpRPZ(s, begin, end);
   }
   if (cmd == "dump-throttlemap") {
-    return doDumpToFile(s, pleaseDumpThrottleMap, cmd);
+    return doDumpToFile(s, pleaseDumpThrottleMap, cmd, false);
   }
   if (cmd == "dump-non-resolving") {
     return doDumpToFile(s, pleaseDumpNonResolvingNS, cmd, false);
index 4e8fd3a9c333696f6d1b1a5ae60eb922695df727..8aa412df3d2382bfd7e3198b047cfac80f420f5c 100644 (file)
@@ -890,7 +890,7 @@ static void doStats(void)
     g_log << Logger::Notice << "stats: cache contended/acquired " << rc_stats.first << '/' << rc_stats.second << " = " << r << '%' << endl;
 
     g_log << Logger::Notice << "stats: throttle map: "
-          << broadcastAccFunction<uint64_t>(pleaseGetThrottleSize) << ", ns speeds: "
+          << SyncRes::getThrottledServersSize() << ", ns speeds: "
           << SyncRes::getNSSpeedsSize() << ", failed ns: "
           << SyncRes::getFailedServersSize() << ", ednsmap: "
           << broadcastAccFunction<uint64_t>(pleaseGetEDNSStatusesSize) << ", non-resolving: "
index deb87f6c22c75056fbe916d8b3f9859148e15465..d8773c1ff4fc45d3c8c38de70e8b0f385f06c37d 100644 (file)
@@ -231,6 +231,79 @@ public:
 
 static LockGuarded <nsspeeds_t> s_nsSpeeds;
 
+template<class Thing> class Throttle : public boost::noncopyable
+{
+public:
+
+  struct entry_t
+  {
+    Thing thing;
+    time_t ttd;
+    mutable unsigned int count;
+  };
+  typedef multi_index_container<entry_t,
+                                indexed_by<
+                                  ordered_unique<tag<Thing>, member<entry_t, Thing, &entry_t::thing>>,
+                                  ordered_non_unique<tag<time_t>, member<entry_t, time_t, &entry_t::ttd>>
+                                  >> cont_t;
+
+  bool shouldThrottle(time_t now, const Thing &t)
+  {
+    auto i = d_cont.find(t);
+    if (i == d_cont.end()) {
+      return false;
+    }
+    if (now > i->ttd || i->count == 0) {
+      d_cont.erase(i);
+      return false;
+    }
+    i->count--;
+
+    return true; // still listed, still blocked
+  }
+
+  void throttle(time_t now, const Thing &t, time_t ttl, unsigned int count)
+  {
+    auto i = d_cont.find(t);
+    time_t ttd = now + ttl;
+    if (i == d_cont.end()) {
+      entry_t e = { t, ttd, count };
+      d_cont.insert(e);
+    } else if (ttd > i->ttd || count > i->count) {
+      ttd = std::max(i->ttd, ttd);
+      count = std::max(i->count, count);
+      auto &ind = d_cont.template get<Thing>();
+      ind.modify(i, [ttd,count](entry_t &e) { e.ttd = ttd; e.count = count; });
+    }
+  }
+
+  size_t size() const
+  {
+    return d_cont.size();
+  }
+
+  cont_t getThrottleMap() const
+  {
+    return d_cont;
+  }
+
+  void clear()
+  {
+    d_cont.clear();
+  }
+
+  void prune() {
+    time_t now = time(nullptr);
+    auto &ind = d_cont.template get<time_t>();
+    ind.erase(ind.begin(), ind.upper_bound(now));
+  }
+
+private:
+  cont_t d_cont;
+};
+
+static LockGuarded<Throttle<std::tuple<ComboAddress,DNSName,QType>>> s_throttle;
+
 struct SavedParentEntry
 {
   SavedParentEntry(const DNSName& name, map<DNSName, vector<ComboAddress>>&& nsAddresses, time_t ttd)
@@ -1036,6 +1109,36 @@ uint64_t SyncRes::doDumpNSSpeeds(int fd)
   return count;
 }
 
+uint64_t SyncRes::getThrottledServersSize()
+{
+  return s_throttle.lock()->size();
+}
+
+void SyncRes::pruneThrottledServers()
+{
+  s_throttle.lock()->prune();
+}
+
+void SyncRes::clearThrottle()
+{
+  s_throttle.lock()->clear();
+}
+
+bool SyncRes::isThrottled(time_t now, const ComboAddress& server, const DNSName& target, uint16_t qtype)
+{
+  return s_throttle.lock()->shouldThrottle(now, std::make_tuple(server, target, qtype));
+}
+
+bool SyncRes::isThrottled(time_t now, const ComboAddress& server)
+{
+  return s_throttle.lock()->shouldThrottle(now, std::make_tuple(server, g_rootdnsname, 0));
+}
+
+void SyncRes::doThrottle(time_t now, const ComboAddress& server, time_t duration, unsigned int tries)
+{
+  s_throttle.lock()->throttle(now, std::make_tuple(server, g_rootdnsname, 0), duration, tries);
+}
+
 uint64_t SyncRes::doDumpThrottleMap(int fd)
 {
   int newfd = dup(fd);
@@ -1051,13 +1154,14 @@ uint64_t SyncRes::doDumpThrottleMap(int fd)
   fprintf(fp.get(), "; remote IP\tqname\tqtype\tcount\tttd\n");
   uint64_t count=0;
 
-  const auto& throttleMap = t_sstorage.throttle.getThrottleMap();
+  // Get a copy to avoid holding the lock while doing I/O
+  const auto throttleMap = s_throttle.lock()->getThrottleMap();
   for(const auto& i : throttleMap)
   {
     count++;
     char tmp[26];
     // remote IP, dns name, qtype, count, ttd
-    fprintf(fp.get(), "%s\t%s\t%d\t%u\t%s\n", std::get<0>(i.thing).toString().c_str(), std::get<1>(i.thing).toLogString().c_str(), std::get<2>(i.thing), i.count, timestamp(i.ttd, tmp, sizeof(tmp)));
+    fprintf(fp.get(), "%s\t%s\t%s\t%u\t%s\n", std::get<0>(i.thing).toString().c_str(), std::get<1>(i.thing).toLogString().c_str(), std::get<2>(i.thing).toString().c_str(), i.count, timestamp(i.ttd, tmp, sizeof(tmp)));
   }
 
   return count;
@@ -3126,12 +3230,12 @@ vector<ComboAddress> SyncRes::retrieveAddressesForNS(const std::string& prefix,
 
 bool SyncRes::throttledOrBlocked(const std::string& prefix, const ComboAddress& remoteIP, const DNSName& qname, const QType qtype, bool pierceDontQuery)
 {
-  if(t_sstorage.throttle.shouldThrottle(d_now.tv_sec, std::make_tuple(remoteIP, g_rootdnsname, 0))) {
+  if (s_throttle.lock()->shouldThrottle(d_now.tv_sec, std::make_tuple(remoteIP, g_rootdnsname, 0))) {
     LOG(prefix<<qname<<": server throttled "<<endl);
     s_throttledqueries++; d_throttledqueries++;
     return true;
   }
-  else if(t_sstorage.throttle.shouldThrottle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()))) {
+  else if (s_throttle.lock()->shouldThrottle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()))) {
     LOG(prefix<<qname<<": query throttled "<<remoteIP.toString()<<", "<<qname<<"; "<<qtype<<endl);
     s_throttledqueries++; d_throttledqueries++;
     return true;
@@ -4950,15 +5054,15 @@ bool SyncRes::doResolveAtThisIP(const std::string& prefix, const DNSName& qname,
       if (s_serverdownmaxfails > 0 && (auth != g_rootdnsname) && s_fails.lock()->incr(remoteIP, d_now) >= s_serverdownmaxfails) {
         LOG(prefix<<qname<<": Max fails reached resolving on "<< remoteIP.toString() <<". Going full throttle for "<< s_serverdownthrottletime <<" seconds" <<endl);
         // mark server as down
-        t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(remoteIP, g_rootdnsname, 0), s_serverdownthrottletime, 10000);
+        s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(remoteIP, g_rootdnsname, 0), s_serverdownthrottletime, 10000);
       }
       else if (resolveret == LWResult::Result::PermanentError) {
         // unreachable, 1 minute or 100 queries
-        t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 100);
+        s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 100);
       }
       else {
         // timeout, 10 seconds or 5 queries
-        t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 10, 5);
+        s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 10, 5);
       }
     }
 
@@ -4974,10 +5078,10 @@ bool SyncRes::doResolveAtThisIP(const std::string& prefix, const DNSName& qname,
 
       if (doTCP) {
         // we can be more heavy-handed over TCP
-        t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 10);
+        s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 10);
       }
       else {
-        t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 10, 2);
+        s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 10, 2);
       }
     }
     return false;
@@ -4994,7 +5098,7 @@ bool SyncRes::doResolveAtThisIP(const std::string& prefix, const DNSName& qname,
           s_nsSpeeds.lock()->find_or_enter(nsName.empty()? DNSName(remoteIP.toStringWithPort()) : nsName, d_now).submit(remoteIP, 1000000, d_now); // 1 sec
         }
         else {
-          t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 3);
+          s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 3);
         }
       }
       return false;
@@ -5013,7 +5117,7 @@ bool SyncRes::doResolveAtThisIP(const std::string& prefix, const DNSName& qname,
       LOG(prefix<<qname<<": truncated bit set, over TCP?"<<endl);
       if (!dontThrottle) {
         /* let's treat that as a ServFail answer from this server */
-        t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 3);
+        s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(remoteIP, qname, qtype.getCode()), 60, 3);
       }
       return false;
     }
@@ -5425,7 +5529,7 @@ int SyncRes::doResolveAt(NsSet &nameservers, DNSName auth, bool flawedNSSet, con
             break;
           }
           /* was lame */
-          t_sstorage.throttle.throttle(d_now.tv_sec, std::make_tuple(*remoteIP, qname, qtype.getCode()), 60, 100);
+          s_throttle.lock()->throttle(d_now.tv_sec, std::make_tuple(*remoteIP, qname, qtype.getCode()), 60, 100);
         }
 
         if (gotNewServers) {
index 11afeb1ef8900eb577e20603d396b6ba42fac26e..db215783b9aa3bcf907033c67d9d31b767ff8b9a 100644 (file)
@@ -79,76 +79,6 @@ typedef std::unordered_map<
   >
 > NsSet;
 
-template<class Thing> class Throttle : public boost::noncopyable
-{
-public:
-
-  struct entry_t
-  {
-    Thing thing;
-    time_t ttd;
-    mutable unsigned int count;
-  };
-  typedef multi_index_container<entry_t,
-                                indexed_by<
-                                  ordered_unique<tag<Thing>, member<entry_t, Thing, &entry_t::thing>>,
-                                  ordered_non_unique<tag<time_t>, member<entry_t, time_t, &entry_t::ttd>>
-                                  >> cont_t;
-
-  bool shouldThrottle(time_t now, const Thing &t)
-  {
-    auto i = d_cont.find(t);
-    if (i == d_cont.end()) {
-      return false;
-    }
-    if (now > i->ttd || i->count == 0) {
-      d_cont.erase(i);
-      return false;
-    }
-    i->count--;
-
-    return true; // still listed, still blocked
-  }
-
-  void throttle(time_t now, const Thing &t, time_t ttl, unsigned int count)
-  {
-    auto i = d_cont.find(t);
-    time_t ttd = now + ttl;
-    if (i == d_cont.end()) {
-      entry_t e = { t, ttd, count };
-      d_cont.insert(e);
-    } else if (ttd > i->ttd || count > i->count) {
-      ttd = std::max(i->ttd, ttd);
-      count = std::max(i->count, count);
-      auto &ind = d_cont.template get<Thing>();
-      ind.modify(i, [ttd,count](entry_t &e) { e.ttd = ttd; e.count = count; });
-    }
-  }
-
-  size_t size() const
-  {
-    return d_cont.size();
-  }
-
-  const cont_t &getThrottleMap() const
-  {
-    return d_cont;
-  }
-
-  void clear()
-  {
-    d_cont.clear();
-  }
-
-  void prune() {
-    time_t now = time(nullptr);
-    auto &ind = d_cont.template get<time_t>();
-    ind.erase(ind.begin(), ind.upper_bound(now));
-  }
-
-private:
-  cont_t d_cont;
-};
 
 extern std::unique_ptr<NegCache> g_negCache;
 
@@ -206,7 +136,6 @@ public:
   };
 
   typedef std::unordered_map<DNSName, AuthDomain> domainmap_t;
-  typedef Throttle<std::tuple<ComboAddress,DNSName,uint16_t> > throttle_t;
 
   struct EDNSStatus {
     EDNSStatus(const ComboAddress &arg) : address(arg) {}
@@ -238,7 +167,6 @@ public:
   };
 
   struct ThreadLocalStorage {
-    throttle_t throttle;
     ednsstatus_t ednsstatus;
     std::shared_ptr<domainmap_t> domainmap;
   };
@@ -327,30 +255,13 @@ public:
   {
     t_sstorage.ednsstatus.prune(cutoff);
   }
-  static uint64_t getThrottledServersSize()
-  {
-    return t_sstorage.throttle.size();
-  }
-  static void pruneThrottledServers()
-  {
-    t_sstorage.throttle.prune();
-  }
-  static void clearThrottle()
-  {
-    t_sstorage.throttle.clear();
-  }
-  static bool isThrottled(time_t now, const ComboAddress& server, const DNSName& target, uint16_t qtype)
-  {
-    return t_sstorage.throttle.shouldThrottle(now, std::make_tuple(server, target, qtype));
-  }
-  static bool isThrottled(time_t now, const ComboAddress& server)
-  {
-    return t_sstorage.throttle.shouldThrottle(now, std::make_tuple(server, g_rootdnsname, 0));
-  }
-  static void doThrottle(time_t now, const ComboAddress& server, time_t duration, unsigned int tries)
-  {
-    t_sstorage.throttle.throttle(now, std::make_tuple(server, g_rootdnsname, 0), duration, tries);
-  }
+
+  static uint64_t getThrottledServersSize();
+  static void pruneThrottledServers();
+  static void clearThrottle();
+  static bool isThrottled(time_t now, const ComboAddress& server, const DNSName& target, uint16_t qtype);
+  static bool isThrottled(time_t now, const ComboAddress& server);
+  static void doThrottle(time_t now, const ComboAddress& server, time_t duration, unsigned int tries);
 
   static uint64_t getFailedServersSize();
   static void clearFailedServers();