]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Basic code to fill ns speed table from a dump
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Thu, 3 Jul 2025 09:05:22 +0000 (11:05 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Wed, 9 Jul 2025 09:17:19 +0000 (11:17 +0200)
Signed-off-by: Otto Moerbeek <otto.moerbeek@open-xchange.com>
pdns/recursordist/lua-recursor4.cc
pdns/recursordist/syncres.cc
pdns/recursordist/syncres.hh

index a1e03ffcacda6dcf6ec82bb6069c759f0ee5c6b3..a91ca83fd0102071770c1388f653540cc8cc2211 100644 (file)
@@ -529,6 +529,10 @@ void RecursorLua4::postPrepareContext() // NOLINT(readability-function-cognitive
     return std::tuple<std::string, size_t>{ret, number};
   });
 
+  d_lw->writeFunction("putIntoNSSpeedTable", [](const string& data) {
+    return SyncRes::putIntoNSSpeedTable(data);
+  });
+
   if (!d_include_path.empty()) {
     includePath(d_include_path);
   }
index 7669e75a64092906d6cab97b9bf5782c5c027e8c..9b09876602fcb6a4fff04e3c63f71b8be7a52076 100644 (file)
@@ -201,6 +201,11 @@ public:
     }
   }
 
+  void insert(const ComboAddress& address, float val, int last)
+  {
+    d_collection.insert(std::make_pair(address, DecayingEwma{val, last}));
+  }
+  
   // d_collection is the modifyable part of the record, we index on DNSName and timeval, and DNSName never changes
   mutable std::map<ComboAddress, DecayingEwma> d_collection;
   DNSName d_name;
@@ -303,6 +308,120 @@ public:
     return count;
   }
 
+
+  template <typename T>
+  bool putEntry(T& message)
+  {
+    DecayingEwmaCollection entry{{}};
+    while (message.next()) {
+      switch (message.tag()) {
+      case PBNSSpeedEntry::required_bytes_name:
+        entry.d_name = DNSName(message.get_bytes());
+        break;
+      case PBNSSpeedEntry::required_int64_lastgets:
+        entry.d_lastget.tv_sec = message.get_int64();
+        break;
+      case PBNSSpeedEntry::required_int64_lastgetus:
+        entry.d_lastget.tv_usec = message.get_int64();
+        break;
+      case PBNSSpeedEntry::repeated_message_map: {
+        protozero::pbf_message<PBNSSpeedMap> map = message.get_message();
+        ComboAddress address;
+        float val{};
+        int last{};
+        while (map.next()) {
+          switch (map.tag()) {
+          case PBNSSpeedMap::required_bytes_address:
+            decodeComboAddress(map, address);
+            break;
+          case PBNSSpeedMap::required_float_val:
+            val = map.get_float();
+            break;
+          case PBNSSpeedMap::required_int32_last:
+            last = map.get_int32();
+            break;
+          }
+        }
+        entry.insert(address, val, last);
+        break;
+      }
+      }
+    }
+    return insert(std::move(entry)).second;
+  }
+
+  size_t putPB(const std::string& pbuf)
+  {
+    auto log = g_slog->withName("syncres")->withValues("size", Logging::Loggable(pbuf.size()));
+    log->info(Logr::Debug, "Processing nsspeed dump");
+
+    protozero::pbf_message<PBNSSpeedDump> full(pbuf);
+    size_t count = 0;
+    size_t inserted = 0;
+    try {
+      bool protocolVersionSeen = false;
+      bool typeSeen = false;
+      while (full.next()) {
+        switch (full.tag()) {
+        case PBNSSpeedDump::required_string_version: {
+          auto version = full.get_string();
+          log = log->withValues("version", Logging::Loggable(version));
+          break;
+        }
+        case PBNSSpeedDump::required_string_identity: {
+          auto identity = full.get_string();
+          log = log->withValues("identity", Logging::Loggable(identity));
+          break;
+        }
+        case PBNSSpeedDump::required_uint64_protocolVersion: {
+          auto protocolVersion = full.get_uint64();
+          log = log->withValues("protocolVersion", Logging::Loggable(protocolVersion));
+          if (protocolVersion != 1) {
+            throw std::runtime_error("Protocol version mismatch");
+          }
+          protocolVersionSeen = true;
+          break;
+        }
+        case PBNSSpeedDump::required_int64_time: {
+          auto time = full.get_int64();
+          log = log->withValues("time", Logging::Loggable(time));
+          break;
+        }
+        case PBNSSpeedDump::required_string_type: {
+          auto type = full.get_string();
+          if (type != "PBNSSpeedDump") {
+            throw std::runtime_error("Data type mismatch");
+          }
+          typeSeen = true;
+          break;
+        }
+        case PBNSSpeedDump::repeated_message_nsspeedEntry: {
+          if (!protocolVersionSeen || !typeSeen) {
+            throw std::runtime_error("Required field missing");
+          }
+          protozero::pbf_message<PBNSSpeedEntry> message = full.get_message();
+          if (putEntry(message)) {
+            ++inserted;
+          }
+          ++count;
+          break;
+        }
+        }
+      }
+      log->info(Logr::Info, "Processed nsspeed dump", "processed", Logging::Loggable(count), "inserted", Logging::Loggable(inserted));
+      return inserted;
+    }
+    catch (const std::runtime_error& e) {
+      log->error(Logr::Error, e.what(), "Runtime exception processing cache dump");
+    }
+    catch (const std::exception& e) {
+      log->error(Logr::Error, e.what(), "Exception processing cache dump");
+    }
+    catch (...) {
+      log->error(Logr::Error, "Other exception processing cache dump");
+    }
+    return 0;
+  }
 };
 
 static LockGuarded<nsspeeds_t> s_nsSpeeds;
@@ -313,6 +432,12 @@ size_t SyncRes::getNSSpeedTable(std::string& ret)
   return copy.getPB(ret);
 }
 
+size_t SyncRes::putIntoNSSpeedTable(const std::string& ret)
+{
+  auto lock = s_nsSpeeds.lock();
+  return lock->putPB(ret);
+}
+
 class Throttle
 {
 public:
index 614741abe6b208e7bd870ea4351a53c23d65aca4..0a661a71f5bd01840f95793e5d4e292da89915a2 100644 (file)
@@ -171,6 +171,7 @@ public:
   static uint64_t doDumpDoTProbeMap(int fileDesc);
 
   static size_t getNSSpeedTable(std::string& ret);
+  static size_t putIntoNSSpeedTable(const std::string& ret);
 
   static int getRootNS(struct timeval now, asyncresolve_t asyncCallback, unsigned int depth, Logr::log_t);
   static void addDontQuery(const std::string& mask)