From: Remi Gacogne Date: Thu, 22 Dec 2022 16:32:05 +0000 (+0100) Subject: dnsdist: Add the infrastructure for restartable queries X-Git-Tag: dnsdist-1.8.0-rc1~86^2~12 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=b37671bc8abe56c376e2e6e7201ad2bb77cd1239;p=thirdparty%2Fpdns.git dnsdist: Add the infrastructure for restartable queries --- diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index cec3a329ab..43e1290b2f 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -127,6 +127,7 @@ struct InternalQueryState std::shared_ptr packetCache{nullptr}; // 16 std::unique_ptr dnsCryptQuery{nullptr}; // 8 std::unique_ptr qTag{nullptr}; // 8 + std::unique_ptr d_packet; // Initial packet, so we can restart the query from the response path if needed // 8 boost::optional tempFailureTTL{boost::none}; // 8 ClientState* cs{nullptr}; // 8 std::unique_ptr du; // 8 diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index ba3b95ff87..3d2603261b 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -22,6 +22,7 @@ #include "dnsdist.hh" #include "dnsdist-async.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-internal-queries.hh" #include "dnsdist-lua.hh" #include "dnsparser.hh" @@ -42,6 +43,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerMember("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; }); luaCtx.registerMember("tcp", [](const DNSQuestion& dq) -> bool { return dq.overTCP(); }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; }); luaCtx.registerMember("skipCache", [](const DNSQuestion& dq) -> bool { return dq.ids.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; }); + luaCtx.registerMember("pool", [](const DNSQuestion& dq) -> std::string { return dq.ids.poolName; }, [](DNSQuestion& dq, const std::string& newPoolName) { dq.ids.poolName = newPoolName; }); luaCtx.registerMember("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; }); luaCtx.registerMember("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; }); luaCtx.registerMember("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; }); @@ -201,6 +203,11 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) return dnsdist::suspendQuery(dq, asyncID, queryID, timeoutMs); }); + luaCtx.registerFunction("setRestartable", [](DNSQuestion& dq) { + dq.ids.d_packet = std::make_unique(dq.getData()); + return true; + }); + class AsynchronousObject { public: @@ -288,6 +295,7 @@ private: luaCtx.registerMember("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; }); luaCtx.registerMember("tcp", [](const DNSResponse& dq) -> bool { return dq.overTCP(); }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; }); luaCtx.registerMember("skipCache", [](const DNSResponse& dq) -> bool { return dq.ids.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; }); + luaCtx.registerMember("pool", [](const DNSResponse& dq) -> std::string { return dq.ids.poolName; }, [](DNSResponse& dq, const std::string& newPoolName) { dq.ids.poolName = newPoolName; }); luaCtx.registerFunction editFunc)>("editTTLs", [](DNSResponse& dr, std::function editFunc) { editDNSPacketTTL(reinterpret_cast(dr.getMutableData().data()), dr.getData().size(), editFunc); }); @@ -426,5 +434,15 @@ private: dr.asynchronous = true; return dnsdist::suspendResponse(dr, asyncID, queryID, timeoutMs); }); + + luaCtx.registerFunction("restart", [](DNSResponse& dr) { + if (!dr.ids.d_packet) { + return false; + } + dr.asynchronous = true; + dr.getMutableData() = *dr.ids.d_packet; + auto query = dnsdist::getInternalQueryFromDQ(dr, false); + return dnsdist::queueQueryResumptionEvent(std::move(query)); + }); #endif /* DISABLE_NON_FFI_DQ_BINDINGS */ } diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h b/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h index 741bd3aedc..533fc5ea3f 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h +++ b/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h @@ -128,6 +128,7 @@ void dnsdist_ffi_dnsquestion_spoof_packet(dnsdist_ffi_dnsquestion_t* dq, const c /* decrease the returned TTL but _after_ inserting the original response into the packet cache */ void dnsdist_ffi_dnsquestion_set_max_returned_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t max) __attribute__ ((visibility ("default"))); +bool dnsdist_ffi_dnsquestion_set_restartable(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); typedef struct dnsdist_ffi_servers_list_t dnsdist_ffi_servers_list_t; typedef struct dnsdist_ffi_server_t dnsdist_ffi_server_t; diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index 71dd9fb625..8bfb877ca6 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -591,6 +591,16 @@ void dnsdist_ffi_dnsquestion_set_max_returned_ttl(dnsdist_ffi_dnsquestion_t* dq, } } +bool dnsdist_ffi_dnsquestion_set_restartable(dnsdist_ffi_dnsquestion_t* dq) +{ + if (dq == nullptr || dq->dq == nullptr) { + return false; + } + + dq->dq->ids.d_packet = std::make_unique(dq->dq->getData()); + return true; +} + size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list) { return list->ffiServers.size(); diff --git a/regression-tests.dnsdist/test_RestartQuery.py b/regression-tests.dnsdist/test_RestartQuery.py new file mode 100644 index 0000000000..bd077fa3b4 --- /dev/null +++ b/regression-tests.dnsdist/test_RestartQuery.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +import threading +import clientsubnetoption +import dns +from dnsdisttests import DNSDistTest + +def servFailResponseCallback(request): + response = dns.message.make_response(request) + response.set_rcode(dns.rcode.SERVFAIL) + return response.to_wire() + +def normalResponseCallback(request): + response = dns.message.make_response(request) + rrset = dns.rrset.from_text(request.question[0].name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + return response.to_wire() + +class TestRestartQuery(DNSDistTest): + + # this test suite uses different responder ports + _testNormalServerPort = 5420 + _testServfailServerPort = 5421 + _config_template = """ + newServer{address="127.0.0.1:%d", pool='restarted'}:setUp() + newServer{address="127.0.0.1:%d", pool=''}:setUp() + + function makeQueryRestartable(dq) + dq:setRestartable() + return DNSAction.None + end + + function restartOnServFail(dr) + if dr.rcode == DNSRCode.SERVFAIL then + dr.pool = 'restarted' + dr:restart() + end + + return DNSResponseAction.None + end + + addAction(AllRule(), LuaAction(makeQueryRestartable)) + addResponseAction(AllRule(), LuaResponseAction(restartOnServFail)) + """ + _config_params = ['_testNormalServerPort', '_testServfailServerPort'] + _verboseMode = True + + @classmethod + def startResponders(cls): + print("Launching responders..") + + # servfail + cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServfailServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, servFailResponseCallback]) + cls._UDPResponder.setDaemon(True) + cls._UDPResponder.start() + cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServfailServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, servFailResponseCallback]) + cls._TCPResponder.setDaemon(True) + cls._TCPResponder.start() + cls._UDPResponderNormal = threading.Thread(name='UDP ResponderNormal', target=cls.UDPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, normalResponseCallback]) + cls._UDPResponderNormal.setDaemon(True) + cls._UDPResponderNormal.start() + cls._TCPResponderNormal = threading.Thread(name='TCP ResponderNormal', target=cls.TCPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, normalResponseCallback]) + cls._TCPResponderNormal.setDaemon(True) + cls._TCPResponderNormal.start() + + def testRestartingQuery(self): + """ + Restart: ServFail then restarted to a second pool + """ + name = 'restart.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + expectedResponse = dns.message.make_response(query) + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEquals(receivedResponse, expectedResponse) +