]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - pdns/dnsdist-lua-actions.cc
Merge pull request #10501 from rgacogne/ddist-per-thread-lua-ffi
[thirdparty/pdns.git] / pdns / dnsdist-lua-actions.cc
index 6536f5133f4897e904de4be1aa604ce3e873361e..453b5c06fc6ef01c319b63763b54c15ed2d96fca 100644 (file)
@@ -489,6 +489,74 @@ private:
   func_t d_func;
 };
 
+class LuaFFIPerThreadAction: public DNSAction
+{
+public:
+  typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIPerThreadAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+  {
+  }
+
+  DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+  {
+    try {
+      auto& state = t_perThreadStates[d_functionID];
+      if (!state.d_initialized) {
+        setupLuaFFIPerThreadContext(state.d_luaContext);
+        /* mark the state as initialized first so if there is a syntax error
+           we only try to execute the code once */
+        state.d_initialized = true;
+        state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+      }
+
+      if (!state.d_func) {
+        /* the function was not properly initialized */
+        return DNSAction::Action::None;
+      }
+
+      dnsdist_ffi_dnsquestion_t dqffi(dq);
+      auto ret = state.d_func(&dqffi);
+      if (ruleresult) {
+        if (dqffi.result) {
+          *ruleresult = *dqffi.result;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
+      }
+      return static_cast<DNSAction::Action>(ret);
+    }
+    catch (const std::exception &e) {
+      warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what());
+    }
+    catch (...) {
+      warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]");
+    }
+    return DNSAction::Action::ServFail;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI per-thread script";
+  }
+
+private:
+  struct PerThreadState
+  {
+    LuaContext d_luaContext;
+    func_t d_func;
+    bool d_initialized{false};
+  };
+  static std::atomic<uint64_t> s_functionsCounter;
+  static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+  const std::string d_functionCode;
+  const uint64_t d_functionID;
+};
+
+std::atomic<uint64_t> LuaFFIPerThreadAction::s_functionsCounter = 0;
+thread_local std::map<uint64_t, LuaFFIPerThreadAction::PerThreadState> LuaFFIPerThreadAction::t_perThreadStates;
 
 class LuaFFIResponseAction: public DNSResponseAction
 {
@@ -537,6 +605,81 @@ private:
   func_t d_func;
 };
 
+class LuaFFIPerThreadResponseAction: public DNSResponseAction
+{
+public:
+  typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIPerThreadResponseAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+  {
+  }
+
+  DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+  {
+    DNSQuestion* dq = dynamic_cast<DNSQuestion*>(dr);
+    if (dq == nullptr) {
+      return DNSResponseAction::Action::ServFail;
+    }
+
+    try {
+      auto& state = t_perThreadStates[d_functionID];
+      if (!state.d_initialized) {
+        setupLuaFFIPerThreadContext(state.d_luaContext);
+        /* mark the state as initialized first so if there is a syntax error
+           we only try to execute the code once */
+        state.d_initialized = true;
+        state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+      }
+
+      if (!state.d_func) {
+        /* the function was not properly initialized */
+        return DNSResponseAction::Action::None;
+      }
+
+      dnsdist_ffi_dnsquestion_t dqffi(dq);
+      auto ret = state.d_func(&dqffi);
+      if (ruleresult) {
+        if (dqffi.result) {
+          *ruleresult = *dqffi.result;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
+      }
+      return static_cast<DNSResponseAction::Action>(ret);
+    }
+    catch (const std::exception &e) {
+      warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what());
+    }
+    catch (...) {
+      warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]");
+    }
+    return DNSResponseAction::Action::ServFail;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI per-thread script";
+  }
+
+private:
+  struct PerThreadState
+  {
+    LuaContext d_luaContext;
+    func_t d_func;
+    bool d_initialized{false};
+  };
+
+  static std::atomic<uint64_t> s_functionsCounter;
+  static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+  const std::string d_functionCode;
+  const uint64_t d_functionID;
+};
+
+std::atomic<uint64_t> LuaFFIPerThreadResponseAction::s_functionsCounter = 0;
+thread_local std::map<uint64_t, LuaFFIPerThreadResponseAction::PerThreadState> LuaFFIPerThreadResponseAction::t_perThreadStates;
+
 thread_local std::default_random_engine SpoofAction::t_randomEngine;
 
 DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
@@ -1762,6 +1905,11 @@ void setupLuaActions(LuaContext& luaCtx)
       return std::shared_ptr<DNSAction>(new LuaFFIAction(func));
     });
 
+  luaCtx.writeFunction("LuaFFIPerThreadAction", [](std::string code) {
+      setLuaSideEffect();
+      return std::shared_ptr<DNSAction>(new LuaFFIPerThreadAction(code));
+    });
+
   luaCtx.writeFunction("SetNoRecurseAction", []() {
       return std::shared_ptr<DNSAction>(new SetNoRecurseAction);
     });
@@ -1928,6 +2076,11 @@ void setupLuaActions(LuaContext& luaCtx)
       return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(func));
     });
 
+  luaCtx.writeFunction("LuaFFIPerThreadResponseAction", [](std::string code) {
+      setLuaSideEffect();
+      return std::shared_ptr<DNSResponseAction>(new LuaFFIPerThreadResponseAction(code));
+    });
+
   luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
       if (logger) {
         // avoids potentially-evaluated-expression warning with clang.