]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add Lua bindings to access the DNS payload as a string 11606/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 5 May 2022 14:20:07 +0000 (16:20 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 5 May 2022 14:20:07 +0000 (16:20 +0200)
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/docs/reference/dq.rst
regression-tests.dnsdist/test_RulesActions.py

index 04e37de43a711e9eb5ad56220df3c021b3a49c00..a6b80078cd8049e6f876ce0d3b1d2a1a5d7e4d40 100644 (file)
@@ -55,7 +55,9 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerFunction<bool(DNSQuestion::*)()const>("getDO", [](const DNSQuestion& dq) {
       return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO;
     });
-
+  luaCtx.registerFunction<std::string(DNSQuestion::*)()const>("getContent", [](const DNSQuestion& dq) {
+    return std::string(reinterpret_cast<const char*>(dq.getData().data()), dq.getData().size());
+  });
   luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) {
       if (dq.ednsOptions == nullptr) {
         parseEDNSOptions(dq);
@@ -196,6 +198,9 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerFunction<bool(DNSResponse::*)()const>("getDO", [](const DNSResponse& dq) {
       return getEDNSZ(dq) & EDNS_HEADER_FLAG_DO;
     });
+  luaCtx.registerFunction<std::string(DNSResponse::*)()const>("getContent", [](const DNSResponse& dq) {
+    return std::string(reinterpret_cast<const char*>(dq.getData().data()), dq.getData().size());
+  });
   luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) {
       if (dq.ednsOptions == nullptr) {
         parseEDNSOptions(dq);
index 7271f13974178b636a5b5a114d8d88fee7e59d71..002a95fc478d3a06504db68fec0d1f8494fe8643 100644 (file)
@@ -85,6 +85,12 @@ This state can be modified from the various hooks.
     :param int type: The type of the new value, ranging from 0 to 255 (both included)
     :param str value: The binary-safe value
 
+  .. method:: DNSQuestion:getContent() -> str
+
+    .. versionadded:: 1.8.0
+
+    Get the content of the DNS packet as a string
+
   .. method:: DNSQuestion:getDO() -> bool
 
     Get the value of the DNSSEC OK bit.
index 6468502f0479d4eea628053a9c50aa6b28e8d5b3..f709fe49d97c071e2f3ac3d97d2753fd128c93ea 100644 (file)
@@ -1564,3 +1564,53 @@ class TestAdvancedSetEDNSOptionAction(DNSDistTest):
             self.assertEqual(expectedQuery, receivedQuery)
             self.checkResponseNoEDNS(response, receivedResponse)
             self.checkQueryEDNS(expectedQuery, receivedQuery)
+
+class TestAdvancedLuaGetContent(DNSDistTest):
+
+    _config_template = """
+    function accessContentLua(dq)
+        local expectedSize = 57
+        local content = dq:getContent()
+        if content == nil or #content == 0 then
+            errlog('No content')
+            return DNSAction.Nxdomain, ""
+        end
+        if #content ~= expectedSize then
+            errlog('Invalid content size'..#content)
+            return DNSAction.Nxdomain, ""
+        end
+        -- the qname is right after the header, and we have only the qtype and qclass after that
+        local qname = string.sub(content, 13, -5)
+        local expectedQName = '\\011get-content\\008advanced\\005tests\\008powerdns\\003com\\000'
+        if qname ~= expectedQName then
+            errlog('Invalid qname '..qname..', expecting '..expectedQName)
+            return DNSAction.Nxdomain, ""
+        end
+        return DNSAction.None, ""
+    end
+    addAction(AllRule(), LuaAction(accessContentLua))
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testGetContentViaLua(self):
+        """
+        Advanced: Test getContent() via Lua
+        """
+        name = 'get-content.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'AAAA', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '::1')
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            self.assertTrue(receivedQuery)
+            self.assertTrue(receivedResponse)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(receivedResponse, response)