g_lua.registerMember<bool (DNSQuestion::*)>("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; });
g_lua.registerMember<bool (DNSQuestion::*)>("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; });
g_lua.registerMember<uint16_t (DNSQuestion::*)>("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; });
+ g_lua.registerFunction<bool(DNSQuestion::*)()>("getDO", [](const DNSQuestion& dq) {
+ return getEDNSZ((const char*)dq.dh, dq.len) & EDNS_HEADER_FLAG_DO;
+ });
/* LuaWrapper doesn't support inheritance */
g_lua.registerMember<const ComboAddress (DNSResponse::*)>("localaddr", [](const DNSResponse& dq) -> const ComboAddress { return *dq.local; }, [](DNSResponse& dq, const ComboAddress newLocal) { (void) newLocal; });
(_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
self.assertEquals(receivedResponse, expectedResponse)
+
+class TestAdvancedLuaDO(DNSDistTest):
+
+ _config_template = """
+ function nxDOLua(dq)
+ if dq:getDO() then
+ return DNSAction.Nxdomain, ""
+ end
+ return DNSAction.None, ""
+ end
+ addLuaAction(AllRule(), nxDOLua)
+ newServer{address="127.0.0.1:%s"}
+ """
+
+ def testNxDOViaLua(self):
+ """
+ Advanced: Nx DO queries via Lua
+ """
+ name = 'nxdo.advanced.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.AAAA,
+ '::1')
+ response.answer.append(rrset)
+ queryWithDO = dns.message.make_query(name, 'A', 'IN', want_dnssec=True)
+ doResponse = dns.message.make_response(queryWithDO)
+ doResponse.set_rcode(dns.rcode.NXDOMAIN)
+
+ # without DO
+ (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+
+ (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+ self.assertTrue(receivedQuery)
+ self.assertTrue(receivedResponse)
+ receivedQuery.id = query.id
+ self.assertEquals(query, receivedQuery)
+ self.assertEquals(receivedResponse, response)
+
+ # with DO
+ (_, receivedResponse) = self.sendUDPQuery(queryWithDO, response=None, useQueue=False)
+ self.assertTrue(receivedResponse)
+ doResponse.id = receivedResponse.id
+ print(doResponse)
+ print(receivedResponse)
+ self.assertEquals(receivedResponse, doResponse)
+
+ (_, receivedResponse) = self.sendTCPQuery(queryWithDO, response=None, useQueue=False)
+ self.assertTrue(receivedResponse)
+ doResponse.id = receivedResponse.id
+ self.assertEquals(receivedResponse, doResponse)