From: Remi Gacogne Date: Wed, 21 Dec 2022 13:13:10 +0000 (+0100) Subject: dnsdist: Add regular Lua bindings for async handling of queries X-Git-Tag: dnsdist-1.8.0-rc1~86^2~14 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=27f38b46fa11a81fd88af8f536cf10fc212a8755;p=thirdparty%2Fpdns.git dnsdist: Add regular Lua bindings for async handling of queries --- diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index 52b4b57152..51052fc24d 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -1114,3 +1114,41 @@ bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsDa return true; } + +namespace dnsdist { +bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uint8_t rcode, bool clearAnswers) +{ + const auto qnameLength = state.qname.wirelength(); + if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) { + return false; + } + + EDNS0Record edns0; + bool hadEDNS = false; + if (clearAnswers) { + hadEDNS = getEDNS0Record(buffer, edns0); + } + + auto dh = reinterpret_cast(buffer.data()); + dh->rcode = rcode; + dh->ad = false; + dh->aa = false; + dh->ra = dh->rd; + dh->qr = true; + + if (clearAnswers) { + dh->ancount = 0; + dh->nscount = 0; + dh->arcount = 0; + buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)); + if (hadEDNS) { + DNSQuestion dq(state, buffer); + if (!addEDNS(buffer, dq.getMaximumSize(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) { + return false; + } + } + } + + return true; +} +} diff --git a/pdns/dnsdist-ecs.hh b/pdns/dnsdist-ecs.hh index f5dbc56c25..653052df81 100644 --- a/pdns/dnsdist-ecs.hh +++ b/pdns/dnsdist-ecs.hh @@ -57,3 +57,7 @@ bool queryHasEDNS(const DNSQuestion& dq); bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0); bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& data); + +namespace dnsdist { +bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uint8_t rcode, bool clearAnswers); +} diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index 5e16d49de9..ba3b95ff87 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -20,6 +20,7 @@ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ #include "dnsdist.hh" +#include "dnsdist-async.hh" #include "dnsdist-ecs.hh" #include "dnsdist-lua.hh" #include "dnsparser.hh" @@ -58,6 +59,13 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerFunction("getContent", [](const DNSQuestion& dq) { return std::string(reinterpret_cast(dq.getData().data()), dq.getData().size()); }); + luaCtx.registerFunction("setContent", [](DNSQuestion& dq, const std::string& raw) { + uint16_t oldID = dq.getHeader()->id; + auto& buffer = dq.getMutableData(); + buffer.clear(); + buffer.insert(buffer.begin(), raw.begin(), raw.end()); + reinterpret_cast(buffer.data())->id = oldID; + }); luaCtx.registerFunction(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) { if (dq.ednsOptions == nullptr) { parseEDNSOptions(dq); @@ -188,6 +196,86 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) setEDNSOption(dq, code, data); }); + luaCtx.registerFunction("suspend", [](DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) { + dq.asynchronous = true; + return dnsdist::suspendQuery(dq, asyncID, queryID, timeoutMs); + }); + +class AsynchronousObject +{ +public: + AsynchronousObject(std::unique_ptr&& obj_): object(std::move(obj_)) + { + } + + DNSQuestion getDQ() const + { + return object->getDQ(); + } + + DNSResponse getDR() const + { + return object->getDR(); + } + + bool resume() + { + return dnsdist::queueQueryResumptionEvent(std::move(object)); + } + + bool drop() + { + auto sender = object->getTCPQuerySender(); + if (!sender) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + sender->notifyIOError(std::move(object->query.d_idstate), now); + return true; + } + + bool setRCode(uint8_t rcode, bool clearAnswers) + { + return dnsdist::setInternalQueryRCode(object->query.d_idstate, object->query.d_buffer, rcode, clearAnswers); + } + +private: + std::unique_ptr object; +}; + + luaCtx.registerFunction("getDQ", [](const AsynchronousObject& obj) { + return obj.getDQ(); + }); + + luaCtx.registerFunction("getDR", [](const AsynchronousObject& obj) { + return obj.getDR(); + }); + + luaCtx.registerFunction("resume", [](AsynchronousObject& obj) { + return obj.resume(); + }); + + luaCtx.registerFunction("drop", [](AsynchronousObject& obj) { + return obj.drop(); + }); + + luaCtx.registerFunction("setRCode", [](AsynchronousObject& obj, uint8_t rcode, bool clearAnswers) { + return obj.setRCode(rcode, clearAnswers); + }); + + luaCtx.writeFunction("getAsynchronousObject", [](uint16_t asyncID, uint16_t queryID) -> AsynchronousObject { + if (!dnsdist::g_asyncHolder) { + throw std::runtime_error("Unable to resume, no asynchronous holder"); + } + auto query = dnsdist::g_asyncHolder->get(asyncID, queryID); + if (!query) { + throw std::runtime_error("Unable to find asynchronous object"); + } + return AsynchronousObject(std::move(query)); + }); + /* LuaWrapper doesn't support inheritance */ luaCtx.registerMember("localaddr", [](const DNSResponse& dq) -> const ComboAddress { return dq.ids.origDest; }, [](DNSResponse& dq, const ComboAddress newLocal) { (void) newLocal; }); luaCtx.registerMember("qname", [](const DNSResponse& dq) -> const DNSName { return dq.ids.qname; }, [](DNSResponse& dq, const DNSName newName) { (void) newName; }); @@ -209,6 +297,14 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerFunction("getContent", [](const DNSResponse& dq) { return std::string(reinterpret_cast(dq.getData().data()), dq.getData().size()); }); + luaCtx.registerFunction("setContent", [](DNSResponse& dr, const std::string& raw) { + uint16_t oldID = dr.getHeader()->id; + auto& buffer = dr.getMutableData(); + buffer.clear(); + buffer.insert(buffer.begin(), raw.begin(), raw.end()); + reinterpret_cast(buffer.data())->id = oldID; + }); + luaCtx.registerFunction(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) { if (dq.ednsOptions == nullptr) { parseEDNSOptions(dq); @@ -325,5 +421,10 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) return setNegativeAndAdditionalSOA(dq, nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), serial, refresh, retry, expire, minimum, false); }); + + luaCtx.registerFunction("suspend", [](DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) { + dr.asynchronous = true; + return dnsdist::suspendResponse(dr, asyncID, queryID, timeoutMs); + }); #endif /* DISABLE_NON_FFI_DQ_BINDINGS */ } diff --git a/pdns/dnsdist-lua-bindings.cc b/pdns/dnsdist-lua-bindings.cc index 63c9be3ae2..555bbc2d0a 100644 --- a/pdns/dnsdist-lua-bindings.cc +++ b/pdns/dnsdist-lua-bindings.cc @@ -156,7 +156,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client) dh.rd=v; }); - luaCtx.registerFunction("getRD", [](dnsheader& dh) { + luaCtx.registerFunction("getRD", [](const dnsheader& dh) { return (bool)dh.rd; }); @@ -164,7 +164,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client) dh.ra=v; }); - luaCtx.registerFunction("getRA", [](dnsheader& dh) { + luaCtx.registerFunction("getRA", [](const dnsheader& dh) { return (bool)dh.ra; }); @@ -172,7 +172,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client) dh.ad=v; }); - luaCtx.registerFunction("getAD", [](dnsheader& dh) { + luaCtx.registerFunction("getAD", [](const dnsheader& dh) { return (bool)dh.ad; }); @@ -180,7 +180,7 @@ void setupLuaBindings(LuaContext& luaCtx, bool client) dh.aa=v; }); - luaCtx.registerFunction("getAA", [](dnsheader& dh) { + luaCtx.registerFunction("getAA", [](const dnsheader& dh) { return (bool)dh.aa; }); @@ -188,10 +188,14 @@ void setupLuaBindings(LuaContext& luaCtx, bool client) dh.cd=v; }); - luaCtx.registerFunction("getCD", [](dnsheader& dh) { + luaCtx.registerFunction("getCD", [](const dnsheader& dh) { return (bool)dh.cd; }); + luaCtx.registerFunction("getID", [](const dnsheader& dh) { + return ntohs(dh.id); + }); + luaCtx.registerFunction("setTC", [](dnsheader& dh, bool v) { dh.tc=v; if(v) dh.ra = dh.rd; // you'll always need this, otherwise TC=1 gets ignored diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index f204b0aca0..71dd9fb625 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -22,6 +22,7 @@ #include "dnsdist-async.hh" #include "dnsdist-dnsparser.hh" +#include "dnsdist-ecs.hh" #include "dnsdist-lua-ffi.hh" #include "dnsdist-mac-address.hh" #include "dnsdist-lua-network.hh" @@ -694,8 +695,7 @@ bool dnsdist_ffi_dnsquestion_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t a { try { dq->dq->asynchronous = true; - dnsdist::suspendQuery(*dq->dq, asyncID, queryID, timeoutMs); - return true; + return dnsdist::suspendQuery(*dq->dq, asyncID, queryID, timeoutMs); } catch (const std::exception& e) { vinfolog("Error in dnsdist_ffi_dnsquestion_set_async: %s", e.what()); @@ -717,8 +717,7 @@ bool dnsdist_ffi_dnsresponse_set_async(dnsdist_ffi_dnsquestion_t* dq, uint16_t a return false; } - dnsdist::suspendResponse(*dr, asyncID, queryID, timeoutMs); - return true; + return dnsdist::suspendResponse(*dr, asyncID, queryID, timeoutMs); } catch (const std::exception& e) { vinfolog("Error in dnsdist_ffi_dnsresponse_set_async: %s", e.what()); @@ -767,37 +766,10 @@ bool dnsdist_ffi_set_rcode_from_async(uint16_t asyncID, uint16_t queryID, uint8_ return false; } - const auto qnameLength = query->query.d_idstate.qname.wirelength(); - auto& buffer = query->query.d_buffer; - if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) { + if (!dnsdist::setInternalQueryRCode(query->query.d_idstate, query->query.d_buffer, rcode, clearAnswers)) { return false; } - EDNS0Record edns0; - bool hadEDNS = false; - if (clearAnswers) { - hadEDNS = getEDNS0Record(buffer, edns0); - } - - auto dh = reinterpret_cast(buffer.data()); - dh->rcode = rcode; - dh->ad = false; - dh->aa = false; - dh->ra = dh->rd; - dh->qr = true; - - if (clearAnswers) { - dh->ancount = 0; - dh->nscount = 0; - dh->arcount = 0; - buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)); - if (hadEDNS) { - if (!addEDNS(buffer, query->query.d_idstate.protocol.isUDP() ? 4096 : std::numeric_limits::max(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) { - return false; - } - } - } - query->query.d_idstate.skipCache = true; return dnsdist::queueQueryResumptionEvent(std::move(query));