]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add support for Lua per-thread FFI rules and actions
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 16 Jun 2021 09:50:04 +0000 (11:50 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 16 Jun 2021 09:50:04 +0000 (11:50 +0200)
pdns/dnsdist-console.cc
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-lua-ffi.hh
pdns/dnsdistdist/dnsdist-rules.cc [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/advanced/tuning.rst
pdns/dnsdistdist/docs/rules-actions.rst
regression-tests.dnsdist/test_Advanced.py

index c3bd5e20c4e0318e881bbe25cb9ff27dccb05c3f..99af9f07c6a972cac7757702b2622a434074a78e 100644 (file)
@@ -489,6 +489,8 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "LogResponseAction", true, "[filename], [append], [buffered]", "Log a line for each response, to the specified file if any, to the console (require verbose) otherwise. The `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." },
   { "LuaAction", true, "function", "Invoke a Lua function that accepts a DNSQuestion" },
   { "LuaFFIAction", true, "function", "Invoke a Lua FFI function that accepts a DNSQuestion" },
+  { "LuaFFIPerThreadAction", true, "function", "Invoke a Lua FFI function that accepts a DNSQuestion, with a per-thread Lua context" },
+  { "LuaFFIPerThreadResponseAction", true, "function", "Invoke a Lua FFI function that accepts a DNSResponse, with a per-thread Lua context" },
   { "LuaFFIResponseAction", true, "function", "Invoke a Lua FFI function that accepts a DNSResponse" },
   { "LuaFFIRule", true, "function", "Invoke a Lua FFI function that filters DNS questions" },
   { "LuaResponseAction", true, "function", "Invoke a Lua function that accepts a DNSResponse" },
index ce1e714be3e8ea22bb974fa8e3b3bdcc0d27dee3..f83d9a9ab71d664fe8e208e131bce95b0dbde24c 100644 (file)
@@ -489,6 +489,65 @@ private:
   func_t d_func;
 };
 
+class LuaFFIPerThreadAction: public DNSAction
+{
+public:
+  typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIPerThreadAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+  {
+  }
+
+  DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+  {
+    try {
+      auto& state = t_perThreadStates[d_functionID];
+      if (!state.d_initialized) {
+        setupLuaFFIPerThreadContext(state.d_luaContext);
+        state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+        state.d_initialized = true;
+      }
+
+      dnsdist_ffi_dnsquestion_t dqffi(dq);
+      auto ret = state.d_func(&dqffi);
+      if (ruleresult) {
+        if (dqffi.result) {
+          *ruleresult = *dqffi.result;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
+      }
+      return static_cast<DNSAction::Action>(ret);
+    } catch (const std::exception &e) {
+      warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what());
+    } catch (...) {
+      warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]");
+    }
+    return DNSAction::Action::ServFail;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI per-thread script";
+  }
+
+private:
+  struct PerThreadState
+  {
+    LuaContext d_luaContext;
+    func_t d_func;
+    bool d_initialized{false};
+  };
+  static uint64_t s_functionsCounter;
+  static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+  std::string d_functionCode;
+  uint64_t d_functionID;
+};
+
+uint64_t LuaFFIPerThreadAction::s_functionsCounter = 0;
+thread_local std::map<uint64_t, LuaFFIPerThreadAction::PerThreadState> LuaFFIPerThreadAction::t_perThreadStates;
 
 class LuaFFIResponseAction: public DNSResponseAction
 {
@@ -537,6 +596,72 @@ private:
   func_t d_func;
 };
 
+class LuaFFIPerThreadResponseAction: public DNSResponseAction
+{
+public:
+  typedef std::function<int(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIPerThreadResponseAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+  {
+  }
+
+  DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+  {
+    DNSQuestion* dq = dynamic_cast<DNSQuestion*>(dr);
+    if (dq == nullptr) {
+      return DNSResponseAction::Action::ServFail;
+    }
+
+    try {
+      auto& state = t_perThreadStates[d_functionID];
+      if (!state.d_initialized) {
+        setupLuaFFIPerThreadContext(state.d_luaContext);
+        state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+        state.d_initialized = true;
+      }
+
+      dnsdist_ffi_dnsquestion_t dqffi(dq);
+      auto ret = state.d_func(&dqffi);
+      if (ruleresult) {
+        if (dqffi.result) {
+          *ruleresult = *dqffi.result;
+        }
+        else {
+          // default to empty string
+          ruleresult->clear();
+        }
+      }
+      return static_cast<DNSResponseAction::Action>(ret);
+    } catch (const std::exception &e) {
+      warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what());
+    } catch (...) {
+      warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]");
+    }
+    return DNSResponseAction::Action::ServFail;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI per-thread script";
+  }
+
+private:
+  struct PerThreadState
+  {
+    LuaContext d_luaContext;
+    func_t d_func;
+    bool d_initialized{false};
+  };
+
+  static uint64_t s_functionsCounter;
+  static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+  std::string d_functionCode;
+  uint64_t d_functionID;
+};
+
+uint64_t LuaFFIPerThreadResponseAction::s_functionsCounter = 0;
+thread_local std::map<uint64_t, LuaFFIPerThreadResponseAction::PerThreadState> LuaFFIPerThreadResponseAction::t_perThreadStates;
+
 thread_local std::default_random_engine SpoofAction::t_randomEngine;
 
 DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const
@@ -1692,6 +1817,11 @@ void setupLuaActions(LuaContext& luaCtx)
       return std::shared_ptr<DNSAction>(new LuaFFIAction(func));
     });
 
+  luaCtx.writeFunction("LuaFFIPerThreadAction", [](std::string code) {
+      setLuaSideEffect();
+      return std::shared_ptr<DNSAction>(new LuaFFIPerThreadAction(code));
+    });
+
   luaCtx.writeFunction("SetNoRecurseAction", []() {
       return std::shared_ptr<DNSAction>(new SetNoRecurseAction);
     });
@@ -1858,6 +1988,11 @@ void setupLuaActions(LuaContext& luaCtx)
       return std::shared_ptr<DNSResponseAction>(new LuaFFIResponseAction(func));
     });
 
+  luaCtx.writeFunction("LuaFFIPerThreadResponseAction", [](std::string code) {
+      setLuaSideEffect();
+      return std::shared_ptr<DNSResponseAction>(new LuaFFIPerThreadResponseAction(code));
+    });
+
   luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
       if (logger) {
         // avoids potentially-evaluated-expression warning with clang.
index 7fc73b84f3493bb9ba201a4f75bd26ad3ca5e9c0..83065177e74ff9ca367de72cc785b36fa7ef6a6a 100644 (file)
@@ -603,6 +603,10 @@ void setupLuaRules(LuaContext& luaCtx)
       return std::shared_ptr<DNSRule>(new LuaFFIRule(func));
     });
 
+  luaCtx.writeFunction("LuaFFIPerThreadRule", [](std::string code) {
+    return std::shared_ptr<DNSRule>(new LuaFFIPerThreadRule(code));
+  });
+
   luaCtx.writeFunction("ProxyProtocolValueRule", [](uint8_t type, boost::optional<std::string> value) {
       return std::shared_ptr<DNSRule>(new ProxyProtocolValueRule(type, value));
     });
index f635d4e6b667eaa84fcec9bdc93fc56383f00c2b..42b16a3de596b9020cb9f9522405e8ae639d4ebb 100644 (file)
@@ -162,7 +162,7 @@ dnsdist_SOURCES = \
        dnsdist-protobuf.cc dnsdist-protobuf.hh \
        dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
-       dnsdist-rules.hh \
+       dnsdist-rules.cc dnsdist-rules.hh \
        dnsdist-secpoll.cc dnsdist-secpoll.hh \
        dnsdist-snmp.cc dnsdist-snmp.hh \
        dnsdist-systemd.cc dnsdist-systemd.hh \
@@ -240,6 +240,7 @@ testrunner_SOURCES = \
        dnsdist-lua-vars.cc \
        dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
+       dnsdist-rules.cc dnsdist-rules.hh \
        dnsdist-tcp-downstream.cc \
        dnsdist-tcp.cc \
        dnsdist-xpf.cc dnsdist-xpf.hh \
index a90a952ebc2b22ee870fb0c7af8e9f988f0a3ef5..1aa1129a7ef5100b0005969582400be7126a8549 100644 (file)
@@ -533,3 +533,12 @@ void setupLuaLoadBalancingContext(LuaContext& luaCtx)
   luaCtx.executeCode(getLuaFFIWrappers());
 #endif
 }
+
+void setupLuaFFIPerThreadContext(LuaContext& luaCtx)
+{
+  setupLuaVars(luaCtx);
+
+#ifdef LUAJIT_VERSION
+  luaCtx.executeCode(getLuaFFIWrappers());
+#endif
+}
index 63156db080aedf11cb10766b3107ea248012ebe8..103065950082b585853a7953220dada1707a1692 100644 (file)
@@ -107,3 +107,4 @@ struct dnsdist_ffi_servers_list_t
 };
 
 const std::string& getLuaFFIWrappers();
+void setupLuaFFIPerThreadContext(LuaContext& luaCtx);
diff --git a/pdns/dnsdistdist/dnsdist-rules.cc b/pdns/dnsdistdist/dnsdist-rules.cc
new file mode 100644 (file)
index 0000000..f1e6eec
--- /dev/null
@@ -0,0 +1,26 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist-rules.hh"
+
+uint64_t LuaFFIPerThreadRule::s_functionsCounter = 0;
+thread_local std::map<uint64_t, LuaFFIPerThreadRule::PerThreadState> LuaFFIPerThreadRule::t_perThreadStates;
index e19d363e126eb3d0221f20dae614b2f60a66763e..bbbcabe0a9c68a27c8e1a5be27cd74941600ac3d 100644 (file)
@@ -1178,6 +1178,53 @@ private:
   func_t d_func;
 };
 
+class LuaFFIPerThreadRule : public DNSRule
+{
+public:
+  typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    try {
+      auto& state = t_perThreadStates[d_functionID];
+      if (!state.d_initialized) {
+        setupLuaFFIPerThreadContext(state.d_luaContext);
+        state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+        state.d_initialized = true;
+      }
+
+      dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
+      return state.d_func(&dqffi);
+    } catch (const std::exception &e) {
+      warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what());
+    } catch (...) {
+      warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]");
+    }
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI per-thread script";
+  }
+private:
+  struct PerThreadState
+  {
+    LuaContext d_luaContext;
+    func_t d_func;
+    bool d_initialized{false};
+  };
+
+  static uint64_t s_functionsCounter;
+  static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+  std::string d_functionCode;
+  uint64_t d_functionID;
+};
+
 class ProxyProtocolValueRule : public DNSRule
 {
 public:
index 3c84d9aa6ea1480f489be0911a9d5f91dc9e5d77..2bf9aa50a2b87e9208993a4adfd315dfbe4ae936 100644 (file)
@@ -111,6 +111,8 @@ When Lua inspection is needed, the best course of action is to restrict the quer
 +------------------------------+-------------+-----------------+
 | Lua FFI rule                 | fast        | global Lua lock |
 +------------------------------+-------------+-----------------+
+| Lua per-thread FFI rule      | fast        | none            |
++------------------------------+-------------+-----------------+
 | C++ LB policy                | fast        | none            |
 +------------------------------+-------------+-----------------+
 | Lua LB policy                | slow        | global Lua lock |
index 35aa1257566488ca8e7d4425da824ef141f244f2..1ce55f34bf1b035768f709dae3962964d6e448de 100644 (file)
@@ -498,6 +498,18 @@ These ``DNSRule``\ s be one of the following items:
   :param KeyValueStore kvs: The key value store to query
   :param KeyValueLookupKey lookupKey: The key to use for the lookup
 
+.. function:: LuaFFIPerThreadRule(function)
+
+  .. versionadded:: 1.7.0
+
+  Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+  The ``function`` should return true if the query matches, or false otherwise. If the Lua code fails, false is returned.
+
+  The function will be invoked in a per-thread Lua state, without access to the global Lua state.
+
+  :param string function: the name of a Lua function
+
 .. function:: LuaFFIRule(function)
 
   .. versionadded:: 1.5.0
@@ -998,6 +1010,30 @@ The following actions exist.
 
   :param string function: the name of a Lua function
 
+.. function:: LuaFFIPerThreadAction(function)
+
+  .. versionadded:: 1.7.0
+
+  Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+  The ``function`` should return a :ref:`DNSAction`. If the Lua code fails, ServFail is returned.
+
+  The function will be invoked in a per-thread Lua state, without access to the global Lua state.
+
+  :param string function: the name of a Lua function
+
+.. function:: LuaFFIPerThreadResponseAction(function)
+
+  .. versionadded:: 1.7.0
+
+  Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``.
+
+  The ``function`` should return a :ref:`DNSResponseAction`. If the Lua code fails, ServFail is returned.
+
+  The function will be invoked in a per-thread Lua state, without access to the global Lua state.
+
+  :param string function: the name of a Lua function
+
 .. function:: LuaFFIResponseAction(function)
 
   .. versionadded:: 1.5.0
index c5c5012d082c4fdf65161d6adf4cc0efbdbdde65..5fbbd517b552c0687a74be387a827786305eef10 100644 (file)
@@ -2130,6 +2130,155 @@ class TestAdvancedLuaFFI(DNSDistTest):
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertEqual(receivedResponse, response)
 
+class TestAdvancedLuaFFIPerThread(DNSDistTest):
+
+    _config_template = """
+
+    local rulefunction = [[
+      local ffi = require("ffi")
+
+      return function(dq)
+        local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+        if qtype ~= DNSQType.A and qtype ~= DNSQType.SOA then
+          print('invalid qtype')
+          return false
+        end
+
+        local qclass = ffi.C.dnsdist_ffi_dnsquestion_get_qclass(dq)
+        if qclass ~= DNSClass.IN then
+          print('invalid qclass')
+          return false
+        end
+
+        local ret_ptr = ffi.new("char *[1]")
+        local ret_ptr_param = ffi.cast("const char **", ret_ptr)
+        local ret_size = ffi.new("size_t[1]")
+        local ret_size_param = ffi.cast("size_t*", ret_size)
+        ffi.C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, ret_ptr_param, ret_size_param)
+        if ret_size[0] ~= 45 then
+          print('invalid length for the qname ')
+          print(ret_size[0])
+          return false
+        end
+
+        local expectedQname = string.char(15)..'luaffiperthread'..string.char(8)..'advanced'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com'
+        if ffi.string(ret_ptr[0]) ~= expectedQname then
+          print('invalid qname')
+          print(ffi.string(ret_ptr[0]))
+          return false
+        end
+
+        local rcode = ffi.C.dnsdist_ffi_dnsquestion_get_rcode(dq)
+        if rcode ~= 0 then
+          print('invalid rcode')
+          return false
+        end
+
+        local opcode = ffi.C.dnsdist_ffi_dnsquestion_get_opcode(dq)
+        if qtype == DNSQType.A and opcode ~= DNSOpcode.Query then
+          print('invalid opcode')
+          return false
+        elseif qtype == DNSQType.SOA and opcode ~= DNSOpcode.Update then
+          print('invalid opcode')
+          return false
+        end
+
+        local dnssecok = ffi.C.dnsdist_ffi_dnsquestion_get_do(dq)
+        if dnssecok ~= false then
+          print('invalid DNSSEC OK')
+          return false
+        end
+
+        local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq)
+        if len ~= 61 then
+          print('invalid length')
+          print(len)
+          return false
+        end
+
+        local tag = ffi.C.dnsdist_ffi_dnsquestion_get_tag(dq, 'a-tag')
+        if ffi.string(tag) ~= 'a-value' then
+          print('invalid tag value')
+          print(ffi.string(tag))
+          return false
+        end
+
+        return true
+      end
+    ]]
+
+    local actionfunction = [[
+      local ffi = require("ffi")
+
+      return function(dq)
+        local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq)
+        if qtype == DNSQType.A then
+          local str = "192.0.2.1"
+          local buf = ffi.new("char[?]", #str + 1)
+          ffi.copy(buf, str)
+          ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str)
+          return DNSAction.Spoof
+        elseif qtype == DNSQType.SOA then
+          ffi.C.dnsdist_ffi_dnsquestion_set_rcode(dq, DNSRCode.REFUSED)
+          return DNSAction.Refused
+        end
+      end
+    ]]
+
+    local settagfunction = [[
+      local ffi = require("ffi")
+
+      return function(dq)
+        ffi.C.dnsdist_ffi_dnsquestion_set_tag(dq, 'a-tag', 'a-value')
+        return DNSAction.None
+      end
+    ]]
+
+    addAction(AllRule(), LuaFFIPerThreadAction(settagfunction))
+    addAction(LuaFFIPerThreadRule(rulefunction), LuaFFIPerThreadAction(actionfunction))
+    -- newServer{address="127.0.0.1:%s"}
+    """
+
+    def testAdvancedLuaPerthreadFFI(self):
+        """
+        Advanced: Test the Lua FFI per-thread interface
+        """
+        name = 'luaffiperthread.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEqual(receivedResponse, response)
+
+    def testAdvancedLuaFFIPerThreadUpdate(self):
+        """
+        Advanced: Test the Lua FFI per-thread interface via an update
+        """
+        name = 'luaffiperthread.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'SOA', 'IN')
+        query.set_opcode(dns.opcode.UPDATE)
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        response = dns.message.make_response(query)
+        response.set_rcode(dns.rcode.REFUSED)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEqual(receivedResponse, response)
+
 class TestAdvancedDropEmptyQueries(DNSDistTest):
 
     _config_template = """