std::shared_ptr<DNSDistPacketCache> packetCache{nullptr}; // 16
std::unique_ptr<DNSCryptQuery> dnsCryptQuery{nullptr}; // 8
std::unique_ptr<QTag> qTag{nullptr}; // 8
+ std::unique_ptr<PacketBuffer> d_packet; // Initial packet, so we can restart the query from the response path if needed // 8
boost::optional<uint32_t> tempFailureTTL{boost::none}; // 8
ClientState* cs{nullptr}; // 8
std::unique_ptr<DOHUnit, void (*)(DOHUnit*)> du; // 8
#include "dnsdist.hh"
#include "dnsdist-async.hh"
#include "dnsdist-ecs.hh"
+#include "dnsdist-internal-queries.hh"
#include "dnsdist-lua.hh"
#include "dnsparser.hh"
luaCtx.registerMember<uint8_t (DNSQuestion::*)>("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; });
luaCtx.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.overTCP(); }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
luaCtx.registerMember<bool (DNSQuestion::*)>("skipCache", [](const DNSQuestion& dq) -> bool { return dq.ids.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; });
+ luaCtx.registerMember<std::string (DNSQuestion::*)>("pool", [](const DNSQuestion& dq) -> std::string { return dq.ids.poolName; }, [](DNSQuestion& dq, const std::string& newPoolName) { dq.ids.poolName = newPoolName; });
luaCtx.registerMember<bool (DNSQuestion::*)>("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; });
luaCtx.registerMember<bool (DNSQuestion::*)>("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; });
luaCtx.registerMember<uint16_t (DNSQuestion::*)>("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; });
return dnsdist::suspendQuery(dq, asyncID, queryID, timeoutMs);
});
+ luaCtx.registerFunction<bool(DNSQuestion::*)()>("setRestartable", [](DNSQuestion& dq) {
+ dq.ids.d_packet = std::make_unique<PacketBuffer>(dq.getData());
+ return true;
+ });
+
class AsynchronousObject
{
public:
luaCtx.registerMember<uint8_t (DNSResponse::*)>("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; });
luaCtx.registerMember<bool (DNSResponse::*)>("tcp", [](const DNSResponse& dq) -> bool { return dq.overTCP(); }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; });
luaCtx.registerMember<bool (DNSResponse::*)>("skipCache", [](const DNSResponse& dq) -> bool { return dq.ids.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; });
+ luaCtx.registerMember<std::string (DNSResponse::*)>("pool", [](const DNSResponse& dq) -> std::string { return dq.ids.poolName; }, [](DNSResponse& dq, const std::string& newPoolName) { dq.ids.poolName = newPoolName; });
luaCtx.registerFunction<void(DNSResponse::*)(std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc)>("editTTLs", [](DNSResponse& dr, std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc) {
editDNSPacketTTL(reinterpret_cast<char *>(dr.getMutableData().data()), dr.getData().size(), editFunc);
});
dr.asynchronous = true;
return dnsdist::suspendResponse(dr, asyncID, queryID, timeoutMs);
});
+
+ luaCtx.registerFunction<bool(DNSResponse::*)()>("restart", [](DNSResponse& dr) {
+ if (!dr.ids.d_packet) {
+ return false;
+ }
+ dr.asynchronous = true;
+ dr.getMutableData() = *dr.ids.d_packet;
+ auto query = dnsdist::getInternalQueryFromDQ(dr, false);
+ return dnsdist::queueQueryResumptionEvent(std::move(query));
+ });
#endif /* DISABLE_NON_FFI_DQ_BINDINGS */
}
/* decrease the returned TTL but _after_ inserting the original response into the packet cache */
void dnsdist_ffi_dnsquestion_set_max_returned_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t max) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_set_restartable(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
typedef struct dnsdist_ffi_servers_list_t dnsdist_ffi_servers_list_t;
typedef struct dnsdist_ffi_server_t dnsdist_ffi_server_t;
}
}
+bool dnsdist_ffi_dnsquestion_set_restartable(dnsdist_ffi_dnsquestion_t* dq)
+{
+ if (dq == nullptr || dq->dq == nullptr) {
+ return false;
+ }
+
+ dq->dq->ids.d_packet = std::make_unique<PacketBuffer>(dq->dq->getData());
+ return true;
+}
+
size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list)
{
return list->ffiServers.size();
--- /dev/null
+#!/usr/bin/env python
+import threading
+import clientsubnetoption
+import dns
+from dnsdisttests import DNSDistTest
+
+def servFailResponseCallback(request):
+ response = dns.message.make_response(request)
+ response.set_rcode(dns.rcode.SERVFAIL)
+ return response.to_wire()
+
+def normalResponseCallback(request):
+ response = dns.message.make_response(request)
+ rrset = dns.rrset.from_text(request.question[0].name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+ return response.to_wire()
+
+class TestRestartQuery(DNSDistTest):
+
+ # this test suite uses different responder ports
+ _testNormalServerPort = 5420
+ _testServfailServerPort = 5421
+ _config_template = """
+ newServer{address="127.0.0.1:%d", pool='restarted'}:setUp()
+ newServer{address="127.0.0.1:%d", pool=''}:setUp()
+
+ function makeQueryRestartable(dq)
+ dq:setRestartable()
+ return DNSAction.None
+ end
+
+ function restartOnServFail(dr)
+ if dr.rcode == DNSRCode.SERVFAIL then
+ dr.pool = 'restarted'
+ dr:restart()
+ end
+
+ return DNSResponseAction.None
+ end
+
+ addAction(AllRule(), LuaAction(makeQueryRestartable))
+ addResponseAction(AllRule(), LuaResponseAction(restartOnServFail))
+ """
+ _config_params = ['_testNormalServerPort', '_testServfailServerPort']
+ _verboseMode = True
+
+ @classmethod
+ def startResponders(cls):
+ print("Launching responders..")
+
+ # servfail
+ cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServfailServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, servFailResponseCallback])
+ cls._UDPResponder.setDaemon(True)
+ cls._UDPResponder.start()
+ cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServfailServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, servFailResponseCallback])
+ cls._TCPResponder.setDaemon(True)
+ cls._TCPResponder.start()
+ cls._UDPResponderNormal = threading.Thread(name='UDP ResponderNormal', target=cls.UDPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, normalResponseCallback])
+ cls._UDPResponderNormal.setDaemon(True)
+ cls._UDPResponderNormal.start()
+ cls._TCPResponderNormal = threading.Thread(name='TCP ResponderNormal', target=cls.TCPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, normalResponseCallback])
+ cls._TCPResponderNormal.setDaemon(True)
+ cls._TCPResponderNormal.start()
+
+ def testRestartingQuery(self):
+ """
+ Restart: ServFail then restarted to a second pool
+ """
+ name = 'restart.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ expectedResponse = dns.message.make_response(query)
+ expectedResponse.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (_, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertTrue(receivedResponse)
+ self.assertEquals(receivedResponse, expectedResponse)
+