From b82f0bc2772f97bb0f01582ca4a1b3cb23cd5aaa Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Wed, 5 Oct 2016 12:52:07 +0200 Subject: [PATCH] dnsdist: Add `DNSQuestion:getDO()` --- pdns/README-dnsdist.md | 1 + pdns/dnsdist-lua.cc | 3 ++ regression-tests.dnsdist/test_Advanced.py | 58 +++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index 12f2884183..8e2ef60e6c 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -1452,6 +1452,7 @@ instantiate a server with additional parameters * member `dh`: DNSHeader * member `ecsOverride`: whether an existing ECS value should be overriden (settable) * member `ecsPrefixLength`: the ECS prefix length to use (settable) + * member `getDO()`: return true if the DNSSEC OK (DO) bit is set * member `len`: the question length * member `localaddr`: ComboAddress of the local bind this question was received on * member `opcode`: the question opcode diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index ecfc53e800..f99bf45433 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -1484,6 +1484,9 @@ vector> setupLua(bool client, const std::string& confi g_lua.registerMember("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; }); g_lua.registerMember("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; }); g_lua.registerMember("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; }); + g_lua.registerFunction("getDO", [](const DNSQuestion& dq) { + return getEDNSZ((const char*)dq.dh, dq.len) & EDNS_HEADER_FLAG_DO; + }); /* LuaWrapper doesn't support inheritance */ g_lua.registerMember("localaddr", [](const DNSResponse& dq) -> const ComboAddress { return *dq.local; }, [](DNSResponse& dq, const ComboAddress newLocal) { (void) newLocal; }); diff --git a/regression-tests.dnsdist/test_Advanced.py b/regression-tests.dnsdist/test_Advanced.py index 8f4b77906c..c0ae43d81d 100644 --- a/regression-tests.dnsdist/test_Advanced.py +++ b/regression-tests.dnsdist/test_Advanced.py @@ -1185,3 +1185,61 @@ class TestAdvancedIncludeDir(DNSDistTest): (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False) self.assertEquals(receivedResponse, expectedResponse) + +class TestAdvancedLuaDO(DNSDistTest): + + _config_template = """ + function nxDOLua(dq) + if dq:getDO() then + return DNSAction.Nxdomain, "" + end + return DNSAction.None, "" + end + addLuaAction(AllRule(), nxDOLua) + newServer{address="127.0.0.1:%s"} + """ + + def testNxDOViaLua(self): + """ + Advanced: Nx DO queries via Lua + """ + name = 'nxdo.advanced.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.AAAA, + '::1') + response.answer.append(rrset) + queryWithDO = dns.message.make_query(name, 'A', 'IN', want_dnssec=True) + doResponse = dns.message.make_response(queryWithDO) + doResponse.set_rcode(dns.rcode.NXDOMAIN) + + # without DO + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(receivedResponse, response) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(receivedResponse, response) + + # with DO + (_, receivedResponse) = self.sendUDPQuery(queryWithDO, response=None, useQueue=False) + self.assertTrue(receivedResponse) + doResponse.id = receivedResponse.id + print(doResponse) + print(receivedResponse) + self.assertEquals(receivedResponse, doResponse) + + (_, receivedResponse) = self.sendTCPQuery(queryWithDO, response=None, useQueue=False) + self.assertTrue(receivedResponse) + doResponse.id = receivedResponse.id + self.assertEquals(receivedResponse, doResponse) -- 2.47.2