]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
map: rework result handling
authorPetr Špaček <petr.spacek@nic.cz>
Wed, 14 Oct 2020 09:08:43 +0000 (11:08 +0200)
committerTomas Krizek <tomas.krizek@nic.cz>
Mon, 26 Oct 2020 13:25:14 +0000 (14:25 +0100)
map() command on leader instance now:
- detects call errors on followers
- detects unsupported number of return values
- detects unsupported data types which cannot be serialized
- keeps nil return values (signaled by table counter "n")

Fixes: #662
daemon/lua/sandbox.lua.in

index 60c71ac7d2a6471e1ab189d0e212db290458fad0..1b0632e2096bab609fb4479ece1eecca0c0f6ead 100644 (file)
@@ -2,6 +2,8 @@
 
 local debug = require('debug')
 local ffi = require('ffi')
+local kluautil = require('kluautil')
+local krprint = require("krprint")
 
 -- Units
 kB = 1024
@@ -650,21 +652,43 @@ end
 
 -- Global commands for map()
 
-local krprint = require("krprint")
+-- must be public because it is called from eval_cmd()
+-- when map() commands are read from control socket
+function _map_luaobj_call_wrapper(cmd)
+       local func = eval_cmd_compile(cmd, true)
+       local ret = kluautil.kr_table_pack(xpcall(func, debug.traceback))
+       local ok, serial = pcall(krprint.serialize_lua, ret, 'error')
+       if not ok then
+               return krprint.serialize_lua(
+                       kluautil.kr_table_pack(false, "returned values cannot be serialized: "
+                               .. serial))
+       else
+               return serial
+       end
+end
+
 function map(cmd, format)
        local socket = require('cqueues.socket')
-       local kluautil = require('kluautil')
        local bit = require("bit")
        local local_sockets = {}
        local results = {}
 
-       format = format or 'luaobj'
        if (type(cmd) ~= 'string') then
                panic('map() command must be a string') end
        if (#cmd <= 0) then
                panic('map() command must be non-empty') end
+       -- syntax check on input command to detect typos early
+       local chunk, err = eval_cmd_compile(cmd, false)
+       if not chunk then
+               panic('failure when compiling map() command: %s', err)
+       end
+
+       format = format or 'luaobj'
        if (format ~= 'luaobj' and format ~= 'strings') then
                panic('map() output format must be luaobj or strings') end
+       if format == 'luaobj' then
+               cmd = '_map_luaobj_call_wrapper([=====[' .. cmd .. ']=====])'
+       end
 
        -- find out control socket paths
        for _,v in pairs(net.list()) do
@@ -678,55 +702,72 @@ function map(cmd, format)
                        worker.control_path)
        end
 
-       -- validate input command to detect typos early
-       local chunk, err = eval_cmd_compile(cmd, false)
-       if not chunk then
-               panic('failure when compiling map() command: %s', err)
-       end
-
+       local result_count = 0
        -- finally execute it on all instances
-       for _,file in ipairs(filetab) do
+       for _, file in ipairs(filetab) do
                local local_exec = false
-               for _,lsoc in ipairs(local_sockets) do
+               for _, lsoc in ipairs(local_sockets) do
                        if file == lsoc then
                                local_exec = true
                        end
                end
+               local path = worker.control_path..file
+               local path_name = (local_exec and 'this instance') or path
+               if verbose() then
+                       log('executing map() on %s: command %s', path_name, cmd)
+               end
 
+               local ret
                if local_exec then
-                       local ret = eval_cmd(cmd, format == 'luaobj')
-                       -- crop to a single return value similarly to original map()
-                       table.insert(results, ret)
+                       ret = eval_cmd(cmd)
                else
-                       local s = socket.connect({ path = worker.control_path..file })
+                       local s = socket.connect({ path = path })
                        s:setmode('bn', 'bn')
                        local status, err = pcall(s.connect, s)
                        if not status then
-                               print(err)
+                               log('map() error while connecting to control socket %s: '
+                                       .. '%s (ignoring this socket)', path, err)
+                               goto continue
                        else
                                s:write('__binary\n')
-                               recv = s:read(2)
-                               if format == 'luaobj' then
-                                       s:write('require("krprint").serialize_lua('..cmd..')\n')
-                               else
-                                       s:write(cmd..'\n')
-                               end
-                               local recv = s:read(4)
+                               local recv = s:read(2)
+                               assert(recv == '> ', 'map() protocol error, undefined state')
+                               s:write(cmd..'\n')
+                               recv = s:read(4)
                                local len = tonumber(recv:byte(1))
                                for i=2,4 do
                                        len = bit.bor(bit.lshift(len, 8), tonumber(recv:byte(i)))
                                end
-                               recv = s:read(len)
-                               if format == 'strings' then
-                                       table.insert(results, recv)
-                               else
-                                       table.insert(results, krprint.deserialize_lua(recv))
-                               end
-
+                               ret = s:read(len)
                                s:close()
                        end
                end
+               result_count = result_count + 1
+               if format == 'luaobj' then
+                       ret = krprint.deserialize_lua(ret)
+                       -- ret is now table with xpcall results
+                       assert(type(ret) == 'table', 'map() protocol error, '
+                               .. 'table with results not retured by follower')
+                       if (ret.n ~= 2) then
+                               panic('unexpected number of return values in map() response: '
+                                       .. 'only single return value is allowed, '
+                                       .. 'use kluautil.kr_table_pack() helper')
+                       end
+                       local ok, retval = ret[1], ret[2]
+                       if ok == false then
+                               panic('error when executing map() command on control socket %s: '
+                                       .. '%s. command execution state is now undefined!',
+                                       path, retval)
+                       end
+                       -- drop wrapper table and return only the actual return value
+                       ret = retval
+               else
+                       assert(type(ret) == 'string', 'map() protocol error, '
+                               .. 'string not retured by follower')
+               end
+               results[result_count] = ret
+               ::continue::
        end
-
+       results.n = result_count
        return results
 end