From: Remi Gacogne Date: Wed, 16 Jun 2021 09:50:04 +0000 (+0200) Subject: dnsdist: Add support for Lua per-thread FFI rules and actions X-Git-Tag: dnsdist-1.7.0-alpha1~88^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=79ac0575815b3a222fb687e0008dad1d5878ac57;p=thirdparty%2Fpdns.git dnsdist: Add support for Lua per-thread FFI rules and actions --- diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index c3bd5e20c4..99af9f07c6 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -489,6 +489,8 @@ const std::vector g_consoleKeywords{ { "LogResponseAction", true, "[filename], [append], [buffered]", "Log a line for each response, to the specified file if any, to the console (require verbose) otherwise. The `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." }, { "LuaAction", true, "function", "Invoke a Lua function that accepts a DNSQuestion" }, { "LuaFFIAction", true, "function", "Invoke a Lua FFI function that accepts a DNSQuestion" }, + { "LuaFFIPerThreadAction", true, "function", "Invoke a Lua FFI function that accepts a DNSQuestion, with a per-thread Lua context" }, + { "LuaFFIPerThreadResponseAction", true, "function", "Invoke a Lua FFI function that accepts a DNSResponse, with a per-thread Lua context" }, { "LuaFFIResponseAction", true, "function", "Invoke a Lua FFI function that accepts a DNSResponse" }, { "LuaFFIRule", true, "function", "Invoke a Lua FFI function that filters DNS questions" }, { "LuaResponseAction", true, "function", "Invoke a Lua function that accepts a DNSResponse" }, diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index ce1e714be3..f83d9a9ab7 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -489,6 +489,65 @@ private: func_t d_func; }; +class LuaFFIPerThreadAction: public DNSAction +{ +public: + typedef std::function func_t; + + LuaFFIPerThreadAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) + { + } + + DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override + { + try { + auto& state = t_perThreadStates[d_functionID]; + if (!state.d_initialized) { + setupLuaFFIPerThreadContext(state.d_luaContext); + state.d_func = state.d_luaContext.executeCode(d_functionCode); + state.d_initialized = true; + } + + dnsdist_ffi_dnsquestion_t dqffi(dq); + auto ret = state.d_func(&dqffi); + if (ruleresult) { + if (dqffi.result) { + *ruleresult = *dqffi.result; + } + else { + // default to empty string + ruleresult->clear(); + } + } + return static_cast(ret); + } catch (const std::exception &e) { + warnlog("LuaFFIPerThreadAction failed inside Lua, returning ServFail: %s", e.what()); + } catch (...) { + warnlog("LuaFFIPerthreadAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSAction::Action::ServFail; + } + + string toString() const override + { + return "Lua FFI per-thread script"; + } + +private: + struct PerThreadState + { + LuaContext d_luaContext; + func_t d_func; + bool d_initialized{false}; + }; + static uint64_t s_functionsCounter; + static thread_local std::map t_perThreadStates; + std::string d_functionCode; + uint64_t d_functionID; +}; + +uint64_t LuaFFIPerThreadAction::s_functionsCounter = 0; +thread_local std::map LuaFFIPerThreadAction::t_perThreadStates; class LuaFFIResponseAction: public DNSResponseAction { @@ -537,6 +596,72 @@ private: func_t d_func; }; +class LuaFFIPerThreadResponseAction: public DNSResponseAction +{ +public: + typedef std::function func_t; + + LuaFFIPerThreadResponseAction(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) + { + } + + DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override + { + DNSQuestion* dq = dynamic_cast(dr); + if (dq == nullptr) { + return DNSResponseAction::Action::ServFail; + } + + try { + auto& state = t_perThreadStates[d_functionID]; + if (!state.d_initialized) { + setupLuaFFIPerThreadContext(state.d_luaContext); + state.d_func = state.d_luaContext.executeCode(d_functionCode); + state.d_initialized = true; + } + + dnsdist_ffi_dnsquestion_t dqffi(dq); + auto ret = state.d_func(&dqffi); + if (ruleresult) { + if (dqffi.result) { + *ruleresult = *dqffi.result; + } + else { + // default to empty string + ruleresult->clear(); + } + } + return static_cast(ret); + } catch (const std::exception &e) { + warnlog("LuaFFIPerThreadResponseAction failed inside Lua, returning ServFail: %s", e.what()); + } catch (...) { + warnlog("LuaFFIPerthreadResponseAction failed inside Lua, returning ServFail: [unknown exception]"); + } + return DNSResponseAction::Action::ServFail; + } + + string toString() const override + { + return "Lua FFI per-thread script"; + } + +private: + struct PerThreadState + { + LuaContext d_luaContext; + func_t d_func; + bool d_initialized{false}; + }; + + static uint64_t s_functionsCounter; + static thread_local std::map t_perThreadStates; + std::string d_functionCode; + uint64_t d_functionID; +}; + +uint64_t LuaFFIPerThreadResponseAction::s_functionsCounter = 0; +thread_local std::map LuaFFIPerThreadResponseAction::t_perThreadStates; + thread_local std::default_random_engine SpoofAction::t_randomEngine; DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const @@ -1692,6 +1817,11 @@ void setupLuaActions(LuaContext& luaCtx) return std::shared_ptr(new LuaFFIAction(func)); }); + luaCtx.writeFunction("LuaFFIPerThreadAction", [](std::string code) { + setLuaSideEffect(); + return std::shared_ptr(new LuaFFIPerThreadAction(code)); + }); + luaCtx.writeFunction("SetNoRecurseAction", []() { return std::shared_ptr(new SetNoRecurseAction); }); @@ -1858,6 +1988,11 @@ void setupLuaActions(LuaContext& luaCtx) return std::shared_ptr(new LuaFFIResponseAction(func)); }); + luaCtx.writeFunction("LuaFFIPerThreadResponseAction", [](std::string code) { + setLuaSideEffect(); + return std::shared_ptr(new LuaFFIPerThreadResponseAction(code)); + }); + luaCtx.writeFunction("RemoteLogAction", [](std::shared_ptr logger, boost::optional > alterFunc, boost::optional> vars) { if (logger) { // avoids potentially-evaluated-expression warning with clang. diff --git a/pdns/dnsdist-lua-rules.cc b/pdns/dnsdist-lua-rules.cc index 7fc73b84f3..83065177e7 100644 --- a/pdns/dnsdist-lua-rules.cc +++ b/pdns/dnsdist-lua-rules.cc @@ -603,6 +603,10 @@ void setupLuaRules(LuaContext& luaCtx) return std::shared_ptr(new LuaFFIRule(func)); }); + luaCtx.writeFunction("LuaFFIPerThreadRule", [](std::string code) { + return std::shared_ptr(new LuaFFIPerThreadRule(code)); + }); + luaCtx.writeFunction("ProxyProtocolValueRule", [](uint8_t type, boost::optional value) { return std::shared_ptr(new ProxyProtocolValueRule(type, value)); }); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index f635d4e6b6..42b16a3de5 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -162,7 +162,7 @@ dnsdist_SOURCES = \ dnsdist-protobuf.cc dnsdist-protobuf.hh \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ dnsdist-rings.cc dnsdist-rings.hh \ - dnsdist-rules.hh \ + dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-secpoll.cc dnsdist-secpoll.hh \ dnsdist-snmp.cc dnsdist-snmp.hh \ dnsdist-systemd.cc dnsdist-systemd.hh \ @@ -240,6 +240,7 @@ testrunner_SOURCES = \ dnsdist-lua-vars.cc \ dnsdist-proxy-protocol.cc dnsdist-proxy-protocol.hh \ dnsdist-rings.cc dnsdist-rings.hh \ + dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-tcp-downstream.cc \ dnsdist-tcp.cc \ dnsdist-xpf.cc dnsdist-xpf.hh \ diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index a90a952ebc..1aa1129a7e 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -533,3 +533,12 @@ void setupLuaLoadBalancingContext(LuaContext& luaCtx) luaCtx.executeCode(getLuaFFIWrappers()); #endif } + +void setupLuaFFIPerThreadContext(LuaContext& luaCtx) +{ + setupLuaVars(luaCtx); + +#ifdef LUAJIT_VERSION + luaCtx.executeCode(getLuaFFIWrappers()); +#endif +} diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.hh b/pdns/dnsdistdist/dnsdist-lua-ffi.hh index 63156db080..1030659500 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.hh +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.hh @@ -107,3 +107,4 @@ struct dnsdist_ffi_servers_list_t }; const std::string& getLuaFFIWrappers(); +void setupLuaFFIPerThreadContext(LuaContext& luaCtx); diff --git a/pdns/dnsdistdist/dnsdist-rules.cc b/pdns/dnsdistdist/dnsdist-rules.cc new file mode 100644 index 0000000000..f1e6eece97 --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-rules.cc @@ -0,0 +1,26 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "dnsdist-rules.hh" + +uint64_t LuaFFIPerThreadRule::s_functionsCounter = 0; +thread_local std::map LuaFFIPerThreadRule::t_perThreadStates; diff --git a/pdns/dnsdistdist/dnsdist-rules.hh b/pdns/dnsdistdist/dnsdist-rules.hh index e19d363e12..bbbcabe0a9 100644 --- a/pdns/dnsdistdist/dnsdist-rules.hh +++ b/pdns/dnsdistdist/dnsdist-rules.hh @@ -1178,6 +1178,53 @@ private: func_t d_func; }; +class LuaFFIPerThreadRule : public DNSRule +{ +public: + typedef std::function func_t; + + LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) + { + } + + bool matches(const DNSQuestion* dq) const override + { + try { + auto& state = t_perThreadStates[d_functionID]; + if (!state.d_initialized) { + setupLuaFFIPerThreadContext(state.d_luaContext); + state.d_func = state.d_luaContext.executeCode(d_functionCode); + state.d_initialized = true; + } + + dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); + return state.d_func(&dqffi); + } catch (const std::exception &e) { + warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what()); + } catch (...) { + warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]"); + } + return false; + } + + string toString() const override + { + return "Lua FFI per-thread script"; + } +private: + struct PerThreadState + { + LuaContext d_luaContext; + func_t d_func; + bool d_initialized{false}; + }; + + static uint64_t s_functionsCounter; + static thread_local std::map t_perThreadStates; + std::string d_functionCode; + uint64_t d_functionID; +}; + class ProxyProtocolValueRule : public DNSRule { public: diff --git a/pdns/dnsdistdist/docs/advanced/tuning.rst b/pdns/dnsdistdist/docs/advanced/tuning.rst index 3c84d9aa6e..2bf9aa50a2 100644 --- a/pdns/dnsdistdist/docs/advanced/tuning.rst +++ b/pdns/dnsdistdist/docs/advanced/tuning.rst @@ -111,6 +111,8 @@ When Lua inspection is needed, the best course of action is to restrict the quer +------------------------------+-------------+-----------------+ | Lua FFI rule | fast | global Lua lock | +------------------------------+-------------+-----------------+ +| Lua per-thread FFI rule | fast | none | ++------------------------------+-------------+-----------------+ | C++ LB policy | fast | none | +------------------------------+-------------+-----------------+ | Lua LB policy | slow | global Lua lock | diff --git a/pdns/dnsdistdist/docs/rules-actions.rst b/pdns/dnsdistdist/docs/rules-actions.rst index 35aa125756..1ce55f34bf 100644 --- a/pdns/dnsdistdist/docs/rules-actions.rst +++ b/pdns/dnsdistdist/docs/rules-actions.rst @@ -498,6 +498,18 @@ These ``DNSRule``\ s be one of the following items: :param KeyValueStore kvs: The key value store to query :param KeyValueLookupKey lookupKey: The key to use for the lookup +.. function:: LuaFFIPerThreadRule(function) + + .. versionadded:: 1.7.0 + + Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``. + + The ``function`` should return true if the query matches, or false otherwise. If the Lua code fails, false is returned. + + The function will be invoked in a per-thread Lua state, without access to the global Lua state. + + :param string function: the name of a Lua function + .. function:: LuaFFIRule(function) .. versionadded:: 1.5.0 @@ -998,6 +1010,30 @@ The following actions exist. :param string function: the name of a Lua function +.. function:: LuaFFIPerThreadAction(function) + + .. versionadded:: 1.7.0 + + Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``. + + The ``function`` should return a :ref:`DNSAction`. If the Lua code fails, ServFail is returned. + + The function will be invoked in a per-thread Lua state, without access to the global Lua state. + + :param string function: the name of a Lua function + +.. function:: LuaFFIPerThreadResponseAction(function) + + .. versionadded:: 1.7.0 + + Invoke a Lua FFI function that accepts a pointer to a ``dnsdist_ffi_dnsquestion_t`` object, whose bindings are defined in ``dnsdist-lua-ffi.hh``. + + The ``function`` should return a :ref:`DNSResponseAction`. If the Lua code fails, ServFail is returned. + + The function will be invoked in a per-thread Lua state, without access to the global Lua state. + + :param string function: the name of a Lua function + .. function:: LuaFFIResponseAction(function) .. versionadded:: 1.5.0 diff --git a/regression-tests.dnsdist/test_Advanced.py b/regression-tests.dnsdist/test_Advanced.py index c5c5012d08..5fbbd517b5 100644 --- a/regression-tests.dnsdist/test_Advanced.py +++ b/regression-tests.dnsdist/test_Advanced.py @@ -2130,6 +2130,155 @@ class TestAdvancedLuaFFI(DNSDistTest): (_, receivedResponse) = sender(query, response=None, useQueue=False) self.assertEqual(receivedResponse, response) +class TestAdvancedLuaFFIPerThread(DNSDistTest): + + _config_template = """ + + local rulefunction = [[ + local ffi = require("ffi") + + return function(dq) + local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq) + if qtype ~= DNSQType.A and qtype ~= DNSQType.SOA then + print('invalid qtype') + return false + end + + local qclass = ffi.C.dnsdist_ffi_dnsquestion_get_qclass(dq) + if qclass ~= DNSClass.IN then + print('invalid qclass') + return false + end + + local ret_ptr = ffi.new("char *[1]") + local ret_ptr_param = ffi.cast("const char **", ret_ptr) + local ret_size = ffi.new("size_t[1]") + local ret_size_param = ffi.cast("size_t*", ret_size) + ffi.C.dnsdist_ffi_dnsquestion_get_qname_raw(dq, ret_ptr_param, ret_size_param) + if ret_size[0] ~= 45 then + print('invalid length for the qname ') + print(ret_size[0]) + return false + end + + local expectedQname = string.char(15)..'luaffiperthread'..string.char(8)..'advanced'..string.char(5)..'tests'..string.char(8)..'powerdns'..string.char(3)..'com' + if ffi.string(ret_ptr[0]) ~= expectedQname then + print('invalid qname') + print(ffi.string(ret_ptr[0])) + return false + end + + local rcode = ffi.C.dnsdist_ffi_dnsquestion_get_rcode(dq) + if rcode ~= 0 then + print('invalid rcode') + return false + end + + local opcode = ffi.C.dnsdist_ffi_dnsquestion_get_opcode(dq) + if qtype == DNSQType.A and opcode ~= DNSOpcode.Query then + print('invalid opcode') + return false + elseif qtype == DNSQType.SOA and opcode ~= DNSOpcode.Update then + print('invalid opcode') + return false + end + + local dnssecok = ffi.C.dnsdist_ffi_dnsquestion_get_do(dq) + if dnssecok ~= false then + print('invalid DNSSEC OK') + return false + end + + local len = ffi.C.dnsdist_ffi_dnsquestion_get_len(dq) + if len ~= 61 then + print('invalid length') + print(len) + return false + end + + local tag = ffi.C.dnsdist_ffi_dnsquestion_get_tag(dq, 'a-tag') + if ffi.string(tag) ~= 'a-value' then + print('invalid tag value') + print(ffi.string(tag)) + return false + end + + return true + end + ]] + + local actionfunction = [[ + local ffi = require("ffi") + + return function(dq) + local qtype = ffi.C.dnsdist_ffi_dnsquestion_get_qtype(dq) + if qtype == DNSQType.A then + local str = "192.0.2.1" + local buf = ffi.new("char[?]", #str + 1) + ffi.copy(buf, str) + ffi.C.dnsdist_ffi_dnsquestion_set_result(dq, buf, #str) + return DNSAction.Spoof + elseif qtype == DNSQType.SOA then + ffi.C.dnsdist_ffi_dnsquestion_set_rcode(dq, DNSRCode.REFUSED) + return DNSAction.Refused + end + end + ]] + + local settagfunction = [[ + local ffi = require("ffi") + + return function(dq) + ffi.C.dnsdist_ffi_dnsquestion_set_tag(dq, 'a-tag', 'a-value') + return DNSAction.None + end + ]] + + addAction(AllRule(), LuaFFIPerThreadAction(settagfunction)) + addAction(LuaFFIPerThreadRule(rulefunction), LuaFFIPerThreadAction(actionfunction)) + -- newServer{address="127.0.0.1:%s"} + """ + + def testAdvancedLuaPerthreadFFI(self): + """ + Advanced: Test the Lua FFI per-thread interface + """ + name = 'luaffiperthread.advanced.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, response) + + def testAdvancedLuaFFIPerThreadUpdate(self): + """ + Advanced: Test the Lua FFI per-thread interface via an update + """ + name = 'luaffiperthread.advanced.tests.powerdns.com.' + query = dns.message.make_query(name, 'SOA', 'IN') + query.set_opcode(dns.opcode.UPDATE) + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + + response = dns.message.make_response(query) + response.set_rcode(dns.rcode.REFUSED) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, response) + class TestAdvancedDropEmptyQueries(DNSDistTest): _config_template = """