From: Remi Gacogne Date: Mon, 23 Jun 2025 12:30:33 +0000 (+0200) Subject: dnsdist: Return `nil` for non-existing Lua objects X-Git-Tag: dnsdist-2.0.0-rc1~7^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=38c11c663e6e250bb001ca2deef7aafc9361d32f;p=thirdparty%2Fpdns.git dnsdist: Return `nil` for non-existing Lua objects Until now we were returning an empty shared pointer, but unfortunately LuaWrapper is currently not smart enough to turn that into a `nil` value. Signed-off-by: Remi Gacogne (cherry picked from commit 4b07e08e8b5aeef617c016c00ca344669870a8ce) --- diff --git a/pdns/dnsdistdist/dnsdist-lua-bindings-dnscrypt.cc b/pdns/dnsdistdist/dnsdist-lua-bindings-dnscrypt.cc index c2a6aa71a4..8d8fe13eb6 100644 --- a/pdns/dnsdistdist/dnsdist-lua-bindings-dnscrypt.cc +++ b/pdns/dnsdistdist/dnsdist-lua-bindings-dnscrypt.cc @@ -62,12 +62,12 @@ void setupLuaBindingsDNSCrypt([[maybe_unused]] LuaContext& luaCtx, [[maybe_unuse return result; }); - luaCtx.registerFunction (std::shared_ptr::*)(size_t idx)>("getCertificatePair", [](std::shared_ptr& ctx, size_t idx) { + luaCtx.registerFunction> (std::shared_ptr::*)(size_t idx)>("getCertificatePair", [](std::shared_ptr& ctx, size_t idx) -> boost::optional> { if (ctx == nullptr) { throw std::runtime_error("DNSCryptContext::getCertificatePair() called on a nil value"); } - std::shared_ptr result = nullptr; + boost::optional> result{boost::none}; auto pairs = ctx->getCertificates(); if (idx < pairs.size()) { result = pairs.at(idx); diff --git a/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc index 27c17a9466..9f2aa5fcb5 100644 --- a/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdistdist/dnsdist-lua-bindings-dnsquestion.cc @@ -687,8 +687,8 @@ void setupLuaBindingsDNSQuestion([[maybe_unused]] LuaContext& luaCtx) return dnsdist::queueQueryResumptionEvent(std::move(query)); }); - luaCtx.registerFunction (DNSResponse::*)(void) const>("getSelectedBackend", [](const DNSResponse& dnsResponse) { - return dnsResponse.d_downstream; + luaCtx.registerFunction> (DNSResponse::*)(void) const>("getSelectedBackend", [](const DNSResponse& dnsResponse) -> boost::optional> { + return dnsResponse.d_downstream ? dnsResponse.d_downstream : boost::optional>(); }); luaCtx.registerFunction("getStaleCacheHit", [](DNSResponse& dnsResponse) { diff --git a/pdns/dnsdistdist/dnsdist-lua.cc b/pdns/dnsdistdist/dnsdist-lua.cc index 653bf78ec7..83a7860042 100644 --- a/pdns/dnsdistdist/dnsdist-lua.cc +++ b/pdns/dnsdistdist/dnsdist-lua.cc @@ -1007,7 +1007,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) return *poolServers; }); - luaCtx.writeFunction("getServer", [client](boost::variant identifier) { + luaCtx.writeFunction("getServer", [client](boost::variant identifier) -> boost::optional> { if (client) { return std::make_shared(ComboAddress()); } @@ -1021,11 +1021,15 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } } else if (auto* pos = boost::get(&identifier)) { - return states.at(*pos); + if (*pos < states.size()) { + return states.at(*pos); + } + g_outputBuffer = "Error: trying to retrieve server " + std::to_string(*pos) + " while there is only " + std::to_string(states.size()) + "servers\n"; + return boost::none; } - g_outputBuffer = "Error: no rule matched\n"; - return std::shared_ptr(nullptr); + g_outputBuffer = "Error: no server matched\n"; + return boost::none; }); #ifndef DISABLE_CARBON @@ -1588,9 +1592,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) g_outputBuffer = ret.str(); }); - luaCtx.writeFunction("getDNSCryptBind", [](uint64_t idx) { + luaCtx.writeFunction("getDNSCryptBind", [](uint64_t idx) -> boost::optional> { setLuaNoSideEffect(); - std::shared_ptr ret = nullptr; + boost::optional> ret{boost::none}; /* we are only interested in distinct DNSCrypt binds, and we have two frontends (UDP and TCP) per bind sharing the same context so we need to retrieve @@ -1731,9 +1735,9 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } }); - luaCtx.writeFunction("getBind", [](uint64_t num) { + luaCtx.writeFunction("getBind", [](uint64_t num) -> boost::optional { setLuaNoSideEffect(); - ClientState* ret = nullptr; + boost::optional ret{boost::none}; auto frontends = dnsdist::getFrontends(); if (num < frontends.size()) { ret = frontends[num].get(); @@ -2474,8 +2478,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) }); #ifdef HAVE_DNS_OVER_QUIC - luaCtx.writeFunction("getDOQFrontend", [client](uint64_t index) { - std::shared_ptr result = nullptr; + luaCtx.writeFunction("getDOQFrontend", [client](uint64_t index) -> boost::optional> { + boost::optional> result{boost::none}; if (client) { return result; } @@ -2556,8 +2560,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) }); #ifdef HAVE_DNS_OVER_HTTP3 - luaCtx.writeFunction("getDOH3Frontend", [client](uint64_t index) { - std::shared_ptr result = nullptr; + luaCtx.writeFunction("getDOH3Frontend", [client](uint64_t index) -> boost::optional> { + boost::optional> result{boost::none}; if (client) { return result; } @@ -2625,8 +2629,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) #endif }); - luaCtx.writeFunction("getDOHFrontend", [client]([[maybe_unused]] uint64_t index) { - std::shared_ptr result = nullptr; + luaCtx.writeFunction("getDOHFrontend", [client]([[maybe_unused]] uint64_t index) -> boost::optional> { + boost::optional> result{boost::none}; if (client) { return result; } @@ -2855,8 +2859,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) #endif }); - luaCtx.writeFunction("getTLSFrontend", []([[maybe_unused]] uint64_t index) { - std::shared_ptr result = nullptr; + luaCtx.writeFunction("getTLSFrontend", []([[maybe_unused]] uint64_t index) -> boost::optional> { + boost::optional> result{boost::none}; #ifdef HAVE_DNS_OVER_TLS setLuaNoSideEffect(); try { diff --git a/regression-tests.dnsdist/test_Lua.py b/regression-tests.dnsdist/test_Lua.py index 3d689477c2..153c490b91 100644 --- a/regression-tests.dnsdist/test_Lua.py +++ b/regression-tests.dnsdist/test_Lua.py @@ -93,3 +93,39 @@ class TestLuaDNSHeaderBindings(DNSDistTest): receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(response, receivedResponse) + +class TestLuaFrontendBindings(DNSDistTest): + _config_template = """ + newServer{address="127.0.0.1:%d"} + + -- check that all these methods return nil on a non-existing entry + functions = { 'getServer', 'getDNSCryptBind', 'getBind', 'getDOQFrontend', 'getDOH3Frontend', 'getDOHFrontend', 'getTLSFrontend'} + for _, func in ipairs(functions) do + assert(_G[func](42) == nil, "function "..func.." did not return nil as expected") + end + + addAction('basic.lua-frontend-bindings.tests.powerdns.com.', RCodeAction(DNSRCode.REFUSED)) + -- also test that getSelectedBackend() returns nil on self-answered responses + function checkSelectedBackend(dr) + local backend = dr:getSelectedBackend() + assert(backend == nil, "DNSResponse::getSelectedBackend() should return nil on self-answered responses") + return DNSResponseAction.None + end + addSelfAnsweredResponseAction(AllRule(), LuaResponseAction(checkSelectedBackend)) + """ + _checkConfigExpectedOutput = b"Error: trying to get DOQ frontend with index 42 but we only have 0 frontend(s)\n\nError: trying to get DOH3 frontend with index 42 but we only have 0 frontend(s)\n\nError: trying to get DOH frontend with index 42 but we only have 0 frontend(s)\n\nError: trying to get TLS frontend with index 42 but we only have 0 frontends\n\nConfiguration 'configs/dnsdist_TestLuaFrontendBindings.conf' OK!\n" + + def testLuaBindings(self): + """ + LuaFrontendBindings: Test Lua frontend bindings + """ + name = 'basic.lua-frontend-bindings.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + expectedResponse.set_rcode(dns.rcode.REFUSED) + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertEqual(receivedResponse, expectedResponse)