]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
rec: store id in info object, makes a few methods less error prone
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Fri, 14 Feb 2025 13:09:45 +0000 (14:09 +0100)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Mon, 24 Feb 2025 14:32:46 +0000 (15:32 +0100)
pdns/recursordist/lua-recursor4.cc
pdns/recursordist/pdns_recursor.cc
pdns/recursordist/rec-main.cc
pdns/recursordist/rec-main.hh

index 4f16b6368428518d86e341898abbf0869ea0a4d9..b5f83881c5d07f199a1f364f187cd36548a9c0c9 100644 (file)
@@ -449,7 +449,7 @@ void RecursorLua4::postPrepareContext() // NOLINT(readability-function-cognitive
     });
 
   d_lw->writeFunction("getRecursorThreadId", []() {
-    return RecThreadInfo::id();
+    return RecThreadInfo::thread_local_id();
   });
 
   d_lw->writeFunction("sendCustomSNMPTrap", [](const std::string& str) {
index a59ee8ab437c5a6ad52c66946df111cf6723858f..47ab1e096917d3a5382b7e041a0e4ba388ef5d99 100644 (file)
@@ -1185,7 +1185,7 @@ void startDoResolve(void* arg) // NOLINT(readability-function-cognitive-complexi
 
     if (!g_quiet || tracedQuery) {
       if (!g_slogStructured) {
-        g_log << Logger::Warning << RecThreadInfo::id() << " [" << g_multiTasker->getTid() << "/" << g_multiTasker->numProcesses() << "] " << (comboWriter->d_tcp ? "TCP " : "") << "question for '" << comboWriter->d_mdp.d_qname << "|"
+        g_log << Logger::Warning << RecThreadInfo::thread_local_id() << " [" << g_multiTasker->getTid() << "/" << g_multiTasker->numProcesses() << "] " << (comboWriter->d_tcp ? "TCP " : "") << "question for '" << comboWriter->d_mdp.d_qname << "|"
               << QType(comboWriter->d_mdp.d_qtype) << "' from " << comboWriter->getRemote();
         if (!comboWriter->d_ednssubnet.getSource().empty()) {
           g_log << " (ecs " << comboWriter->d_ednssubnet.getSource().toString() << ")";
@@ -1850,7 +1850,7 @@ void startDoResolve(void* arg) // NOLINT(readability-function-cognitive-complexi
       pbMessage.setDeviceName(dnsQuestion.deviceName);
       pbMessage.setToPort(comboWriter->d_destination.getPort());
       pbMessage.addPolicyTags(comboWriter->d_gettagPolicyTags);
-      pbMessage.setWorkerId(RecThreadInfo::id());
+      pbMessage.setWorkerId(RecThreadInfo::thread_local_id());
       pbMessage.setPacketCacheHit(false);
       pbMessage.setOutgoingQueries(resolver.d_outqueries);
       for (const auto& metaValue : dnsQuestion.meta) {
@@ -1885,7 +1885,7 @@ void startDoResolve(void* arg) // NOLINT(readability-function-cognitive-complexi
     uint64_t spentUsec = uSec(resolver.getNow() - comboWriter->d_now);
     if (!g_quiet) {
       if (!g_slogStructured) {
-        g_log << Logger::Error << RecThreadInfo::id() << " [" << g_multiTasker->getTid() << "/" << g_multiTasker->numProcesses() << "] answer to " << (comboWriter->d_mdp.d_header.rd ? "" : "non-rd ") << "question '" << comboWriter->d_mdp.d_qname << "|" << DNSRecordContent::NumberToType(comboWriter->d_mdp.d_qtype);
+        g_log << Logger::Error << RecThreadInfo::thread_local_id() << " [" << g_multiTasker->getTid() << "/" << g_multiTasker->numProcesses() << "] answer to " << (comboWriter->d_mdp.d_header.rd ? "" : "non-rd ") << "question '" << comboWriter->d_mdp.d_qname << "|" << DNSRecordContent::NumberToType(comboWriter->d_mdp.d_qtype);
         g_log << "': " << ntohs(packetWriter.getHeader()->ancount) << " answers, " << ntohs(packetWriter.getHeader()->arcount) << " additional, took " << resolver.d_outqueries << " packets, " << resolver.d_totUsec / 1000.0 << " netw ms, " << static_cast<double>(spentUsec) / 1000.0 << " tot ms, " << resolver.d_throttledqueries << " throttled, " << resolver.d_timeouts << " timeouts, " << resolver.d_tcpoutqueries << "/" << resolver.d_dotoutqueries << " tcp/dot connections, rcode=" << res;
 
         if (!shouldNotValidate && resolver.isDNSSECValidationRequested()) {
index 0e2a66a55f71dbc1d929ff925a80d2eea3130467..4e0ad3825da3eb63a3ab796ab6893e52d376ee19 100644 (file)
@@ -131,8 +131,7 @@ std::shared_ptr<Logr::Logger> g_slogudpout;
 static deferredAdd_t s_deferredUDPadds;
 static deferredAdd_t s_deferredTCPadds;
 
-/* first we have the handler thread, t_id == 0 (some other
-   helper threads like SNMP might have t_id == 0 as well)
+/* first we have the handler thread, t_id == 0 (thread not created as a RecursorThread have t_id = NOT_INITED)
    then the distributor threads if any
    and finally the workers */
 std::vector<RecThreadInfo> RecThreadInfo::s_threadInfos;
@@ -144,7 +143,7 @@ bool RecThreadInfo::s_weDistributeQueries; // if true, 1 or more threads listen
 unsigned int RecThreadInfo::s_numDistributorThreads;
 unsigned int RecThreadInfo::s_numUDPWorkerThreads;
 unsigned int RecThreadInfo::s_numTCPWorkerThreads;
-thread_local unsigned int RecThreadInfo::t_id;
+thread_local unsigned int RecThreadInfo::t_id{RecThreadInfo::TID_NOT_INITED};
 
 pdns::RateLimitedLog g_rateLimitedLogger;
 
@@ -358,7 +357,8 @@ int RecThreadInfo::runThreads(Logr::log_t log)
       serveRustWeb();
     }
     for (auto& tInfo : RecThreadInfo::infos()) {
-      if (tInfo.getName() == "web+stat") { // XXX testing for isHandler() does not work as expected!
+      // who handles the handler? the caller!
+      if (tInfo.isHandler()) {
         continue;
       }
       tInfo.thread.join();
@@ -562,7 +562,7 @@ void protobufLogQuery(LocalStateHolder<LuaConfigItems>& luaconfsLocal, const boo
   msg.setRequestorId(requestorId);
   msg.setDeviceId(deviceId);
   msg.setDeviceName(deviceName);
-  msg.setWorkerId(RecThreadInfo::id());
+  msg.setWorkerId(RecThreadInfo::thread_local_id());
   // For queries, packetCacheHit and outgoingQueries are not relevant
 
   if (!policyTags.empty()) {
@@ -646,7 +646,7 @@ void protobufLogResponse(const struct dnsheader* header, LocalStateHolder<LuaCon
   pbMessage.setDeviceId(deviceId);
   pbMessage.setDeviceName(deviceName);
   pbMessage.setToPort(destination.getPort());
-  pbMessage.setWorkerId(RecThreadInfo::id());
+  pbMessage.setWorkerId(RecThreadInfo::thread_local_id());
   // this method is only used for PC cache hits
   pbMessage.setPacketCacheHit(true);
   // we do not set outgoingQueries, it is not relevant for PC cache hits
@@ -1116,7 +1116,7 @@ static void loggerSDBackend(const Logging::Entry& entry)
   }
   // Thread id filled in by backend, since the SL code does not know about RecursorThreads
   // We use the Recursor thread, other threads get id 0. May need to revisit.
-  appendKeyAndVal("TID", std::to_string(RecThreadInfo::id()));
+  appendKeyAndVal("TID", std::to_string(RecThreadInfo::thread_local_id()));
 
   vector<iovec> iov;
   iov.reserve(strings.size());
@@ -1144,7 +1144,7 @@ static void loggerJSONBackend(const Logging::Entry& entry)
     {"level", std::to_string(entry.level)},
     // Thread id filled in by backend, since the SL code does not know about RecursorThreads
     // We use the Recursor thread, other threads get id 0. May need to revisit.
-    {"tid", std::to_string(RecThreadInfo::id())},
+    {"tid", std::to_string(RecThreadInfo::thread_local_id())},
     {"ts", Logging::toTimestampStringMilli(entry.d_timestamp, timebuf)},
   };
 
@@ -1197,7 +1197,7 @@ static void loggerBackend(const Logging::Entry& entry)
   }
   // Thread id filled in by backend, since the SL code does not know about RecursorThreads
   // We use the Recursor thread, other threads get id 0. May need to revisit.
-  buf << " tid=" << std::quoted(std::to_string(RecThreadInfo::id()));
+  buf << " tid=" << std::quoted(std::to_string(RecThreadInfo::thread_local_id()));
   std::array<char, 64> timebuf{};
   buf << " ts=" << std::quoted(Logging::toTimestampStringMilli(entry.d_timestamp, timebuf));
   for (auto const& value : entry.values) {
@@ -1491,11 +1491,11 @@ void parseACLs()
 
 void broadcastFunction(const pipefunc_t& func)
 {
-  /* This function might be called by the worker with t_id 0 during startup
+  /* This function might be called by the worker with t_id not inited during startup
      for the initialization of ACLs and domain maps. After that it should only
      be called by the handler. */
 
-  if (RecThreadInfo::infos().empty() && RecThreadInfo::id() == 0) {
+  if (RecThreadInfo::infos().empty() && !RecThreadInfo::is_thread_inited()) {
     /* the handler and  distributors will call themselves below, but
        during startup we get called while g_threadInfos has not been
        populated yet to update the ACL or domain maps, so we need to
@@ -1506,7 +1506,7 @@ void broadcastFunction(const pipefunc_t& func)
 
   unsigned int thread = 0;
   for (const auto& threadInfo : RecThreadInfo::infos()) {
-    if (thread++ == RecThreadInfo::id()) {
+    if (thread++ == RecThreadInfo::thread_local_id()) {
       func(); // don't write to ourselves!
       continue;
     }
@@ -1576,16 +1576,15 @@ static RemoteLoggerStats_t& operator+=(RemoteLoggerStats_t& lhs, const RemoteLog
 template <class T>
 T broadcastAccFunction(const std::function<T*()>& func)
 {
-  if (!RecThreadInfo::self().isHandler()) {
-    SLOG(g_log << Logger::Error << "broadcastAccFunction has been called by a worker (" << RecThreadInfo::id() << ")" << endl,
-         g_slog->withName("runtime")->info(Logr::Critical, "broadcastAccFunction has been called by a worker")); // tid will be added
+  if (RecThreadInfo::thread_local_id() != 0) {
+    g_slog->withName("runtime")->info(Logr::Critical, "broadcastAccFunction has been called by a worker"); // tid will be added
     _exit(1);
   }
 
   unsigned int thread = 0;
   T ret = T();
   for (const auto& threadInfo : RecThreadInfo::infos()) {
-    if (thread++ == RecThreadInfo::id()) {
+    if (thread++ == RecThreadInfo::thread_local_id()) {
       continue;
     }
 
@@ -1892,7 +1891,7 @@ static unsigned int initDistribution(Logr::log_t log)
   g_reusePort = ::arg().mustDo("reuseport");
 #endif
 
-  RecThreadInfo::infos().resize(RecThreadInfo::numRecursorThreads());
+  RecThreadInfo::resize(RecThreadInfo::numRecursorThreads());
 
   if (g_reusePort) {
     unsigned int threadNum = 1;
@@ -3357,14 +3356,14 @@ static RecursorControlChannel::Answer* doReloadLuaScript()
       t_pdl->loadFile(fname);
     }
     catch (std::runtime_error& ex) {
-      string msg = std::to_string(RecThreadInfo::id()) + " Retaining current script, could not read '" + fname + "': " + ex.what();
+      string msg = std::to_string(RecThreadInfo::thread_local_id()) + " Retaining current script, could not read '" + fname + "': " + ex.what();
       SLOG(g_log << Logger::Error << msg << endl,
            log->error(Logr::Error, ex.what(), "Retaining current script, could not read new script"));
       return new RecursorControlChannel::Answer{1, msg + "\n"};
     }
   }
   catch (std::exception& e) {
-    SLOG(g_log << Logger::Error << RecThreadInfo::id() << " Retaining current script, error from '" << fname << "': " << e.what() << endl,
+    SLOG(g_log << Logger::Error << RecThreadInfo::thread_local_id() << " Retaining current script, error from '" << fname << "': " << e.what() << endl,
          log->error(Logr::Error, e.what(), "Retaining current script, error in new script"));
     return new RecursorControlChannel::Answer{1, string("retaining current script, error from '" + fname + "': " + e.what() + "\n")};
   }
index d18afcb3aeb4c21575405c39f4d046efcd785016..acc929ce747ad50eeddb62b75f927b2811d9bd48 100644 (file)
@@ -332,8 +332,7 @@ static bool sendResponseOverTCP(const std::unique_ptr<DNSComboWriter>& dc, const
 
 // For communicating with our threads effectively readonly after
 // startup.
-// First we have the handler thread, t_id == 0 (some other helper
-// threads like SNMP might have t_id == 0 as well) then the
+// First we have the handler thread, t_id == 0  then the
 // distributor threads if any and finally the workers
 struct RecThreadInfo
 {
@@ -350,12 +349,16 @@ struct RecThreadInfo
 public:
   static RecThreadInfo& self()
   {
-    return s_threadInfos.at(t_id);
+    auto& info = s_threadInfos.at(t_id);
+    assert(info.d_myid == t_id); // internal consistency check
+    return info;
   }
 
   static RecThreadInfo& info(unsigned int index)
   {
-    return s_threadInfos.at(index);
+    auto& info = s_threadInfos.at(index);
+    assert(info.d_myid == index);
+    return info;
   }
 
   static vector<RecThreadInfo>& infos()
@@ -365,17 +368,11 @@ public:
 
   [[nodiscard]] bool isDistributor() const
   {
-    if (t_id == 0) {
-      return false;
-    }
     return s_weDistributeQueries && listener;
   }
 
   [[nodiscard]] bool isHandler() const
   {
-    if (t_id == 0) {
-      return true;
-    }
     return handler;
   }
 
@@ -427,11 +424,24 @@ public:
     taskThread = true;
   }
 
-  static unsigned int id()
+  static unsigned int thread_local_id()
   {
+    if (t_id == TID_NOT_INITED) {
+      return 0; // backward compatibility
+    }
     return t_id;
   }
 
+  static bool is_thread_inited()
+  {
+    return t_id != TID_NOT_INITED;
+  }
+
+  [[nodiscard]] unsigned int id() const
+  {
+    return d_myid;
+  }
+
   static void setThreadId(unsigned int arg)
   {
     t_id = arg;
@@ -550,6 +560,15 @@ public:
     info(0).thread.join();
   }
 
+  static void resize(size_t size)
+  {
+    s_threadInfos.resize(size);
+    for (unsigned int i = 0; i < size; i++) {
+      s_threadInfos.at(i).d_myid = i;
+    }
+  }
+  static constexpr unsigned int TID_NOT_INITED = std::numeric_limits<unsigned int>::max();
+
 private:
   // FD corresponding to TCP sockets this thread is listening on.
   // These FDs are also in deferredAdds when we have one socket per
@@ -569,6 +588,7 @@ private:
   std::string name;
   std::thread thread;
   int exitCode{0};
+  unsigned int d_myid{TID_NOT_INITED}; // should always equal to the thread_local tid;
 
   // handle the web server, carbon, statistics and the control channel
   bool handler{false};