]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add regular Lua bindings for async handling of queries
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 21 Dec 2022 13:13:10 +0000 (14:13 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 13 Jan 2023 15:57:50 +0000 (16:57 +0100)
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist-lua-bindings.cc
pdns/dnsdistdist/dnsdist-lua-ffi.cc

index 52b4b571526a810c41bac41d5eefbcaf4565dd46..51052fc24da65bf18215eb1fd81d7498d95ba323 100644 (file)
@@ -1114,3 +1114,41 @@ bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsDa
 
   return true;
 }
+
+namespace dnsdist {
+bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer,  uint8_t rcode, bool clearAnswers)
+{
+  const auto qnameLength = state.qname.wirelength();
+  if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
+    return false;
+  }
+
+  EDNS0Record edns0;
+  bool hadEDNS = false;
+  if (clearAnswers) {
+    hadEDNS = getEDNS0Record(buffer, edns0);
+  }
+
+  auto dh = reinterpret_cast<dnsheader*>(buffer.data());
+  dh->rcode = rcode;
+  dh->ad = false;
+  dh->aa = false;
+  dh->ra = dh->rd;
+  dh->qr = true;
+
+  if (clearAnswers) {
+    dh->ancount = 0;
+    dh->nscount = 0;
+    dh->arcount = 0;
+    buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
+    if (hadEDNS) {
+      DNSQuestion dq(state, buffer);
+      if (!addEDNS(buffer, dq.getMaximumSize(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) {
+        return false;
+      }
+    }
+  }
+
+  return true;
+}
+}
index f5dbc56c253ba99c8a2f5448c0736c125629aa18..653052df81965d3ae1db422ed4a245e828d82bf5 100644 (file)
@@ -57,3 +57,7 @@ bool queryHasEDNS(const DNSQuestion& dq);
 bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0);
 
 bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& data);
+
+namespace dnsdist {
+bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer,  uint8_t rcode, bool clearAnswers);
+}
index 5e16d49de9322843dd5275c18b6f6a924578cb29..ba3b95ff8740d575577df0fd865653d335c6db3e 100644 (file)
@@ -20,6 +20,7 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 #include "dnsdist.hh"
+#include "dnsdist-async.hh"
 #include "dnsdist-ecs.hh"
 #include "dnsdist-lua.hh"
 #include "dnsparser.hh"
@@ -58,6 +59,13 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerFunction<std::string(DNSQuestion::*)()const>("getContent", [](const DNSQuestion& dq) {
     return std::string(reinterpret_cast<const char*>(dq.getData().data()), dq.getData().size());
   });
+  luaCtx.registerFunction<void(DNSQuestion::*)(const std::string&)>("setContent", [](DNSQuestion& dq, const std::string& raw) {
+    uint16_t oldID = dq.getHeader()->id;
+    auto& buffer = dq.getMutableData();
+    buffer.clear();
+    buffer.insert(buffer.begin(), raw.begin(), raw.end());
+    reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+  });
   luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) {
       if (dq.ednsOptions == nullptr) {
         parseEDNSOptions(dq);
@@ -188,6 +196,86 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     setEDNSOption(dq, code, data);
   });
 
+  luaCtx.registerFunction<bool(DNSQuestion::*)(uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)>("suspend", [](DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) {
+    dq.asynchronous = true;
+    return dnsdist::suspendQuery(dq, asyncID, queryID, timeoutMs);
+  });
+
+class AsynchronousObject
+{
+public:
+  AsynchronousObject(std::unique_ptr<CrossProtocolQuery>&& obj_): object(std::move(obj_))
+  {
+  }
+
+  DNSQuestion getDQ() const
+  {
+    return object->getDQ();
+  }
+
+  DNSResponse getDR() const
+  {
+    return object->getDR();
+  }
+
+  bool resume()
+  {
+    return dnsdist::queueQueryResumptionEvent(std::move(object));
+  }
+
+  bool drop()
+  {
+    auto sender = object->getTCPQuerySender();
+    if (!sender) {
+      return false;
+    }
+
+    struct timeval now;
+    gettimeofday(&now, nullptr);
+    sender->notifyIOError(std::move(object->query.d_idstate), now);
+    return true;
+  }
+
+  bool setRCode(uint8_t rcode, bool clearAnswers)
+  {
+    return dnsdist::setInternalQueryRCode(object->query.d_idstate, object->query.d_buffer, rcode, clearAnswers);
+  }
+
+private:
+  std::unique_ptr<CrossProtocolQuery> object;
+};
+
+  luaCtx.registerFunction<DNSQuestion(AsynchronousObject::*)(void) const>("getDQ", [](const AsynchronousObject& obj) {
+      return obj.getDQ();
+    });
+
+  luaCtx.registerFunction<DNSQuestion(AsynchronousObject::*)(void) const>("getDR", [](const AsynchronousObject& obj) {
+      return obj.getDR();
+    });
+
+  luaCtx.registerFunction<bool(AsynchronousObject::*)(void)>("resume", [](AsynchronousObject& obj) {
+      return obj.resume();
+    });
+
+  luaCtx.registerFunction<bool(AsynchronousObject::*)(void)>("drop", [](AsynchronousObject& obj) {
+      return obj.drop();
+    });
+
+  luaCtx.registerFunction<bool(AsynchronousObject::*)(uint8_t, bool)>("setRCode", [](AsynchronousObject& obj, uint8_t rcode, bool clearAnswers) {
+    return obj.setRCode(rcode, clearAnswers);
+  });
+
+  luaCtx.writeFunction("getAsynchronousObject", [](uint16_t asyncID, uint16_t queryID) -> AsynchronousObject {
+    if (!dnsdist::g_asyncHolder) {
+      throw std::runtime_error("Unable to resume, no asynchronous holder");
+    }
+    auto query = dnsdist::g_asyncHolder->get(asyncID, queryID);
+    if (!query) {
+      throw std::runtime_error("Unable to find asynchronous object");
+    }
+    return AsynchronousObject(std::move(query));
+  });
+
   /* LuaWrapper doesn't support inheritance */
   luaCtx.registerMember<const ComboAddress (DNSResponse::*)>("localaddr", [](const DNSResponse& dq) -> const ComboAddress { return dq.ids.origDest; }, [](DNSResponse& dq, const ComboAddress newLocal) { (void) newLocal; });
   luaCtx.registerMember<const DNSName (DNSResponse::*)>("qname", [](const DNSResponse& dq) -> const DNSName { return dq.ids.qname; }, [](DNSResponse& dq, const DNSName newName) { (void) newName; });
@@ -209,6 +297,14 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerFunction<std::string(DNSResponse::*)()const>("getContent", [](const DNSResponse& dq) {
     return std::string(reinterpret_cast<const char*>(dq.getData().data()), dq.getData().size());
   });
+  luaCtx.registerFunction<void(DNSResponse::*)(const std::string&)>("setContent", [](DNSResponse& dr, const std::string& raw) {
+    uint16_t oldID = dr.getHeader()->id;
+    auto& buffer = dr.getMutableData();
+    buffer.clear();
+    buffer.insert(buffer.begin(), raw.begin(), raw.end());
+    reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+  });
+
   luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) {
       if (dq.ednsOptions == nullptr) {
         parseEDNSOptions(dq);
@@ -325,5 +421,10 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
 
       return setNegativeAndAdditionalSOA(dq, nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), serial, refresh, retry, expire, minimum, false);
     });
+
+  luaCtx.registerFunction<bool(DNSResponse::*)(uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)>("suspend", [](DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) {
+    dr.asynchronous = true;
+    return dnsdist::suspendResponse(dr, asyncID, queryID, timeoutMs);
+  });
 #endif /* DISABLE_NON_FFI_DQ_BINDINGS */
 }
index 63c9be3ae2716eff887b315f81d2f60ab7ca0f1f..555bbc2d0a62d95f43bf56a4c50f0abb2bc7254f 100644 (file)
@@ -156,7 +156,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
       dh.rd=v;
     });
 
-  luaCtx.registerFunction<bool(dnsheader::*)()>("getRD", [](dnsheader& dh) {
+  luaCtx.registerFunction<bool(dnsheader::*)()const>("getRD", [](const dnsheader& dh) {
       return (bool)dh.rd;
     });
 
@@ -164,7 +164,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
       dh.ra=v;
     });
 
-  luaCtx.registerFunction<bool(dnsheader::*)()>("getRA", [](dnsheader& dh) {
+  luaCtx.registerFunction<bool(dnsheader::*)()const>("getRA", [](const dnsheader& dh) {
       return (bool)dh.ra;
     });
 
@@ -172,7 +172,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
       dh.ad=v;
     });
 
-  luaCtx.registerFunction<bool(dnsheader::*)()>("getAD", [](dnsheader& dh) {
+  luaCtx.registerFunction<bool(dnsheader::*)()const>("getAD", [](const dnsheader& dh) {
       return (bool)dh.ad;
     });
 
@@ -180,7 +180,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
       dh.aa=v;
     });
 
-  luaCtx.registerFunction<bool(dnsheader::*)()>("getAA", [](dnsheader& dh) {
+  luaCtx.registerFunction<bool(dnsheader::*)()const>("getAA", [](const dnsheader& dh) {
       return (bool)dh.aa;
     });
 
@@ -188,10 +188,14 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
       dh.cd=v;
     });
 
-  luaCtx.registerFunction<bool(dnsheader::*)()>("getCD", [](dnsheader& dh) {
+  luaCtx.registerFunction<bool(dnsheader::*)()const >("getCD", [](const dnsheader& dh) {
       return (bool)dh.cd;
     });
 
+    luaCtx.registerFunction<uint16_t(dnsheader::*)()const>("getID", [](const dnsheader& dh) {
+      return ntohs(dh.id);
+    });
+
   luaCtx.registerFunction<void(dnsheader::*)(bool)>("setTC", [](dnsheader& dh, bool v) {
       dh.tc=v;
       if(v) dh.ra = dh.rd; // you'll always need this, otherwise TC=1 gets ignored
index f204b0aca09d7063bf3c1a82ff4dcc531d003b4e..71dd9fb6250fdc767af292456565ca33f6dde92f 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "dnsdist-async.hh"
 #include "dnsdist-dnsparser.hh"
+#include "dnsdist-ecs.hh"
 #include "dnsdist-lua-ffi.hh"
 #include "dnsdist-mac-address.hh"
 #include "dnsdist-lua-network.hh"
@@ -694,8 +695,7 @@ bool dnsdist_ffi_dnsquestion_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t a
 {
   try {
     dq->dq->asynchronous = true;
-    dnsdist::suspendQuery(*dq->dq, asyncID, queryID, timeoutMs);
-    return true;
+    return dnsdist::suspendQuery(*dq->dq, asyncID, queryID, timeoutMs);
   }
   catch (const std::exception& e) {
     vinfolog("Error in dnsdist_ffi_dnsquestion_set_async: %s", e.what());
@@ -717,8 +717,7 @@ bool dnsdist_ffi_dnsresponse_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t a
       return false;
     }
 
-    dnsdist::suspendResponse(*dr, asyncID, queryID, timeoutMs);
-    return true;
+    return dnsdist::suspendResponse(*dr, asyncID, queryID, timeoutMs);
   }
   catch (const std::exception& e) {
     vinfolog("Error in dnsdist_ffi_dnsresponse_set_async: %s", e.what());
@@ -767,37 +766,10 @@ bool dnsdist_ffi_set_rcode_from_async(uint16_t asyncID, uint16_t queryID, uint8_
     return false;
   }
 
-  const auto qnameLength = query->query.d_idstate.qname.wirelength();
-  auto& buffer = query->query.d_buffer;
-  if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
+  if (!dnsdist::setInternalQueryRCode(query->query.d_idstate, query->query.d_buffer, rcode, clearAnswers)) {
     return false;
   }
 
-  EDNS0Record edns0;
-  bool hadEDNS = false;
-  if (clearAnswers) {
-    hadEDNS = getEDNS0Record(buffer, edns0);
-  }
-
-  auto dh = reinterpret_cast<dnsheader*>(buffer.data());
-  dh->rcode = rcode;
-  dh->ad = false;
-  dh->aa = false;
-  dh->ra = dh->rd;
-  dh->qr = true;
-
-  if (clearAnswers) {
-    dh->ancount = 0;
-    dh->nscount = 0;
-    dh->arcount = 0;
-    buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
-    if (hadEDNS) {
-      if (!addEDNS(buffer, query->query.d_idstate.protocol.isUDP() ? 4096 : std::numeric_limits<uint16_t>::max(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) {
-        return false;
-      }
-    }
-  }
-
   query->query.d_idstate.skipCache = true;
 
   return dnsdist::queueQueryResumptionEvent(std::move(query));