]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Expose trailing data as a Lua string
authorRichard Gibson <richard.gibson@gmail.com>
Tue, 16 Oct 2018 21:39:28 +0000 (17:39 -0400)
committerRichard Gibson <richard.gibson@gmail.com>
Tue, 16 Oct 2018 21:46:27 +0000 (17:46 -0400)
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdistdist/docs/reference/dq.rst
regression-tests.dnsdist/test_Trailing.py

index 147efb7ee0de05658bbefe1cf990bbf9acad3d4d..2c7d77416959fa04ed225a84b7362aac651e2240 100644 (file)
@@ -63,24 +63,24 @@ void setupLuaBindingsDNSQuestion()
 
       return *dq.ednsOptions;
     });
-  g_lua.registerFunction<vector<uint8_t>(DNSQuestion::*)(void)>("getTrailingData", [](const DNSQuestion& dq) {
-      const uint8_t* message = reinterpret_cast<const uint8_t*>(dq.dh);
-      const uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(message), dq.len);
-      vector<uint8_t> tail(message + length, message + dq.len);
+  g_lua.registerFunction<std::string(DNSQuestion::*)(void)>("getTrailingData", [](const DNSQuestion& dq) {
+      const char* message = reinterpret_cast<const char*>(dq.dh);
+      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
+      const std::string tail = std::string(message + messageLen, dq.len - messageLen);
       return tail;
     });
-  g_lua.registerFunction<bool(DNSQuestion::*)(vector<pair<int, uint8_t>>)>("setTrailingData", [](DNSQuestion& dq, const vector<pair<int, uint8_t>>&data) {
-      uint8_t* message = reinterpret_cast<uint8_t*>(dq.dh);
-      const uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(message), dq.len);
-      if(length + data.size() > dq.size) {
+  g_lua.registerFunction<bool(DNSQuestion::*)(std::string)>("setTrailingData", [](DNSQuestion& dq, const std::string& tail) {
+      char* message = reinterpret_cast<char*>(dq.dh);
+      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
+      const uint16_t tailLen = tail.size();
+      if(messageLen + tailLen > dq.size) {
         return false;
       }
 
-      /* Copy data from the Lua array, whose first index is 1 instead of 0. */
-      dq.len = length + data.size();
-      uint8_t* tail = message + length - 1;
-      for(const auto& pair : data) {
-        *(tail + pair.first) = pair.second;
+      /* Update length and copy data from the Lua string. */
+      dq.len = messageLen + tailLen;
+      if(tailLen > 0) {
+        tail.copy(message + messageLen, tailLen);
       }
       return true;
     });
@@ -144,24 +144,24 @@ void setupLuaBindingsDNSQuestion()
   g_lua.registerFunction<void(DNSResponse::*)(std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc)>("editTTLs", [](const DNSResponse& dr, std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc) {
         editDNSPacketTTL((char*) dr.dh, dr.len, editFunc);
       });
-  g_lua.registerFunction<vector<uint8_t>(DNSResponse::*)(void)>("getTrailingData", [](const DNSResponse& dq) {
-      const uint8_t* message = reinterpret_cast<const uint8_t*>(dq.dh);
-      const uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(message), dq.len);
-      vector<uint8_t> tail(message + length, message + dq.len);
+  g_lua.registerFunction<std::string(DNSResponse::*)(void)>("getTrailingData", [](const DNSResponse& dq) {
+      const char* message = reinterpret_cast<const char*>(dq.dh);
+      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
+      const std::string tail = std::string(message + messageLen, dq.len - messageLen);
       return tail;
     });
-  g_lua.registerFunction<bool(DNSResponse::*)(vector<pair<int, uint8_t>>)>("setTrailingData", [](DNSResponse& dq, const vector<pair<int, uint8_t>>&data) {
-      uint8_t* message = reinterpret_cast<uint8_t*>(dq.dh);
-      const uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(message), dq.len);
-      if(length + data.size() > dq.size) {
+  g_lua.registerFunction<bool(DNSResponse::*)(std::string)>("setTrailingData", [](DNSResponse& dq, const std::string& tail) {
+      char* message = reinterpret_cast<char*>(dq.dh);
+      const uint16_t messageLen = getDNSPacketLength(message, dq.len);
+      const uint16_t tailLen = tail.size();
+      if(messageLen + tailLen > dq.size) {
         return false;
       }
 
-      /* Copy data from the Lua array, whose first index is 1 instead of 0. */
-      dq.len = length + data.size();
-      uint8_t* tail = message + length - 1;
-      for(const auto& pair : data) {
-        *(tail + pair.first) = pair.second;
+      /* Update length and copy data from the Lua string. */
+      dq.len = messageLen + tailLen;
+      if(tailLen > 0) {
+        tail.copy(message + messageLen, tailLen);
       }
       return true;
     });
index c5f2b7dcf141aa1a5203a12acad2a0cd4b796534..7d6febaab3e6025d6464386e87ad2d5601a2fa3e 100644 (file)
@@ -109,13 +109,13 @@ This state can be modified from the various hooks.
 
     :returns: A table of tags, using strings as keys and values
 
-  .. method:: DNSQuestion:getTrailingData() -> table
+  .. method:: DNSQuestion:getTrailingData() -> string
 
     .. versionadded:: 1.4.0
 
     Get all data following the DNS message.
 
-    :returns: A list of 8-bit integers
+    :returns: The trailing data as a null-safe string
 
   .. method:: DNSQuestion:sendTrap(reason)
 
@@ -142,13 +142,13 @@ This state can be modified from the various hooks.
 
     :param table tags: A table of tags, using strings as keys and values
 
-  .. method:: DNSQuestion:setTrailingData(bytes) -> bool
+  .. method:: DNSQuestion:setTrailingData(tail) -> bool
 
     .. versionadded:: 1.4.0
 
     Set the data following the DNS message, overwriting anything already present.
 
-    :param table bytes: The new data as a list of 8-bit integers
+    :param string tail: The new data
     :returns: true if the operation succeeded, false otherwise
 
 .. _DNSResponse:
index 33d23b1e9e7372da260ebc4c9663732372599dc6..3803e3797d76afe5644062e55f49db2211201957 100644 (file)
@@ -14,7 +14,7 @@ class TestTrailingDataToBackend(DNSDistTest):
     newServer{address="127.0.0.1:%s"}
 
     function replaceTrailingData(dq)
-        local success = dq:setTrailingData({65, 66, 67}) -- "ABC"
+        local success = dq:setTrailingData("ABC")
         if not success then
             return DNSAction.ServFail, ""
         end
@@ -24,7 +24,7 @@ class TestTrailingDataToBackend(DNSDistTest):
 
     function fillBuffer(dq)
         local available = dq.size - dq.len
-        local tail = extendTableBy({}, available)
+        local tail = string.rep("A", available)
         local success = dq:setTrailingData(tail)
         if not success then
             return DNSAction.ServFail, ""
@@ -35,7 +35,7 @@ class TestTrailingDataToBackend(DNSDistTest):
 
     function exceedBuffer(dq)
         local available = dq.size - dq.len
-        local tail = extendTableBy({}, available + 1)
+        local tail = string.rep("A", available + 1)
         local success = dq:setTrailingData(tail)
         if not success then
             return DNSAction.ServFail, ""
@@ -43,21 +43,6 @@ class TestTrailingDataToBackend(DNSDistTest):
         return DNSAction.None, ""
     end
     addLuaAction("limited.trailing.tests.powerdns.com.", exceedBuffer)
-
-    function extendTableBy(t, n)
-        if n <= 1 then
-            if n == 1 then
-                t[#t + 1] = 0
-            end
-            return t
-        end
-
-        local lower = math.floor(n / 2)
-        local upper = n - lower
-        t = extendTableBy(t, lower)
-        t = extendTableBy(t, upper)
-        return t
-    end
     """
     @classmethod
     def startResponders(cls):
@@ -192,7 +177,7 @@ class TestTrailingDataToDnsdist(DNSDistTest):
     addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
 
     function removeTrailingData(dq)
-        local success = dq:setTrailingData({})
+        local success = dq:setTrailingData("")
         if not success then
             return DNSAction.ServFail, ""
         end
@@ -201,14 +186,13 @@ class TestTrailingDataToDnsdist(DNSDistTest):
     addLuaAction("removed.trailing.tests.powerdns.com.", removeTrailingData)
 
     function reportTrailingData(dq)
-        local tailBytes = dq:getTrailingData()
-        local tailChars = string.char(unpack(tailBytes))
-        return DNSAction.Spoof, "-" .. tailChars .. ".echoed.trailing.tests.powerdns.com."
+        local tail = dq:getTrailingData()
+        return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com."
     end
     addLuaAction("echoed.trailing.tests.powerdns.com.", reportTrailingData)
 
     function replaceTrailingData(dq)
-        local success = dq:setTrailingData({65, 66, 67}) -- "ABC"
+        local success = dq:setTrailingData("ABC")
         if not success then
             return DNSAction.ServFail, ""
         end