]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
RPZ notify
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Wed, 10 Jan 2024 08:55:11 +0000 (09:55 +0100)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Wed, 10 Jan 2024 12:08:07 +0000 (13:08 +0100)
pdns/recursordist/pdns_recursor.cc
pdns/recursordist/rpzloader.cc
pdns/recursordist/rpzloader.hh

index b10234786eeff96972c4ed0efd29896da1846c6b..ad6c83c70cabdfa1b4518fd74d4045e10fcf09e1 100644 (file)
@@ -2318,7 +2318,7 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr
       SLOG(g_log << Logger::Notice << RecThreadInfo::id() << " got NOTIFY for " << qname.toLogString() << " from " << source.toStringWithPort() << (source != fromaddr ? " (via " + fromaddr.toStringWithPort() + ")" : "") << endl,
            g_slogudpin->info(Logr::Notice, "Got NOTIFY", "source", Logging::Loggable(source), "remote", Logging::Loggable(fromaddr), "qname", Logging::Loggable(qname)));
     }
-
+    notifyRPZTracker(qname);
     requestWipeCaches(qname);
 
     // the operation will now be treated as a Query, generating
index 45c3576fae932e9366b9de4023c53f7e40fa3af8..02be125662888154cf8950b5f22586b2b54ef9d5 100644 (file)
@@ -386,7 +386,17 @@ static bool dumpZoneToDisk(Logr::log_t logger, const DNSName& zoneName, const st
   return true;
 }
 
-static void preloadRPZFIle(RPZTrackerParams& params, const DNSName& zoneName, std::shared_ptr<DNSFilterEngine::Zone>& oldZone, uint32_t& refresh, const string& polName, Logr::log_t logger)
+// A struct that holds the condition var and related stuff to allow notifies to be sent to the tread owning
+// the struct.
+struct RPZWaiter {
+  RPZWaiter(std::thread::id arg) : id(arg) {}
+  std::thread::id id;
+  std::mutex mutex;
+  std::condition_variable condVar;
+  std::atomic<bool> stop{false};
+};
+
+static void preloadRPZFIle(RPZTrackerParams& params, const DNSName& zoneName, std::shared_ptr<DNSFilterEngine::Zone>& oldZone, uint32_t& refresh, const string& polName, RPZWaiter& rpzwaiter, Logr::log_t logger)
 {
   while (!params.soaRecordContent) {
     /* if we received an empty sr, the zone was not really preloaded */
@@ -426,23 +436,29 @@ static void preloadRPZFIle(RPZTrackerParams& params, const DNSName& zoneName, st
     // Release newZone before (long) sleep to reduce memory usage
     newZone = nullptr;
     if (!params.soaRecordContent) {
-      sleep(refresh);
+      std::unique_lock lock(rpzwaiter.mutex);
+      rpzwaiter.condVar.wait_for(lock, std::chrono::seconds(refresh),
+                              [&stop = rpzwaiter.stop] { return stop.load(); });
     }
+    rpzwaiter.stop = false;
   }
 }
 
-static void RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneName, std::shared_ptr<DNSFilterEngine::Zone>& oldZone, uint32_t& refresh, const string& polName, bool skipRefreshDelay, uint64_t configGeneration, Logr::log_t logger)
+static bool RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneName, std::shared_ptr<DNSFilterEngine::Zone>& oldZone, uint32_t& refresh, const string& polName, bool& skipRefreshDelay, uint64_t configGeneration, RPZWaiter& rpzwaiter, Logr::log_t logger)
 {
   // Don't hold on to oldZone, it well be re-assigned after sleep in the try block
   oldZone = nullptr;
-  DNSRecord dr;
-  dr.setContent(params.soaRecordContent);
+  DNSRecord dnsRecord;
+  dnsRecord.setContent(params.soaRecordContent);
 
   if (skipRefreshDelay) {
     skipRefreshDelay = false;
   }
   else {
-    sleep(refresh);
+    std::unique_lock lock(rpzwaiter.mutex);
+    rpzwaiter.condVar.wait_for(lock, std::chrono::seconds(refresh),
+                            [&stop = rpzwaiter.stop] { return stop.load(); });
+    rpzwaiter.stop = false;
   }
   auto luaconfsLocal = g_luaconfs.getLocal();
 
@@ -452,12 +468,12 @@ static void RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneNam
     */
     SLOG(g_log << Logger::Info << "A more recent configuration has been found, stopping the existing RPZ update thread for " << zoneName << endl,
          logger->info(Logr::Info, "A more recent configuration has been found, stopping the existing RPZ update thread"));
-    return;
+    return false;
   }
 
   vector<pair<vector<DNSRecord>, vector<DNSRecord>>> deltas;
   for (const auto& primary : params.primaries) {
-    auto soa = getRR<SOARecordContent>(dr);
+    auto soa = getRR<SOARecordContent>(dnsRecord);
     auto serial = soa ? soa->d_st.serial : 0;
     SLOG(g_log << Logger::Info << "Getting IXFR deltas for " << zoneName << " from " << primary.toStringWithPort() << ", our serial: " << serial << endl,
          logger->info(Logr::Info, "Getting IXFR deltas", "address", Logging::Loggable(primary), "ourserial", Logging::Loggable(serial)));
@@ -468,7 +484,7 @@ static void RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneNam
     }
 
     try {
-      deltas = getIXFRDeltas(primary, zoneName, dr, params.xfrTimeout, true, params.tsigtriplet, &local, params.maxReceivedBytes);
+      deltas = getIXFRDeltas(primary, zoneName, dnsRecord, params.xfrTimeout, true, params.tsigtriplet, &local, params.maxReceivedBytes);
 
       /* no need to try another primary */
       break;
@@ -482,7 +498,7 @@ static void RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneNam
   }
 
   if (deltas.empty()) {
-    return;
+    return true;
   }
 
   try {
@@ -492,13 +508,13 @@ static void RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneNam
     if (luaconfsLocal->generation != configGeneration) {
       SLOG(g_log << Logger::Info << "A more recent configuration has been found, stopping the existing RPZ update thread for " << zoneName << endl,
            logger->info(Logr::Info, "A more recent configuration has been found, stopping the existing RPZ update thread"));
-      return;
+      return false;
     }
     oldZone = luaconfsLocal->dfe.getZone(params.zoneIdx);
     if (!oldZone || oldZone->getDomain() != zoneName) {
       SLOG(g_log << Logger::Info << "This policy is no more, stopping the existing RPZ update thread for " << zoneName << endl,
            logger->info(Logr::Info, "This policy is no more, stopping the existing RPZ update thread"));
-      return;
+      return false;
     }
     /* we need to make a _full copy_ of the zone we are going to work on */
     std::shared_ptr<DNSFilterEngine::Zone> newZone = std::make_shared<DNSFilterEngine::Zone>(*oldZone);
@@ -578,7 +594,7 @@ static void RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneNam
     if (luaconfsLocal->generation != configGeneration) {
       SLOG(g_log << Logger::Info << "A more recent configuration has been found, stopping the existing RPZ update thread for " << zoneName << endl,
            logger->info(Logr::Info, "A more recent configuration has been found, stopping the existing RPZ update thread"));
-      return;
+      return false;
     }
     g_luaconfs.modify([zoneIdx = params.zoneIdx, &newZone](LuaConfigItems& lci) {
       lci.dfe.setZone(zoneIdx, newZone);
@@ -597,6 +613,29 @@ static void RPZTrackerIteration(RPZTrackerParams& params, const DNSName& zoneNam
     SLOG(g_log << Logger::Error << "Error while applying the update received over XFR for " << zoneName << ", skipping the update: " << e.reason << endl,
          logger->error(Logr::Error, e.reason, "Exception while applying the update received over XFR, skipping", "exception", Logging::Loggable("PDNSException")));
   }
+  return true;
+}
+
+// As there can be multiple threads doing updates (due to config reloads), we use a multimap.
+// The value contains the actual thread id that owns the struct.
+
+static LockGuarded<std::multimap<DNSName, RPZWaiter&>> condVars;
+
+// Notify all threads trakcing the RPZ name
+bool notifyRPZTracker(const DNSName& name)
+{
+  auto lock = condVars.lock();
+  auto [start, end] = lock->equal_range(name);
+  if (start == end) {
+    // Did not find any thread tracking that RPZ name
+    return false;
+  }
+  while (start != end) {
+    start->second.stop = true;
+    start->second.condVar.notify_one();
+    ++start;
+  }
+  return true;
 }
 
 void RPZIXFRTracker(RPZTrackerParams params, uint64_t configGeneration)
@@ -604,6 +643,7 @@ void RPZIXFRTracker(RPZTrackerParams params, uint64_t configGeneration)
   setThreadName("rec/rpzixfr");
   bool isPreloaded = params.soaRecordContent != nullptr;
   auto logger = g_slog->withName("rpz");
+  RPZWaiter waiter(std::this_thread::get_id());
 
   /* we can _never_ modify this zone directly, we need to do a full copy then replace the existing zone */
   std::shared_ptr<DNSFilterEngine::Zone> oldZone = g_luaconfs.getLocal()->dfe.getZone(params.zoneIdx);
@@ -621,11 +661,27 @@ void RPZIXFRTracker(RPZTrackerParams params, uint64_t configGeneration)
   // Now that we know the name, set it in the logger
   logger = logger->withValues("zone", Logging::Loggable(zoneName));
 
-  preloadRPZFIle(params, zoneName, oldZone, refresh, polName, logger);
+  {
+    auto lock = condVars.lock();
+    lock->emplace(zoneName, waiter);
+  }
+  preloadRPZFIle(params, zoneName, oldZone, refresh, polName, waiter, logger);
 
   bool skipRefreshDelay = isPreloaded;
 
-  for (;;) {
-    RPZTrackerIteration(params, zoneName, oldZone, refresh, polName, skipRefreshDelay, configGeneration, logger);
+  while (RPZTrackerIteration(params, zoneName, oldZone, refresh, polName, skipRefreshDelay, configGeneration, waiter, logger)) {
+    // empty
+  }
+
+  // Zap our (and only our) RPZWaiter struct out of the multimap
+  auto lock = condVars.lock();
+  auto [start, end] = lock->equal_range(zoneName);
+  while (start != end) {
+    if (start->second.id == std::this_thread::get_id()) {
+      start = lock->erase(start);
+    }
+    else {
+      ++start;
+    }
   }
 }
index c8811ee0b3d9f90fa91b64430ffe266197bc430d..99fd749945edc599600537e4ff19e783e63df7e7 100644 (file)
@@ -44,6 +44,7 @@ struct RPZTrackerParams {
 
 std::shared_ptr<const SOARecordContent> loadRPZFromFile(const std::string& fname, std::shared_ptr<DNSFilterEngine::Zone> zone, const boost::optional<DNSFilterEngine::Policy>& defpol, bool defpolOverrideLocal, uint32_t maxTTL);
 void RPZIXFRTracker(RPZTrackerParams params, uint64_t configGeneration);
+bool notifyRPZTracker(const DNSName& name);
 
 struct rpzStats
 {