From: Remi Gacogne Date: Thu, 5 May 2022 14:20:07 +0000 (+0200) Subject: dnsdist: Add Lua bindings to access the DNS payload as a string X-Git-Tag: auth-4.8.0-alpha0~97^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F11606%2Fhead;p=thirdparty%2Fpdns.git dnsdist: Add Lua bindings to access the DNS payload as a string --- diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index 04e37de43a..a6b80078cd 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -55,7 +55,9 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerFunction("getDO", [](const DNSQuestion& dq) { return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO; }); - + luaCtx.registerFunction("getContent", [](const DNSQuestion& dq) { + return std::string(reinterpret_cast(dq.getData().data()), dq.getData().size()); + }); luaCtx.registerFunction(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) { if (dq.ednsOptions == nullptr) { parseEDNSOptions(dq); @@ -196,6 +198,9 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerFunction("getDO", [](const DNSResponse& dq) { return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO; }); + luaCtx.registerFunction("getContent", [](const DNSResponse& dq) { + return std::string(reinterpret_cast(dq.getData().data()), dq.getData().size()); + }); luaCtx.registerFunction(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) { if (dq.ednsOptions == nullptr) { parseEDNSOptions(dq); diff --git a/pdns/dnsdistdist/docs/reference/dq.rst b/pdns/dnsdistdist/docs/reference/dq.rst index 7271f13974..002a95fc47 100644 --- a/pdns/dnsdistdist/docs/reference/dq.rst +++ b/pdns/dnsdistdist/docs/reference/dq.rst @@ -85,6 +85,12 @@ This state can be modified from the various hooks. :param int type: The type of the new value, ranging from 0 to 255 (both included) :param str value: The binary-safe value + .. method:: DNSQuestion:getContent() -> str + + .. versionadded:: 1.8.0 + + Get the content of the DNS packet as a string + .. method:: DNSQuestion:getDO() -> bool Get the value of the DNSSEC OK bit. diff --git a/regression-tests.dnsdist/test_RulesActions.py b/regression-tests.dnsdist/test_RulesActions.py index 6468502f04..f709fe49d9 100644 --- a/regression-tests.dnsdist/test_RulesActions.py +++ b/regression-tests.dnsdist/test_RulesActions.py @@ -1564,3 +1564,53 @@ class TestAdvancedSetEDNSOptionAction(DNSDistTest): self.assertEqual(expectedQuery, receivedQuery) self.checkResponseNoEDNS(response, receivedResponse) self.checkQueryEDNS(expectedQuery, receivedQuery) + +class TestAdvancedLuaGetContent(DNSDistTest): + + _config_template = """ + function accessContentLua(dq) + local expectedSize = 57 + local content = dq:getContent() + if content == nil or #content == 0 then + errlog('No content') + return DNSAction.Nxdomain, "" + end + if #content ~= expectedSize then + errlog('Invalid content size'..#content) + return DNSAction.Nxdomain, "" + end + -- the qname is right after the header, and we have only the qtype and qclass after that + local qname = string.sub(content, 13, -5) + local expectedQName = '\\011get-content\\008advanced\\005tests\\008powerdns\\003com\\000' + if qname ~= expectedQName then + errlog('Invalid qname '..qname..', expecting '..expectedQName) + return DNSAction.Nxdomain, "" + end + return DNSAction.None, "" + end + addAction(AllRule(), LuaAction(accessContentLua)) + newServer{address="127.0.0.1:%s"} + """ + + def testGetContentViaLua(self): + """ + Advanced: Test getContent() via Lua + """ + name = 'get-content.advanced.tests.powerdns.com.' + query = dns.message.make_query(name, 'AAAA', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.AAAA, + '::1') + response.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(receivedResponse, response)