From: Remi Gacogne Date: Mon, 14 Aug 2023 15:02:39 +0000 (+0200) Subject: dnsdist: Add a DNSHeader:getTC() Lua binding X-Git-Tag: rec-5.0.0-alpha1~57^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2be955a17cd79d92dd5ab3975481c4a53c6369a1;p=thirdparty%2Fpdns.git dnsdist: Add a DNSHeader:getTC() Lua binding --- diff --git a/pdns/dnsdist-lua-bindings.cc b/pdns/dnsdist-lua-bindings.cc index 225f48b955..35a950590b 100644 --- a/pdns/dnsdist-lua-bindings.cc +++ b/pdns/dnsdist-lua-bindings.cc @@ -192,10 +192,14 @@ void setupLuaBindings(LuaContext& luaCtx, bool client) return (bool)dh.cd; }); - luaCtx.registerFunction("getID", [](const dnsheader& dh) { + luaCtx.registerFunction("getID", [](const dnsheader& dh) { return ntohs(dh.id); }); + luaCtx.registerFunction("getTC", [](const dnsheader& dh) { + return (bool)dh.tc; + }); + luaCtx.registerFunction("setTC", [](dnsheader& dh, bool v) { dh.tc=v; if(v) dh.ra = dh.rd; // you'll always need this, otherwise TC=1 gets ignored diff --git a/pdns/dnsdistdist/docs/reference/dq.rst b/pdns/dnsdistdist/docs/reference/dq.rst index 6025c64e16..f90d67ac1b 100644 --- a/pdns/dnsdistdist/docs/reference/dq.rst +++ b/pdns/dnsdistdist/docs/reference/dq.rst @@ -504,6 +504,12 @@ DNSHeader (``dh``) object Get recursion desired flag. + .. method:: DNSHeader:getTC() -> int + + .. versionadded:: 1.8.1 + + Get the TC flag. + .. method:: DNSHeader:setAA(aa) Set authoritative answer flag. diff --git a/regression-tests.dnsdist/test_Lua.py b/regression-tests.dnsdist/test_Lua.py index 656c287ec9..3d689477c2 100644 --- a/regression-tests.dnsdist/test_Lua.py +++ b/regression-tests.dnsdist/test_Lua.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import base64 +import dns import time import unittest from dnsdisttests import DNSDistTest @@ -40,3 +41,55 @@ class TestLuaThread(DNSDistTest): time.sleep(3) count2 = self.sendConsoleCommand('counter') self.assertTrue(count2 > count1) + +class TestLuaDNSHeaderBindings(DNSDistTest): + _config_template = """ + newServer{address="127.0.0.1:%s"} + + function checkTCSet(dq) + local tc = dq.dh:getTC() + if not tc then + return DNSAction.Spoof, 'tc-not-set.check-tc.lua-dnsheaders.tests.powerdns.com.' + end + return DNSAction.Allow + end + + addAction('check-tc.lua-dnsheaders.tests.powerdns.com.', LuaAction(checkTCSet)) + """ + + def testLuaGetTC(self): + """ + LuaDNSHeaders: TC + """ + name = 'notset.check-tc.lua-dnsheaders.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.CNAME, + 'tc-not-set.check-tc.lua-dnsheaders.tests.powerdns.com.') + response.answer.append(rrset) + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertEqual(response, receivedResponse) + + name = 'set.check-tc.lua-dnsheaders.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + query.flags |= dns.flags.TC + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (receivedQuery, receivedResponse) = sender(query, response) + receivedQuery.id = query.id + self.assertEqual(query, receivedQuery) + self.assertEqual(response, receivedResponse)