]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Merge pull request #10501 from rgacogne/ddist-per-thread-lua-ffi
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 20 Jul 2021 08:55:50 +0000 (10:55 +0200)
committerGitHub <noreply@github.com>
Tue, 20 Jul 2021 08:55:50 +0000 (10:55 +0200)
dnsdist: Add support for Lua per-thread FFI rules and actions

1  2 
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/rules-actions.rst
regression-tests.dnsdist/test_Advanced.py

index 6536f5133f4897e904de4be1aa604ce3e873361e,b067741c78d08fe2d5abe1addad689978ccf0534..453b5c06fc6ef01c319b63763b54c15ed2d96fca
@@@ -181,7 -181,7 +181,7 @@@ TeeAction::~TeeAction(
  
  DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
  {
 -  if (dq->tcp) {
 +  if (dq->overTCP()) {
      d_tcpdrops++;
    }
    else {
@@@ -489,6 -489,74 +489,74 @@@ private
    func_t d_func;
  };
  
+ class LuaFFIPerThreadAction: public DNSAction
+ {
+ public:
+   typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+   LuaFFIPerThreadAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+   {
+   }
+   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+   {
+     try {
+       auto& state = t_perThreadStates[d_functionID];
+       if (!state.d_initialized) {
+         setupLuaFFIPerThreadContext(state.d_luaContext);
+         /* mark the state as initialized first so if there is a syntax error
+            we only try to execute the code once */
+         state.d_initialized = true;
+         state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+       }
+       if (!state.d_func) {
+         /* the function was not properly initialized */
+         return DNSAction::Action::None;
+       }
+       dnsdist_ffi_dnsquestion_t dqffi(dq);
+       auto ret = state.d_func(&dqffi);
+       if (ruleresult) {
+         if (dqffi.result) {
+           *ruleresult = *dqffi.result;
+         }
+         else {
+           // default to empty string
+           ruleresult->clear();
+         }
+       }
+       return static_cast<DNSAction::Action>(ret);
+     }
+     catch (const std::exception &e) {
+       warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what());
+     }
+     catch (...) {
+       warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]");
+     }
+     return DNSAction::Action::ServFail;
+   }
+   string toString() const override
+   {
+     return "Lua FFI per-thread script";
+   }
+ private:
+   struct PerThreadState
+   {
+     LuaContext d_luaContext;
+     func_t d_func;
+     bool d_initialized{false};
+   };
+   static std::atomic<uint64_t> s_functionsCounter;
+   static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+   const std::string d_functionCode;
+   const uint64_t d_functionID;
+ };
+ std::atomic<uint64_t> LuaFFIPerThreadAction::s_functionsCounter = 0;
+ thread_local std::map<uint64_t, LuaFFIPerThreadAction::PerThreadState> LuaFFIPerThreadAction::t_perThreadStates;
  
  class LuaFFIResponseAction: public DNSResponseAction
  {
@@@ -537,6 -605,81 +605,81 @@@ private
    func_t d_func;
  };
  
+ class LuaFFIPerThreadResponseAction: public DNSResponseAction
+ {
+ public:
+   typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+   LuaFFIPerThreadResponseAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+   {
+   }
+   DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+   {
+     DNSQuestion* dq = dynamic_cast<DNSQuestion*>(dr);
+     if (dq == nullptr) {
+       return DNSResponseAction::Action::ServFail;
+     }
+     try {
+       auto& state = t_perThreadStates[d_functionID];
+       if (!state.d_initialized) {
+         setupLuaFFIPerThreadContext(state.d_luaContext);
+         /* mark the state as initialized first so if there is a syntax error
+            we only try to execute the code once */
+         state.d_initialized = true;
+         state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+       }
+       if (!state.d_func) {
+         /* the function was not properly initialized */
+         return DNSResponseAction::Action::None;
+       }
+       dnsdist_ffi_dnsquestion_t dqffi(dq);
+       auto ret = state.d_func(&dqffi);
+       if (ruleresult) {
+         if (dqffi.result) {
+           *ruleresult = *dqffi.result;
+         }
+         else {
+           // default to empty string
+           ruleresult->clear();
+         }
+       }
+       return static_cast<DNSResponseAction::Action>(ret);
+     }
+     catch (const std::exception &e) {
+       warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what());
+     }
+     catch (...) {
+       warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]");
+     }
+     return DNSResponseAction::Action::ServFail;
+   }
+   string toString() const override
+   {
+     return "Lua FFI per-thread script";
+   }
+ private:
+   struct PerThreadState
+   {
+     LuaContext d_luaContext;
+     func_t d_func;
+     bool d_initialized{false};
+   };
+   static std::atomic<uint64_t> s_functionsCounter;
+   static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+   const std::string d_functionCode;
+   const uint64_t d_functionID;
+ };
+ std::atomic<uint64_t> LuaFFIPerThreadResponseAction::s_functionsCounter = 0;
+ thread_local std::map<uint64_t, LuaFFIPerThreadResponseAction::PerThreadState> LuaFFIPerThreadResponseAction::t_perThreadStates;
  thread_local std::default_random_engine SpoofAction::t_randomEngine;
  
  DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
@@@ -727,25 -870,35 +870,25 @@@ class LogAction : public DNSAction, pub
  {
  public:
    // this action does not stop the processing
 -  LogAction(): d_fp(nullptr, fclose)
 +  LogAction()
    {
    }
  
 -  LogAction(const std::string& str, bool binary=true, bool append=false, bool buffered=true, bool verboseOnly=true, bool includeTimestamp=false): d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp)
 +  LogAction(const std::string& str, bool binary=true, bool append=false, bool buffered=true, bool verboseOnly=true, bool includeTimestamp=false): d_fname(str), d_binary(binary), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered)
    {
      if (str.empty()) {
        return;
      }
  
 -    if(append) {
 -      d_fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(str.c_str(), "a+"), fclose);
 -    }
 -    else {
 -      d_fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(str.c_str(), "w"), fclose);
 -    }
 -
 -    if (!d_fp) {
 -      throw std::runtime_error("Unable to open file '"+str+"' for logging: "+stringerror());
 -    }
 -
 -    if (!buffered) {
 -      setbuf(d_fp.get(), 0);
 +    if (!reopenLogFile())  {
 +      throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror());
      }
    }
  
    DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
    {
 -    if (!d_fp) {
 +    auto fp = std::atomic_load_explicit(&d_fp, std::memory_order_acquire);
 +    if (!fp) {
        if (!d_verboseOnly || g_verbose) {
          if (d_includeTimestamp) {
            infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast<unsigned long long>(dq->queryTime->tv_sec), static_cast<unsigned long>(dq->queryTime->tv_nsec), dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).toString(), dq->getHeader()->id);
          if (d_includeTimestamp) {
            uint64_t tv_sec = static_cast<uint64_t>(dq->queryTime->tv_sec);
            uint32_t tv_nsec = static_cast<uint32_t>(dq->queryTime->tv_nsec);
 -          fwrite(&tv_sec, sizeof(tv_sec), 1, d_fp.get());
 -          fwrite(&tv_nsec, sizeof(tv_nsec), 1, d_fp.get());
 +          fwrite(&tv_sec, sizeof(tv_sec), 1, fp.get());
 +          fwrite(&tv_nsec, sizeof(tv_nsec), 1, fp.get());
          }
          uint16_t id = dq->getHeader()->id;
 -        fwrite(&id, sizeof(id), 1, d_fp.get());
 -        fwrite(out.c_str(), 1, out.size(), d_fp.get());
 -        fwrite(&dq->qtype, sizeof(dq->qtype), 1, d_fp.get());
 -        fwrite(&dq->remote->sin4.sin_family, sizeof(dq->remote->sin4.sin_family), 1, d_fp.get());
 +        fwrite(&id, sizeof(id), 1, fp.get());
 +        fwrite(out.c_str(), 1, out.size(), fp.get());
 +        fwrite(&dq->qtype, sizeof(dq->qtype), 1, fp.get());
 +        fwrite(&dq->remote->sin4.sin_family, sizeof(dq->remote->sin4.sin_family), 1, fp.get());
          if (dq->remote->sin4.sin_family == AF_INET) {
 -          fwrite(&dq->remote->sin4.sin_addr.s_addr, sizeof(dq->remote->sin4.sin_addr.s_addr), 1, d_fp.get());
 +          fwrite(&dq->remote->sin4.sin_addr.s_addr, sizeof(dq->remote->sin4.sin_addr.s_addr), 1, fp.get());
          }
          else if (dq->remote->sin4.sin_family == AF_INET6) {
 -          fwrite(&dq->remote->sin6.sin6_addr.s6_addr, sizeof(dq->remote->sin6.sin6_addr.s6_addr), 1, d_fp.get());
 +          fwrite(&dq->remote->sin6.sin6_addr.s6_addr, sizeof(dq->remote->sin6.sin6_addr.s6_addr), 1, fp.get());
          }
 -        fwrite(&dq->remote->sin4.sin_port, sizeof(dq->remote->sin4.sin_port), 1, d_fp.get());
 +        fwrite(&dq->remote->sin4.sin_port, sizeof(dq->remote->sin4.sin_port), 1, fp.get());
        }
        else {
          if (d_includeTimestamp) {
 -          fprintf(d_fp.get(), "[%llu.%lu] Packet from %s for %s %s with id %d\n", static_cast<unsigned long long>(dq->queryTime->tv_sec), static_cast<unsigned long>(dq->queryTime->tv_nsec), dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).toString().c_str(), dq->getHeader()->id);
 +          fprintf(fp.get(), "[%llu.%lu] Packet from %s for %s %s with id %d\n", static_cast<unsigned long long>(dq->queryTime->tv_sec), static_cast<unsigned long>(dq->queryTime->tv_nsec), dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).toString().c_str(), dq->getHeader()->id);
          }
          else {
 -          fprintf(d_fp.get(), "Packet from %s for %s %s with id %d\n", dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).toString().c_str(), dq->getHeader()->id);
 +          fprintf(fp.get(), "Packet from %s for %s %s with id %d\n", dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).toString().c_str(), dq->getHeader()->id);
          }
        }
      }
      }
      return "log";
    }
 +
 +  void reload() override
 +  {
 +    if (!reopenLogFile()) {
 +      warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror());
 +    }
 +  }
 +
  private:
 +  bool reopenLogFile()
 +  {
 +    // we are using a naked pointer here because we don't want fclose to be called
 +    // with a nullptr, which would happen if we constructor a shared_ptr with fclose
 +    // as a custom deleter and nullptr as a FILE*
 +    auto nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w");
 +    if (!nfp) {
 +      /* don't fall on our sword when reopening */
 +      return false;
 +    }
 +
 +    auto fp = std::shared_ptr<FILE>(nfp, fclose);
 +    nfp = nullptr;
 +
 +    if (!d_buffered) {
 +      setbuf(fp.get(), 0);
 +    }
 +
 +    std::atomic_store_explicit(&d_fp, fp, std::memory_order_release);
 +    return true;
 +  }
 +
    std::string d_fname;
 -  std::unique_ptr<FILE, int(*)(FILE*)> d_fp{nullptr, fclose};
 +  std::shared_ptr<FILE> d_fp{nullptr};
    bool d_binary{true};
    bool d_verboseOnly{true};
    bool d_includeTimestamp{false};
 +  bool d_append{false};
 +  bool d_buffered{true};
  };
  
  class LogResponseAction : public DNSResponseAction, public boost::noncopyable
  {
  public:
 -  LogResponseAction(): d_fp(nullptr, fclose)
 +  LogResponseAction()
    {
    }
  
 -  LogResponseAction(const std::string& str, bool append=false, bool buffered=true, bool verboseOnly=true, bool includeTimestamp=false): d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp)
 +  LogResponseAction(const std::string& str, bool append=false, bool buffered=true, bool verboseOnly=true, bool includeTimestamp=false): d_fname(str), d_verboseOnly(verboseOnly), d_includeTimestamp(includeTimestamp), d_append(append), d_buffered(buffered)
    {
      if (str.empty()) {
        return;
      }
  
 -    if (append) {
 -      d_fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(str.c_str(), "a+"), fclose);
 -    }
 -    else {
 -      d_fp = std::unique_ptr<FILE, int(*)(FILE*)>(fopen(str.c_str(), "w"), fclose);
 -    }
 -
 -    if (!d_fp) {
 -      throw std::runtime_error("Unable to open file '"+str+"' for logging: "+stringerror());
 -    }
 -
 -    if (!buffered) {
 -      setbuf(d_fp.get(), 0);
 +    if (!reopenLogFile()) {
 +      throw std::runtime_error("Unable to open file '" + str + "' for logging: " + stringerror());
      }
    }
  
    DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
    {
 -    if (!d_fp) {
 +    auto fp = std::atomic_load_explicit(&d_fp, std::memory_order_acquire);
 +    if (!fp) {
        if (!d_verboseOnly || g_verbose) {
          if (d_includeTimestamp) {
            infolog("[%u.%u] Answer to %s for %s %s (%s) with id %d", static_cast<unsigned long long>(dr->queryTime->tv_sec), static_cast<unsigned long>(dr->queryTime->tv_nsec), dr->remote->toStringWithPort(), dr->qname->toString(), QType(dr->qtype).toString(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id);
      }
      else {
        if (d_includeTimestamp) {
 -        fprintf(d_fp.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %d\n", static_cast<unsigned long long>(dr->queryTime->tv_sec), static_cast<unsigned long>(dr->queryTime->tv_nsec), dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id);
 +        fprintf(fp.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %d\n", static_cast<unsigned long long>(dr->queryTime->tv_sec), static_cast<unsigned long>(dr->queryTime->tv_nsec), dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id);
        }
        else {
 -        fprintf(d_fp.get(), "Answer to %s for %s %s (%s) with id %d\n", dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id);
 +        fprintf(fp.get(), "Answer to %s for %s %s (%s) with id %d\n", dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id);
        }
      }
      return Action::None;
      }
      return "log";
    }
 +
 +  void reload() override
 +  {
 +    if (!reopenLogFile()) {
 +      warnlog("Unable to open file '%s' for logging: %s", d_fname, stringerror());
 +    }
 +  }
 +
  private:
 +  bool reopenLogFile()
 +  {
 +    // we are using a naked pointer here because we don't want fclose to be called
 +    // with a nullptr, which would happen if we constructor a shared_ptr with fclose
 +    // as a custom deleter and nullptr as a FILE*
 +    auto nfp = fopen(d_fname.c_str(), d_append ? "a+" : "w");
 +    if (!nfp) {
 +      /* don't fall on our sword when reopening */
 +      return false;
 +    }
 +
 +    auto fp = std::shared_ptr<FILE>(nfp, fclose);
 +    nfp = nullptr;
 +
 +    if (!d_buffered) {
 +      setbuf(fp.get(), 0);
 +    }
 +
 +    std::atomic_store_explicit(&d_fp, fp, std::memory_order_release);
 +    return true;
 +  }
 +
    std::string d_fname;
 -  std::unique_ptr<FILE, int(*)(FILE*)> d_fp{nullptr, fclose};
 +  std::shared_ptr<FILE> d_fp{nullptr};
    bool d_verboseOnly{true};
    bool d_includeTimestamp{false};
 +  bool d_append{false};
 +  bool d_buffered{true};
  };
  
  
@@@ -1087,28 -1186,6 +1230,28 @@@ private
    bool d_hasV6;
  };
  
 +static DnstapMessage::ProtocolType ProtocolToDNSTap(DNSQuestion::Protocol protocol)
 +{
 +  DnstapMessage::ProtocolType result;
 +  switch (protocol) {
 +  default:
 +  case DNSQuestion::Protocol::DoUDP:
 +  case DNSQuestion::Protocol::DNSCryptUDP:
 +    result = DnstapMessage::ProtocolType::DoUDP;
 +    break;
 +  case DNSQuestion::Protocol::DoTCP:
 +  case DNSQuestion::Protocol::DNSCryptTCP:
 +    result = DnstapMessage::ProtocolType::DoTCP;
 +    break;
 +  case DNSQuestion::Protocol::DoT:
 +    result = DnstapMessage::ProtocolType::DoT;
 +    break;
 +  case DNSQuestion::Protocol::DoH:
 +    result = DnstapMessage::ProtocolType::DoH;
 +    break;
 +  }
 +  return result;
 +}
  
  class DnstapLogAction : public DNSAction, public boost::noncopyable
  {
@@@ -1122,8 -1199,7 +1265,8 @@@ public
      static thread_local std::string data;
      data.clear();
  
 -    DnstapMessage message(data, !dq->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, dq->remote, dq->local, dq->tcp, reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size(), dq->queryTime, nullptr);
 +    DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dq->getProtocol());
 +    DnstapMessage message(data, !dq->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, dq->remote, dq->local, protocol, reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size(), dq->queryTime, nullptr);
      {
        if (d_alterFunc) {
          std::lock_guard<std::mutex> lock(g_luamutex);
@@@ -1256,8 -1332,7 +1399,8 @@@ public
      gettime(&now, true);
      data.clear();
  
 -    DnstapMessage message(data, DnstapMessage::MessageType::client_response, d_identity, dr->remote, dr->local, dr->tcp, reinterpret_cast<const char*>(dr->getData().data()), dr->getData().size(), dr->queryTime, &now);
 +    DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dr->getProtocol());
 +    DnstapMessage message(data, DnstapMessage::MessageType::client_response, d_identity, dr->remote, dr->local, protocol, reinterpret_cast<const char*>(dr->getData().data()), dr->getData().size(), dr->queryTime, &now);
      {
        if (d_alterFunc) {
          std::lock_guard<std::mutex> lock(g_luamutex);
@@@ -1749,8 -1824,6 +1892,8 @@@ void setupLuaActions(LuaContext& luaCtx
      });
  
    luaCtx.registerFunction("getStats", &DNSAction::getStats);
 +  luaCtx.registerFunction("reload", &DNSAction::reload);
 +  luaCtx.registerFunction("reload", &DNSResponseAction::reload);
  
    luaCtx.writeFunction("LuaAction", [](LuaAction::func_t func) {
        setLuaSideEffect();
        return std::shared_ptr<DNSAction>(new LuaFFIAction(func));
      });
  
+   luaCtx.writeFunction("LuaFFIPerThreadAction", [](std::string code) {
+       setLuaSideEffect();
+       return std::shared_ptr<DNSAction>(new LuaFFIPerThreadAction(code));
+     });
    luaCtx.writeFunction("SetNoRecurseAction", []() {
        return std::shared_ptr<DNSAction>(new SetNoRecurseAction);
      });
        return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(func));
      });
  
+   luaCtx.writeFunction("LuaFFIPerThreadResponseAction", [](std::string code) {
+       setLuaSideEffect();
+       return std::shared_ptr<DNSResponseAction>(new LuaFFIPerThreadResponseAction(code));
+     });
    luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
        if (logger) {
          // avoids potentially-evaluated-expression warning with clang.
index fca4ef67d7983dd737d237bdc0a2981988db1cc8,83065177e74ff9ca367de72cc785b36fa7ef6a6a..2153b63a31e81a929b80f2f82e63fd555275d05e
@@@ -444,7 -444,7 +444,7 @@@ void setupLuaRules(LuaContext& luaCtx
        sw.start();
        for(int n=0; n < times; ++n) {
          item& i = items[n % items.size()];
 -        DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, false, &sw.d_start);
 +        DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, DNSQuestion::Protocol::DoUDP, &sw.d_start);
          if (rule->matches(&dq)) {
            matches++;
          }
        return std::shared_ptr<DNSRule>(new LuaFFIRule(func));
      });
  
+   luaCtx.writeFunction("LuaFFIPerThreadRule", [](std::string code) {
+     return std::shared_ptr<DNSRule>(new LuaFFIPerThreadRule(code));
+   });
    luaCtx.writeFunction("ProxyProtocolValueRule", [](uint8_t type, boost::optional<std::string> value) {
        return std::shared_ptr<DNSRule>(new ProxyProtocolValueRule(type, value));
      });
index b4b06761e179a6e469009a416090457373a18313,1aa1129a7ef5100b0005969582400be7126a8549..67be9c8afa3e775179153341b0f8683f683d6a00
@@@ -122,7 -122,7 +122,7 @@@ uint8_t dnsdist_ffi_dnsquestion_get_opc
  
  bool dnsdist_ffi_dnsquestion_get_tcp(const dnsdist_ffi_dnsquestion_t* dq)
  {
 -  return dq->dq->tcp;
 +  return dq->dq->overTCP();
  }
  
  bool dnsdist_ffi_dnsquestion_get_skip_cache(const dnsdist_ffi_dnsquestion_t* dq)
@@@ -438,49 -438,6 +438,49 @@@ void dnsdist_ffi_dnsquestion_send_trap(
    }
  }
  
 +void dnsdist_ffi_dnsquestion_spoof_raw(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_raw_value_t* values, size_t valuesCount)
 +{
 +  std::vector<std::string> data;
 +  data.reserve(valuesCount);
 +
 +  for (size_t idx = 0; idx < valuesCount; idx++) {
 +    data.emplace_back(values[idx].value, values[idx].size);
 +  }
 +
 +  std::string result;
 +  SpoofAction sa(data);
 +  sa(dq->dq, &result);
 +}
 +
 +void dnsdist_ffi_dnsquestion_spoof_addrs(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_raw_value_t* values, size_t valuesCount)
 +{
 +  std::vector<ComboAddress> data;
 +  data.reserve(valuesCount);
 +
 +  for (size_t idx = 0; idx < valuesCount; idx++) {
 +    if (values[idx].size == 4) {
 +      sockaddr_in sin;
 +      sin.sin_family = AF_INET;
 +      sin.sin_port = 0;
 +      memcpy(&sin.sin_addr.s_addr, values[idx].value, sizeof(sin.sin_addr.s_addr));
 +      data.emplace_back(&sin);
 +    }
 +    else if (values[idx].size == 16) {
 +      sockaddr_in6 sin6;
 +      sin6.sin6_family = AF_INET6;
 +      sin6.sin6_port = 0;
 +      sin6.sin6_scope_id = 0;
 +      sin6.sin6_flowinfo = 0;
 +      memcpy(&sin6.sin6_addr.s6_addr, values[idx].value, sizeof(sin6.sin6_addr.s6_addr));
 +      data.emplace_back(&sin6);
 +    }
 +  }
 +
 +  std::string result;
 +  SpoofAction sa(data);
 +  sa(dq->dq, &result);
 +}
 +
  size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list)
  {
    return list->ffiServers.size();
@@@ -576,3 -533,12 +576,12 @@@ void setupLuaLoadBalancingContext(LuaCo
    luaCtx.executeCode(getLuaFFIWrappers());
  #endif
  }
+ void setupLuaFFIPerThreadContext(LuaContext& luaCtx)
+ {
+   setupLuaVars(luaCtx);
+ #ifdef LUAJIT_VERSION
+   luaCtx.executeCode(getLuaFFIWrappers());
+ #endif
+ }
index 4c739570d27fd917d5d79449dc124bf7bc4bb40e,06147fa8b20eb4d59aef6eefb487e3c452300184..be35f8027dd63d9bee6e098125bd38b80941e49f
@@@ -703,7 -703,7 +703,7 @@@ public
    }
    bool matches(const DNSQuestion* dq) const override
    {
 -    return dq->tcp == d_tcp;
 +    return dq->overTCP() == d_tcp;
    }
    string toString() const override
    {
@@@ -1178,6 -1178,62 +1178,62 @@@ private
    func_t d_func;
  };
  
+ class LuaFFIPerThreadRule : public DNSRule
+ {
+ public:
+   typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+   LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+   {
+   }
+   bool matches(const DNSQuestion* dq) const override
+   {
+     try {
+       auto& state = t_perThreadStates[d_functionID];
+       if (!state.d_initialized) {
+         setupLuaFFIPerThreadContext(state.d_luaContext);
+         /* mark the state as initialized first so if there is a syntax error
+            we only try to execute the code once */
+         state.d_initialized = true;
+         state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+       }
+       if (!state.d_func) {
+         /* the function was not properly initialized */
+         return false;
+       }
+       dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
+       return state.d_func(&dqffi);
+     }
+     catch (const std::exception &e) {
+       warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what());
+     }
+     catch (...) {
+       warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]");
+     }
+     return false;
+   }
+   string toString() const override
+   {
+     return "Lua FFI per-thread script";
+   }
+ private:
+   struct PerThreadState
+   {
+     LuaContext d_luaContext;
+     func_t d_func;
+     bool d_initialized{false};
+   };
+   static std::atomic<uint64_t> s_functionsCounter;
+   static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+   const std::string d_functionCode;
+   const uint64_t d_functionID;
+ };
  class ProxyProtocolValueRule : public DNSRule
  {
  public:
index 1dba0dbc725abf2048037c325505d088d4441ce1,38408e8317768d7653d6dccb5c99d75ce5924728..835d99631540e0bb45ef79dd7d1d40e0f39ecf09
@@@ -473,7 -473,7 +473,7 @@@ These ``DNSRule``\ s be one of the foll
    .. versionadded:: 1.4.0
  
    Matches DNS over HTTPS queries with a HTTP path matching the regular expression supplied in ``regex``. For example, if the query has been sent to the https://192.0.2.1:443/PowerDNS?dns=... URL, the path would be '/PowerDNS'.
 -  Only valid DNS over HTTPS queries are matched. If you want to match all HTTP queries, see :meth:`DOHFrontend.setResponsesMap` instead.
 +  Only valid DNS over HTTPS queries are matched. If you want to match all HTTP queries, see :meth:`DOHFrontend:setResponsesMap` instead.
  
    :param str regex: The regex to match on
  
    .. versionadded:: 1.4.0
  
    Matches DNS over HTTPS queries with a HTTP path of ``path``. For example, if the query has been sent to the https://192.0.2.1:443/PowerDNS?dns=... URL, the path would be '/PowerDNS'.
 -  Only valid DNS over HTTPS queries are matched. If you want to match all HTTP queries, see :meth:`DOHFrontend.setResponsesMap` instead.
 +  Only valid DNS over HTTPS queries are matched. If you want to match all HTTP queries, see :meth:`DOHFrontend:setResponsesMap` instead.
  
    :param str path: The exact HTTP path to match on
  
    :param KeyValueStore kvs: The key value store to query
    :param KeyValueLookupKey lookupKey: The key to use for the lookup
  
+ .. function:: LuaFFIPerThreadRule(function)
+   .. versionadded:: 1.7.0
+   Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+   The ``function`` should return true if the query matches, or false otherwise. If the Lua code fails, false is returned.
+   The function will be invoked in a per-thread Lua state, without access to the global Lua state. All constants (:ref:`DNSQType`, :ref:`DNSRCode`, ...) are available in that per-thread context,
+   as well as all FFI functions. Objects and their bindings that are not usable in a FFI context (:class:`DNSQuestion`, :class:`DNSDistProtoBufMessage`, :class:`PacketCache`, ...)
+   are not available.
+   :param string function: a Lua string returning a Lua function
  .. function:: LuaFFIRule(function)
  
    .. versionadded:: 1.5.0
@@@ -943,9 -957,6 +957,9 @@@ The following actions exist
    .. versionchanged:: 1.4.0
      Added the optional parameters ``verboseOnly`` and ``includeTimestamp``, made ``filename`` optional.
  
 +  .. versionchanged:: 1.7.0
 +    Added the ``reload`` method.
 +
    Log a line for each query, to the specified ``file`` if any, to the console (require verbose) if the empty string is given as filename.
  
    If an empty string is supplied in the file name, the logging is done to stdout, and only in verbose mode by default. This can be changed by setting ``verboseOnly`` to false.
    The ``append`` optional parameter specifies whether we open the file for appending or truncate each time (default).
    The ``buffered`` optional parameter specifies whether writes to the file are buffered (default) or not.
  
 +  Since 1.7.0 calling the ``reload()`` method on the object will cause it to close and re-open the log file, for rotation purposes.
 +
    Subsequent rules are processed after this action.
  
    :param string filename: File to log to. Set to an empty string to log to the normal stdout log, this only works when ``-v`` is set on the command line.
  
    .. versionadded:: 1.5.0
  
 +  .. versionchanged:: 1.7.0
 +    Added the ``reload`` method.
 +
    Log a line for each response, to the specified ``file`` if any, to the console (require verbose) if the empty string is given as filename.
  
    If an empty string is supplied in the file name, the logging is done to stdout, and only in verbose mode by default. This can be changed by setting ``verboseOnly`` to false.
    The ``append`` optional parameter specifies whether we open the file for appending or truncate each time (default).
    The ``buffered`` optional parameter specifies whether writes to the file are buffered (default) or not.
  
 +  Since 1.7.0 calling the ``reload()`` method on the object will cause it to close and re-open the log file, for rotation purposes.
 +
    Subsequent rules are processed after this action.
  
    :param string filename: File to log to. Set to an empty string to log to the normal stdout log, this only works when ``-v`` is set on the command line.
  
    :param string function: the name of a Lua function
  
+ .. function:: LuaFFIPerThreadAction(function)
+   .. versionadded:: 1.7.0
+   Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+   The ``function`` should return a :ref:`DNSAction`. If the Lua code fails, ServFail is returned.
+   The function will be invoked in a per-thread Lua state, without access to the global Lua state. All constants (:ref:`DNSQType`, :ref:`DNSRCode`, ...) are available in that per-thread context,
+   as well as all FFI functions. Objects and their bindings that are not usable in a FFI context (:class:`DNSQuestion`, :class:`DNSDistProtoBufMessage`, :class:`PacketCache`, ...)
+   are not available.
+   :param string function: a Lua string returning a Lua function
+ .. function:: LuaFFIPerThreadResponseAction(function)
+   .. versionadded:: 1.7.0
+   Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+   The ``function`` should return a :ref:`DNSResponseAction`. If the Lua code fails, ServFail is returned.
+   The function will be invoked in a per-thread Lua state, without access to the global Lua state. All constants (:ref:`DNSQType`, :ref:`DNSRCode`, ...) are available in that per-thread context,
+   as well as all FFI functions. Objects and their bindings that are not usable in a FFI context (:class:`DNSQuestion`, :class:`DNSDistProtoBufMessage`, :class:`PacketCache`, ...)
+   are not available.
+   :param string function: a Lua string returning a Lua function
  .. function:: LuaFFIResponseAction(function)
  
    .. versionadded:: 1.5.0
index f784f9b33b4956ed6a35c4d22e22856170bd6935,5fbbd517b552c0687a74be387a827786305eef10..27a9f6b0f534e561d3981078b998b7cbc47228c1
@@@ -2130,6 -2130,155 +2130,155 @@@ class TestAdvancedLuaFFI(DNSDistTest)
              (_, receivedResponse) = sender(query, response=None, useQueue=False)
              self.assertEqual(receivedResponse, response)
  
+ class TestAdvancedLuaFFIPerThread(DNSDistTest):
+     _config_template = """
+     local rulefunction = [[
+       local ffi = require("ffi")
+       return function(dq)
+         local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+         if qtype ~= DNSQType.A and qtype ~= DNSQType.SOA then
+           print('invalid qtype')
+           return false
+         end
+         local qclass = ffi.C.dnsdist_ffi_dnsquestion_get_qclass(dq)
+         if qclass ~= DNSClass.IN then
+           print('invalid qclass')
+           return false
+         end
+         local ret_ptr = ffi.new("char *[1]")
+         local ret_ptr_param = ffi.cast("const char **", ret_ptr)
+         local ret_size = ffi.new("size_t[1]")
+         local ret_size_param = ffi.cast("size_t*", ret_size)
+         ffi.C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, ret_ptr_param, ret_size_param)
+         if ret_size[0] ~= 45 then
+           print('invalid length for the qname ')
+           print(ret_size[0])
+           return false
+         end
+         local expectedQname = string.char(15)..'luaffiperthread'..string.char(8)..'advanced'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com'
+         if ffi.string(ret_ptr[0]) ~= expectedQname then
+           print('invalid qname')
+           print(ffi.string(ret_ptr[0]))
+           return false
+         end
+         local rcode = ffi.C.dnsdist_ffi_dnsquestion_get_rcode(dq)
+         if rcode ~= 0 then
+           print('invalid rcode')
+           return false
+         end
+         local opcode = ffi.C.dnsdist_ffi_dnsquestion_get_opcode(dq)
+         if qtype == DNSQType.A and opcode ~= DNSOpcode.Query then
+           print('invalid opcode')
+           return false
+         elseif qtype == DNSQType.SOA and opcode ~= DNSOpcode.Update then
+           print('invalid opcode')
+           return false
+         end
+         local dnssecok = ffi.C.dnsdist_ffi_dnsquestion_get_do(dq)
+         if dnssecok ~= false then
+           print('invalid DNSSEC OK')
+           return false
+         end
+         local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq)
+         if len ~= 61 then
+           print('invalid length')
+           print(len)
+           return false
+         end
+         local tag = ffi.C.dnsdist_ffi_dnsquestion_get_tag(dq, 'a-tag')
+         if ffi.string(tag) ~= 'a-value' then
+           print('invalid tag value')
+           print(ffi.string(tag))
+           return false
+         end
+         return true
+       end
+     ]]
+     local actionfunction = [[
+       local ffi = require("ffi")
+       return function(dq)
+         local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+         if qtype == DNSQType.A then
+           local str = "192.0.2.1"
+           local buf = ffi.new("char[?]", #str + 1)
+           ffi.copy(buf, str)
+           ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+           return DNSAction.Spoof
+         elseif qtype == DNSQType.SOA then
+           ffi.C.dnsdist_ffi_dnsquestion_set_rcode(dq, DNSRCode.REFUSED)
+           return DNSAction.Refused
+         end
+       end
+     ]]
+     local settagfunction = [[
+       local ffi = require("ffi")
+       return function(dq)
+         ffi.C.dnsdist_ffi_dnsquestion_set_tag(dq, 'a-tag', 'a-value')
+         return DNSAction.None
+       end
+     ]]
+     addAction(AllRule(), LuaFFIPerThreadAction(settagfunction))
+     addAction(LuaFFIPerThreadRule(rulefunction), LuaFFIPerThreadAction(actionfunction))
+     -- newServer{address="127.0.0.1:%s"}
+     """
+     def testAdvancedLuaPerthreadFFI(self):
+         """
+         Advanced: Test the Lua FFI per-thread interface
+         """
+         name = 'luaffiperthread.advanced.tests.powerdns.com.'
+         query = dns.message.make_query(name, 'A', 'IN')
+         # dnsdist set RA = RD for spoofed responses
+         query.flags &= ~dns.flags.RD
+         response = dns.message.make_response(query)
+         rrset = dns.rrset.from_text(name,
+                                     60,
+                                     dns.rdataclass.IN,
+                                     dns.rdatatype.A,
+                                     '192.0.2.1')
+         response.answer.append(rrset)
+         for method in ("sendUDPQuery", "sendTCPQuery"):
+             sender = getattr(self, method)
+             (_, receivedResponse) = sender(query, response=None, useQueue=False)
+             self.assertEqual(receivedResponse, response)
+     def testAdvancedLuaFFIPerThreadUpdate(self):
+         """
+         Advanced: Test the Lua FFI per-thread interface via an update
+         """
+         name = 'luaffiperthread.advanced.tests.powerdns.com.'
+         query = dns.message.make_query(name, 'SOA', 'IN')
+         query.set_opcode(dns.opcode.UPDATE)
+         # dnsdist set RA = RD for spoofed responses
+         query.flags &= ~dns.flags.RD
+         response = dns.message.make_response(query)
+         response.set_rcode(dns.rcode.REFUSED)
+         for method in ("sendUDPQuery", "sendTCPQuery"):
+             sender = getattr(self, method)
+             (_, receivedResponse) = sender(query, response=None, useQueue=False)
+             self.assertEqual(receivedResponse, response)
  class TestAdvancedDropEmptyQueries(DNSDistTest):
  
      _config_template = """
              sender = getattr(self, method)
              (_, receivedResponse) = sender(query, response=None, useQueue=False)
              self.assertEqual(receivedResponse, None)
 +
 +class TestProtocols(DNSDistTest):
 +    _config_template = """
 +    function checkUDP(dq)
 +      if dq:getProtocol() ~= "Do53 UDP" then
 +        return DNSAction.Spoof, '1.2.3.4'
 +      end
 +      return DNSAction.None
 +    end
 +
 +    function checkTCP(dq)
 +      if dq:getProtocol() ~= "Do53 TCP" then
 +        return DNSAction.Spoof, '1.2.3.4'
 +      end
 +      return DNSAction.None
 +    end
 +
 +    addAction("udp.protocols.advanced.tests.powerdns.com.", LuaAction(checkUDP))
 +    addAction("tcp.protocols.advanced.tests.powerdns.com.", LuaAction(checkTCP))
 +    newServer{address="127.0.0.1:%s"}
 +    """
 +
 +    def testProtocolUDP(self):
 +        """
 +        Advanced: Test DNSQuestion.Protocol over UDP
 +        """
 +        name = 'udp.protocols.advanced.tests.powerdns.com.'
 +        query = dns.message.make_query(name, 'A', 'IN')
 +        response = dns.message.make_response(query)
 +
 +        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
 +        receivedQuery.id = query.id
 +        self.assertEqual(receivedQuery, query)
 +        self.assertEqual(receivedResponse, response)
 +
 +    def testProtocolTCP(self):
 +        """
 +        Advanced: Test DNSQuestion.Protocol over TCP
 +        """
 +        name = 'tcp.protocols.advanced.tests.powerdns.com.'
 +        query = dns.message.make_query(name, 'A', 'IN')
 +        response = dns.message.make_response(query)
 +
 +        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
 +        receivedQuery.id = query.id
 +        self.assertEqual(receivedQuery, query)
 +        self.assertEqual(receivedResponse, response)