]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add the infrastructure for restartable queries
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 22 Dec 2022 16:32:05 +0000 (17:32 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 13 Jan 2023 15:57:51 +0000 (16:57 +0100)
pdns/dnsdist-idstate.hh
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
pdns/dnsdistdist/dnsdist-lua-ffi.cc
regression-tests.dnsdist/test_RestartQuery.py [new file with mode: 0644]

index cec3a329ab470d341c7c5126eda1c6038423200a..43e1290b2fb5aa1291847508111a2241fe10d843 100644 (file)
@@ -127,6 +127,7 @@ struct InternalQueryState
   std::shared_ptr<DNSDistPacketCache> packetCache{nullptr}; // 16
   std::unique_ptr<DNSCryptQuery> dnsCryptQuery{nullptr}; // 8
   std::unique_ptr<QTag> qTag{nullptr}; // 8
+  std::unique_ptr<PacketBuffer> d_packet; // Initial packet, so we can restart the query from the response path if needed // 8
   boost::optional<uint32_t> tempFailureTTL{boost::none}; // 8
   ClientState* cs{nullptr}; // 8
   std::unique_ptr<DOHUnit, void (*)(DOHUnit*)> du; // 8
index ba3b95ff8740d575577df0fd865653d335c6db3e..3d2603261b8dd028a68d17a63e937107dbde87a3 100644 (file)
@@ -22,6 +22,7 @@
 #include "dnsdist.hh"
 #include "dnsdist-async.hh"
 #include "dnsdist-ecs.hh"
+#include "dnsdist-internal-queries.hh"
 #include "dnsdist-lua.hh"
 #include "dnsparser.hh"
 
@@ -42,6 +43,7 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerMember<uint8_t (DNSQuestion::*)>("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.overTCP(); }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("skipCache", [](const DNSQuestion& dq) -> bool { return dq.ids.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; });
+  luaCtx.registerMember<std::string (DNSQuestion::*)>("pool", [](const DNSQuestion& dq) -> std::string { return dq.ids.poolName; }, [](DNSQuestion& dq, const std::string& newPoolName) { dq.ids.poolName = newPoolName; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; });
   luaCtx.registerMember<uint16_t (DNSQuestion::*)>("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; });
@@ -201,6 +203,11 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     return dnsdist::suspendQuery(dq, asyncID, queryID, timeoutMs);
   });
 
+  luaCtx.registerFunction<bool(DNSQuestion::*)()>("setRestartable", [](DNSQuestion& dq) {
+    dq.ids.d_packet = std::make_unique<PacketBuffer>(dq.getData());
+    return true;
+  });
+
 class AsynchronousObject
 {
 public:
@@ -288,6 +295,7 @@ private:
   luaCtx.registerMember<uint8_t (DNSResponse::*)>("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; });
   luaCtx.registerMember<bool (DNSResponse::*)>("tcp", [](const DNSResponse& dq) -> bool { return dq.overTCP(); }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; });
   luaCtx.registerMember<bool (DNSResponse::*)>("skipCache", [](const DNSResponse& dq) -> bool { return dq.ids.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; });
+  luaCtx.registerMember<std::string (DNSResponse::*)>("pool", [](const DNSResponse& dq) -> std::string { return dq.ids.poolName; }, [](DNSResponse& dq, const std::string& newPoolName) { dq.ids.poolName = newPoolName; });
   luaCtx.registerFunction<void(DNSResponse::*)(std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc)>("editTTLs", [](DNSResponse& dr, std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc) {
     editDNSPacketTTL(reinterpret_cast<char *>(dr.getMutableData().data()), dr.getData().size(), editFunc);
       });
@@ -426,5 +434,15 @@ private:
     dr.asynchronous = true;
     return dnsdist::suspendResponse(dr, asyncID, queryID, timeoutMs);
   });
+
+  luaCtx.registerFunction<bool(DNSResponse::*)()>("restart", [](DNSResponse& dr) {
+    if (!dr.ids.d_packet) {
+      return false;
+    }
+    dr.asynchronous = true;
+    dr.getMutableData() = *dr.ids.d_packet;
+    auto query = dnsdist::getInternalQueryFromDQ(dr, false);
+    return dnsdist::queueQueryResumptionEvent(std::move(query));
+  });
 #endif /* DISABLE_NON_FFI_DQ_BINDINGS */
 }
index 741bd3aedc3c180bfe57a2193f942274b61343b7..533fc5ea3fb382a34a729782fcd2b1e8e37bc10d 100644 (file)
@@ -128,6 +128,7 @@ void dnsdist_ffi_dnsquestion_spoof_packet(dnsdist_ffi_dnsquestion_t* dq, const c
 
 /* decrease the returned TTL but _after_ inserting the original response into the packet cache */
 void dnsdist_ffi_dnsquestion_set_max_returned_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t max) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_set_restartable(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 
 typedef struct dnsdist_ffi_servers_list_t dnsdist_ffi_servers_list_t;
 typedef struct dnsdist_ffi_server_t dnsdist_ffi_server_t;
index 71dd9fb6250fdc767af292456565ca33f6dde92f..8bfb877ca66f15ffc6efcd891faf27945add9397 100644 (file)
@@ -591,6 +591,16 @@ void dnsdist_ffi_dnsquestion_set_max_returned_ttl(dnsdist_ffi_dnsquestion_t* dq,
   }
 }
 
+bool dnsdist_ffi_dnsquestion_set_restartable(dnsdist_ffi_dnsquestion_t* dq)
+{
+  if (dq == nullptr || dq->dq == nullptr) {
+    return false;
+  }
+
+  dq->dq->ids.d_packet = std::make_unique<PacketBuffer>(dq->dq->getData());
+  return true;
+}
+
 size_t dnsdist_ffi_servers_list_get_count(const dnsdist_ffi_servers_list_t* list)
 {
   return list->ffiServers.size();
diff --git a/regression-tests.dnsdist/test_RestartQuery.py b/regression-tests.dnsdist/test_RestartQuery.py
new file mode 100644 (file)
index 0000000..bd077fa
--- /dev/null
@@ -0,0 +1,88 @@
+#!/usr/bin/env python
+import threading
+import clientsubnetoption
+import dns
+from dnsdisttests import DNSDistTest
+
+def servFailResponseCallback(request):
+    response = dns.message.make_response(request)
+    response.set_rcode(dns.rcode.SERVFAIL)
+    return response.to_wire()
+
+def normalResponseCallback(request):
+    response = dns.message.make_response(request)
+    rrset = dns.rrset.from_text(request.question[0].name,
+                                3600,
+                                dns.rdataclass.IN,
+                                dns.rdatatype.A,
+                                '127.0.0.1')
+    response.answer.append(rrset)
+    return response.to_wire()
+
+class TestRestartQuery(DNSDistTest):
+
+    # this test suite uses different responder ports
+    _testNormalServerPort = 5420
+    _testServfailServerPort = 5421
+    _config_template = """
+    newServer{address="127.0.0.1:%d", pool='restarted'}:setUp()
+    newServer{address="127.0.0.1:%d", pool=''}:setUp()
+
+    function makeQueryRestartable(dq)
+      dq:setRestartable()
+      return DNSAction.None
+    end
+
+    function restartOnServFail(dr)
+      if dr.rcode == DNSRCode.SERVFAIL then
+        dr.pool = 'restarted'
+        dr:restart()
+      end
+
+      return DNSResponseAction.None
+    end
+
+    addAction(AllRule(), LuaAction(makeQueryRestartable))
+    addResponseAction(AllRule(), LuaResponseAction(restartOnServFail))
+    """
+    _config_params = ['_testNormalServerPort', '_testServfailServerPort']
+    _verboseMode = True
+
+    @classmethod
+    def startResponders(cls):
+        print("Launching responders..")
+
+        # servfail
+        cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServfailServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, servFailResponseCallback])
+        cls._UDPResponder.setDaemon(True)
+        cls._UDPResponder.start()
+        cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServfailServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, servFailResponseCallback])
+        cls._TCPResponder.setDaemon(True)
+        cls._TCPResponder.start()
+        cls._UDPResponderNormal = threading.Thread(name='UDP ResponderNormal', target=cls.UDPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, normalResponseCallback])
+        cls._UDPResponderNormal.setDaemon(True)
+        cls._UDPResponderNormal.start()
+        cls._TCPResponderNormal = threading.Thread(name='TCP ResponderNormal', target=cls.TCPResponder, args=[cls._testNormalServerPort, cls._toResponderQueue, cls._fromResponderQueue, False, False, normalResponseCallback])
+        cls._TCPResponderNormal.setDaemon(True)
+        cls._TCPResponderNormal.start()
+
+    def testRestartingQuery(self):
+        """
+        Restart: ServFail then restarted to a second pool
+        """
+        name = 'restart.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEquals(receivedResponse, expectedResponse)