luaCtx.registerFunction<bool(DNSQuestion::*)()const>("getDO", [](const DNSQuestion& dq) {
return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO;
});
-
+ luaCtx.registerFunction<std::string(DNSQuestion::*)()const>("getContent", [](const DNSQuestion& dq) {
+ return std::string(reinterpret_cast<const char*>(dq.getData().data()), dq.getData().size());
+ });
luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) {
if (dq.ednsOptions == nullptr) {
parseEDNSOptions(dq);
luaCtx.registerFunction<bool(DNSResponse::*)()const>("getDO", [](const DNSResponse& dq) {
return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO;
});
+ luaCtx.registerFunction<std::string(DNSResponse::*)()const>("getContent", [](const DNSResponse& dq) {
+ return std::string(reinterpret_cast<const char*>(dq.getData().data()), dq.getData().size());
+ });
luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) {
if (dq.ednsOptions == nullptr) {
parseEDNSOptions(dq);
self.assertEqual(expectedQuery, receivedQuery)
self.checkResponseNoEDNS(response, receivedResponse)
self.checkQueryEDNS(expectedQuery, receivedQuery)
+
+class TestAdvancedLuaGetContent(DNSDistTest):
+
+ _config_template = """
+ function accessContentLua(dq)
+ local expectedSize = 57
+ local content = dq:getContent()
+ if content == nil or #content == 0 then
+ errlog('No content')
+ return DNSAction.Nxdomain, ""
+ end
+ if #content ~= expectedSize then
+ errlog('Invalid content size'..#content)
+ return DNSAction.Nxdomain, ""
+ end
+ -- the qname is right after the header, and we have only the qtype and qclass after that
+ local qname = string.sub(content, 13, -5)
+ local expectedQName = '\\011get-content\\008advanced\\005tests\\008powerdns\\003com\\000'
+ if qname ~= expectedQName then
+ errlog('Invalid qname '..qname..', expecting '..expectedQName)
+ return DNSAction.Nxdomain, ""
+ end
+ return DNSAction.None, ""
+ end
+ addAction(AllRule(), LuaAction(accessContentLua))
+ newServer{address="127.0.0.1:%s"}
+ """
+
+ def testGetContentViaLua(self):
+ """
+ Advanced: Test getContent() via Lua
+ """
+ name = 'get-content.advanced.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'AAAA', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '::1')
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(receivedResponse, response)