]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add a DNSHeader:getTC() Lua binding
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 14 Aug 2023 15:02:39 +0000 (17:02 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 14 Aug 2023 15:03:48 +0000 (17:03 +0200)
pdns/dnsdist-lua-bindings.cc
pdns/dnsdistdist/docs/reference/dq.rst
regression-tests.dnsdist/test_Lua.py

index 225f48b9556f90d7d397231d6eecf60d9a47828e..35a950590bc3d96511578dff8672f53bc1533d1e 100644 (file)
@@ -192,10 +192,14 @@ void setupLuaBindings(LuaContext& luaCtx, bool client)
       return (bool)dh.cd;
     });
 
-    luaCtx.registerFunction<uint16_t(dnsheader::*)()const>("getID", [](const dnsheader& dh) {
+  luaCtx.registerFunction<uint16_t(dnsheader::*)()const>("getID", [](const dnsheader& dh) {
       return ntohs(dh.id);
     });
 
+  luaCtx.registerFunction<bool(dnsheader::*)()const>("getTC", [](const dnsheader& dh) {
+      return (bool)dh.tc;
+    });
+
   luaCtx.registerFunction<void(dnsheader::*)(bool)>("setTC", [](dnsheader& dh, bool v) {
       dh.tc=v;
       if(v) dh.ra = dh.rd; // you'll always need this, otherwise TC=1 gets ignored
index 6025c64e166a2b168e3acd0bb0fed756ff1da9b7..f90d67ac1b9368b76b847fa4bd34e66519947022 100644 (file)
@@ -504,6 +504,12 @@ DNSHeader (``dh``) object
 
     Get recursion desired flag.
 
+  .. method:: DNSHeader:getTC() -> int
+
+    .. versionadded:: 1.8.1
+
+    Get the TC flag.
+
   .. method:: DNSHeader:setAA(aa)
 
     Set authoritative answer flag.
index 656c287ec9d1351bddaf39640430d87554ffe2ed..3d689477c220ba1fe21e3e6b02554eeab7e74935 100644 (file)
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 
 import base64
+import dns
 import time
 import unittest
 from dnsdisttests import DNSDistTest
@@ -40,3 +41,55 @@ class TestLuaThread(DNSDistTest):
         time.sleep(3)
         count2 = self.sendConsoleCommand('counter')
         self.assertTrue(count2 > count1)
+
+class TestLuaDNSHeaderBindings(DNSDistTest):
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    function checkTCSet(dq)
+      local tc = dq.dh:getTC()
+      if not tc then
+        return DNSAction.Spoof, 'tc-not-set.check-tc.lua-dnsheaders.tests.powerdns.com.'
+      end
+      return DNSAction.Allow
+    end
+
+    addAction('check-tc.lua-dnsheaders.tests.powerdns.com.', LuaAction(checkTCSet))
+    """
+
+    def testLuaGetTC(self):
+        """
+        LuaDNSHeaders: TC
+        """
+        name = 'notset.check-tc.lua-dnsheaders.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    'tc-not-set.check-tc.lua-dnsheaders.tests.powerdns.com.')
+        response.answer.append(rrset)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertEqual(response, receivedResponse)
+
+        name = 'set.check-tc.lua-dnsheaders.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+        query.flags |= dns.flags.TC
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(response, receivedResponse)