From: Richard Gibson Date: Tue, 16 Oct 2018 21:39:28 +0000 (-0400) Subject: dnsdist: Expose trailing data as a Lua string X-Git-Tag: rec-4.2.0-alpha1~16^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7d243a5e15b3d5d38f36d1883a4d4520c6d6d5c2;p=thirdparty%2Fpdns.git dnsdist: Expose trailing data as a Lua string --- diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index 147efb7ee0..2c7d774169 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -63,24 +63,24 @@ void setupLuaBindingsDNSQuestion() return *dq.ednsOptions; }); - g_lua.registerFunction(DNSQuestion::*)(void)>("getTrailingData", [](const DNSQuestion& dq) { - const uint8_t* message = reinterpret_cast(dq.dh); - const uint16_t length = getDNSPacketLength(reinterpret_cast(message), dq.len); - vector tail(message + length, message + dq.len); + g_lua.registerFunction("getTrailingData", [](const DNSQuestion& dq) { + const char* message = reinterpret_cast(dq.dh); + const uint16_t messageLen = getDNSPacketLength(message, dq.len); + const std::string tail = std::string(message + messageLen, dq.len - messageLen); return tail; }); - g_lua.registerFunction>)>("setTrailingData", [](DNSQuestion& dq, const vector>&data) { - uint8_t* message = reinterpret_cast(dq.dh); - const uint16_t length = getDNSPacketLength(reinterpret_cast(message), dq.len); - if(length + data.size() > dq.size) { + g_lua.registerFunction("setTrailingData", [](DNSQuestion& dq, const std::string& tail) { + char* message = reinterpret_cast(dq.dh); + const uint16_t messageLen = getDNSPacketLength(message, dq.len); + const uint16_t tailLen = tail.size(); + if(messageLen + tailLen > dq.size) { return false; } - /* Copy data from the Lua array, whose first index is 1 instead of 0. */ - dq.len = length + data.size(); - uint8_t* tail = message + length - 1; - for(const auto& pair : data) { - *(tail + pair.first) = pair.second; + /* Update length and copy data from the Lua string. */ + dq.len = messageLen + tailLen; + if(tailLen > 0) { + tail.copy(message + messageLen, tailLen); } return true; }); @@ -144,24 +144,24 @@ void setupLuaBindingsDNSQuestion() g_lua.registerFunction editFunc)>("editTTLs", [](const DNSResponse& dr, std::function editFunc) { editDNSPacketTTL((char*) dr.dh, dr.len, editFunc); }); - g_lua.registerFunction(DNSResponse::*)(void)>("getTrailingData", [](const DNSResponse& dq) { - const uint8_t* message = reinterpret_cast(dq.dh); - const uint16_t length = getDNSPacketLength(reinterpret_cast(message), dq.len); - vector tail(message + length, message + dq.len); + g_lua.registerFunction("getTrailingData", [](const DNSResponse& dq) { + const char* message = reinterpret_cast(dq.dh); + const uint16_t messageLen = getDNSPacketLength(message, dq.len); + const std::string tail = std::string(message + messageLen, dq.len - messageLen); return tail; }); - g_lua.registerFunction>)>("setTrailingData", [](DNSResponse& dq, const vector>&data) { - uint8_t* message = reinterpret_cast(dq.dh); - const uint16_t length = getDNSPacketLength(reinterpret_cast(message), dq.len); - if(length + data.size() > dq.size) { + g_lua.registerFunction("setTrailingData", [](DNSResponse& dq, const std::string& tail) { + char* message = reinterpret_cast(dq.dh); + const uint16_t messageLen = getDNSPacketLength(message, dq.len); + const uint16_t tailLen = tail.size(); + if(messageLen + tailLen > dq.size) { return false; } - /* Copy data from the Lua array, whose first index is 1 instead of 0. */ - dq.len = length + data.size(); - uint8_t* tail = message + length - 1; - for(const auto& pair : data) { - *(tail + pair.first) = pair.second; + /* Update length and copy data from the Lua string. */ + dq.len = messageLen + tailLen; + if(tailLen > 0) { + tail.copy(message + messageLen, tailLen); } return true; }); diff --git a/pdns/dnsdistdist/docs/reference/dq.rst b/pdns/dnsdistdist/docs/reference/dq.rst index c5f2b7dcf1..7d6febaab3 100644 --- a/pdns/dnsdistdist/docs/reference/dq.rst +++ b/pdns/dnsdistdist/docs/reference/dq.rst @@ -109,13 +109,13 @@ This state can be modified from the various hooks. :returns: A table of tags, using strings as keys and values - .. method:: DNSQuestion:getTrailingData() -> table + .. method:: DNSQuestion:getTrailingData() -> string .. versionadded:: 1.4.0 Get all data following the DNS message. - :returns: A list of 8-bit integers + :returns: The trailing data as a null-safe string .. method:: DNSQuestion:sendTrap(reason) @@ -142,13 +142,13 @@ This state can be modified from the various hooks. :param table tags: A table of tags, using strings as keys and values - .. method:: DNSQuestion:setTrailingData(bytes) -> bool + .. method:: DNSQuestion:setTrailingData(tail) -> bool .. versionadded:: 1.4.0 Set the data following the DNS message, overwriting anything already present. - :param table bytes: The new data as a list of 8-bit integers + :param string tail: The new data :returns: true if the operation succeeded, false otherwise .. _DNSResponse: diff --git a/regression-tests.dnsdist/test_Trailing.py b/regression-tests.dnsdist/test_Trailing.py index 33d23b1e9e..3803e3797d 100644 --- a/regression-tests.dnsdist/test_Trailing.py +++ b/regression-tests.dnsdist/test_Trailing.py @@ -14,7 +14,7 @@ class TestTrailingDataToBackend(DNSDistTest): newServer{address="127.0.0.1:%s"} function replaceTrailingData(dq) - local success = dq:setTrailingData({65, 66, 67}) -- "ABC" + local success = dq:setTrailingData("ABC") if not success then return DNSAction.ServFail, "" end @@ -24,7 +24,7 @@ class TestTrailingDataToBackend(DNSDistTest): function fillBuffer(dq) local available = dq.size - dq.len - local tail = extendTableBy({}, available) + local tail = string.rep("A", available) local success = dq:setTrailingData(tail) if not success then return DNSAction.ServFail, "" @@ -35,7 +35,7 @@ class TestTrailingDataToBackend(DNSDistTest): function exceedBuffer(dq) local available = dq.size - dq.len - local tail = extendTableBy({}, available + 1) + local tail = string.rep("A", available + 1) local success = dq:setTrailingData(tail) if not success then return DNSAction.ServFail, "" @@ -43,21 +43,6 @@ class TestTrailingDataToBackend(DNSDistTest): return DNSAction.None, "" end addLuaAction("limited.trailing.tests.powerdns.com.", exceedBuffer) - - function extendTableBy(t, n) - if n <= 1 then - if n == 1 then - t[#t + 1] = 0 - end - return t - end - - local lower = math.floor(n / 2) - local upper = n - lower - t = extendTableBy(t, lower) - t = extendTableBy(t, upper) - return t - end """ @classmethod def startResponders(cls): @@ -192,7 +177,7 @@ class TestTrailingDataToDnsdist(DNSDistTest): addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction()) function removeTrailingData(dq) - local success = dq:setTrailingData({}) + local success = dq:setTrailingData("") if not success then return DNSAction.ServFail, "" end @@ -201,14 +186,13 @@ class TestTrailingDataToDnsdist(DNSDistTest): addLuaAction("removed.trailing.tests.powerdns.com.", removeTrailingData) function reportTrailingData(dq) - local tailBytes = dq:getTrailingData() - local tailChars = string.char(unpack(tailBytes)) - return DNSAction.Spoof, "-" .. tailChars .. ".echoed.trailing.tests.powerdns.com." + local tail = dq:getTrailingData() + return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com." end addLuaAction("echoed.trailing.tests.powerdns.com.", reportTrailingData) function replaceTrailingData(dq) - local success = dq:setTrailingData({65, 66, 67}) -- "ABC" + local success = dq:setTrailingData("ABC") if not success then return DNSAction.ServFail, "" end