return result;
});
- luaCtx.registerFunction<std::shared_ptr<DNSCryptCertificatePair> (std::shared_ptr<DNSCryptContext>::*)(size_t idx)>("getCertificatePair", [](std::shared_ptr<DNSCryptContext>& ctx, size_t idx) {
+ luaCtx.registerFunction<boost::optional<std::shared_ptr<DNSCryptCertificatePair>> (std::shared_ptr<DNSCryptContext>::*)(size_t idx)>("getCertificatePair", [](std::shared_ptr<DNSCryptContext>& ctx, size_t idx) -> boost::optional<std::shared_ptr<DNSCryptCertificatePair>> {
if (ctx == nullptr) {
throw std::runtime_error("DNSCryptContext::getCertificatePair() called on a nil value");
}
- std::shared_ptr<DNSCryptCertificatePair> result = nullptr;
+ boost::optional<std::shared_ptr<DNSCryptCertificatePair>> result{boost::none};
auto pairs = ctx->getCertificates();
if (idx < pairs.size()) {
result = pairs.at(idx);
return dnsdist::queueQueryResumptionEvent(std::move(query));
});
- luaCtx.registerFunction<std::shared_ptr<DownstreamState> (DNSResponse::*)(void) const>("getSelectedBackend", [](const DNSResponse& dnsResponse) {
- return dnsResponse.d_downstream;
+ luaCtx.registerFunction<boost::optional<std::shared_ptr<DownstreamState>> (DNSResponse::*)(void) const>("getSelectedBackend", [](const DNSResponse& dnsResponse) -> boost::optional<std::shared_ptr<DownstreamState>> {
+ return dnsResponse.d_downstream ? dnsResponse.d_downstream : boost::optional<std::shared_ptr<DownstreamState>>();
});
luaCtx.registerFunction<bool (DNSResponse::*)()>("getStaleCacheHit", [](DNSResponse& dnsResponse) {
return *poolServers;
});
- luaCtx.writeFunction("getServer", [client](boost::variant<unsigned int, std::string> identifier) {
+ luaCtx.writeFunction("getServer", [client](boost::variant<unsigned int, std::string> identifier) -> boost::optional<std::shared_ptr<DownstreamState>> {
if (client) {
return std::make_shared<DownstreamState>(ComboAddress());
}
}
}
else if (auto* pos = boost::get<unsigned int>(&identifier)) {
- return states.at(*pos);
+ if (*pos < states.size()) {
+ return states.at(*pos);
+ }
+ g_outputBuffer = "Error: trying to retrieve server " + std::to_string(*pos) + " while there is only " + std::to_string(states.size()) + "servers\n";
+ return boost::none;
}
- g_outputBuffer = "Error: no rule matched\n";
- return std::shared_ptr<DownstreamState>(nullptr);
+ g_outputBuffer = "Error: no server matched\n";
+ return boost::none;
});
#ifndef DISABLE_CARBON
g_outputBuffer = ret.str();
});
- luaCtx.writeFunction("getDNSCryptBind", [](uint64_t idx) {
+ luaCtx.writeFunction("getDNSCryptBind", [](uint64_t idx) -> boost::optional<std::shared_ptr<DNSCryptContext>> {
setLuaNoSideEffect();
- std::shared_ptr<DNSCryptContext> ret = nullptr;
+ boost::optional<std::shared_ptr<DNSCryptContext>> ret{boost::none};
/* we are only interested in distinct DNSCrypt binds,
and we have two frontends (UDP and TCP) per bind
sharing the same context so we need to retrieve
}
});
- luaCtx.writeFunction("getBind", [](uint64_t num) {
+ luaCtx.writeFunction("getBind", [](uint64_t num) -> boost::optional<ClientState*> {
setLuaNoSideEffect();
- ClientState* ret = nullptr;
+ boost::optional<ClientState*> ret{boost::none};
auto frontends = dnsdist::getFrontends();
if (num < frontends.size()) {
ret = frontends[num].get();
});
#ifdef HAVE_DNS_OVER_QUIC
- luaCtx.writeFunction("getDOQFrontend", [client](uint64_t index) {
- std::shared_ptr<DOQFrontend> result = nullptr;
+ luaCtx.writeFunction("getDOQFrontend", [client](uint64_t index) -> boost::optional<std::shared_ptr<DOQFrontend>> {
+ boost::optional<std::shared_ptr<DOQFrontend>> result{boost::none};
if (client) {
return result;
}
});
#ifdef HAVE_DNS_OVER_HTTP3
- luaCtx.writeFunction("getDOH3Frontend", [client](uint64_t index) {
- std::shared_ptr<DOH3Frontend> result = nullptr;
+ luaCtx.writeFunction("getDOH3Frontend", [client](uint64_t index) -> boost::optional<std::shared_ptr<DOH3Frontend>> {
+ boost::optional<std::shared_ptr<DOH3Frontend>> result{boost::none};
if (client) {
return result;
}
#endif
});
- luaCtx.writeFunction("getDOHFrontend", [client]([[maybe_unused]] uint64_t index) {
- std::shared_ptr<DOHFrontend> result = nullptr;
+ luaCtx.writeFunction("getDOHFrontend", [client]([[maybe_unused]] uint64_t index) -> boost::optional<std::shared_ptr<DOHFrontend>> {
+ boost::optional<std::shared_ptr<DOHFrontend>> result{boost::none};
if (client) {
return result;
}
#endif
});
- luaCtx.writeFunction("getTLSFrontend", []([[maybe_unused]] uint64_t index) {
- std::shared_ptr<TLSFrontend> result = nullptr;
+ luaCtx.writeFunction("getTLSFrontend", []([[maybe_unused]] uint64_t index) -> boost::optional<std::shared_ptr<TLSFrontend>> {
+ boost::optional<std::shared_ptr<TLSFrontend>> result{boost::none};
#ifdef HAVE_DNS_OVER_TLS
setLuaNoSideEffect();
try {
receivedQuery.id = query.id
self.assertEqual(query, receivedQuery)
self.assertEqual(response, receivedResponse)
+
+class TestLuaFrontendBindings(DNSDistTest):
+ _config_template = """
+ newServer{address="127.0.0.1:%d"}
+
+ -- check that all these methods return nil on a non-existing entry
+ functions = { 'getServer', 'getDNSCryptBind', 'getBind', 'getDOQFrontend', 'getDOH3Frontend', 'getDOHFrontend', 'getTLSFrontend'}
+ for _, func in ipairs(functions) do
+ assert(_G[func](42) == nil, "function "..func.." did not return nil as expected")
+ end
+
+ addAction('basic.lua-frontend-bindings.tests.powerdns.com.', RCodeAction(DNSRCode.REFUSED))
+ -- also test that getSelectedBackend() returns nil on self-answered responses
+ function checkSelectedBackend(dr)
+ local backend = dr:getSelectedBackend()
+ assert(backend == nil, "DNSResponse::getSelectedBackend() should return nil on self-answered responses")
+ return DNSResponseAction.None
+ end
+ addSelfAnsweredResponseAction(AllRule(), LuaResponseAction(checkSelectedBackend))
+ """
+ _checkConfigExpectedOutput = b"Error: trying to get DOQ frontend with index 42 but we only have 0 frontend(s)\n\nError: trying to get DOH3 frontend with index 42 but we only have 0 frontend(s)\n\nError: trying to get DOH frontend with index 42 but we only have 0 frontend(s)\n\nError: trying to get TLS frontend with index 42 but we only have 0 frontends\n\nConfiguration 'configs/dnsdist_TestLuaFrontendBindings.conf' OK!\n"
+
+ def testLuaBindings(self):
+ """
+ LuaFrontendBindings: Test Lua frontend bindings
+ """
+ name = 'basic.lua-frontend-bindings.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+ expectedResponse = dns.message.make_response(query)
+ expectedResponse.set_rcode(dns.rcode.REFUSED)
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertEqual(receivedResponse, expectedResponse)