]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
map: thorough error checking on control socket operations
authorPetr Špaček <petr.spacek@nic.cz>
Tue, 20 Oct 2020 11:04:23 +0000 (13:04 +0200)
committerTomas Krizek <tomas.krizek@nic.cz>
Mon, 26 Oct 2020 13:25:15 +0000 (14:25 +0100)
daemon/lua/sandbox.lua.in

index 8d7fb45b50c46196fe043ad855144179044e875b..e08c69b53896bc025e4a7c72807609c479e8194b 100644 (file)
@@ -671,9 +671,74 @@ function _map_luaobj_call_wrapper(cmd)
        end
 end
 
-function map(cmd, format)
+local function _sock_errmsg(path, desc)
+       return string.format(
+               'map() error while communicating with %s: %s',
+               path, desc)
+end
+
+local function _sock_check(sock, call, params, path, desc)
+       local errprefix = _sock_errmsg(path, desc) .. ': '
+       local retvals = kluautil.kr_table_pack(pcall(call, unpack(params)))
+       local ok = retvals[1]
+       if not ok then
+               error(errprefix .. tostring(retvals[2]))
+       end
+       local rerr, werr = sock:error()
+       if rerr or werr then
+               error(string.format('%sread error %s; write error %s', errprefix, rerr, werr))
+       end
+       if retvals[2] == nil then
+               error(errprefix .. 'unexpected nil result')
+       end
+       return unpack(retvals, 2, retvals.n)
+end
+
+local function _sock_assert(condition, path, desc)
+       if not condition then
+               error(_sock_errmsg(path, desc))
+       end
+end
+
+local function map_send_recv(cmd, path)
+       local bit = require('bit')
        local socket = require('cqueues.socket')
-       local bit = require("bit")
+       local s = socket.connect({ path = path })
+       s:setmaxerrs(0)
+       s:setmode('bn', 'bn')
+       local status, err = pcall(s.connect, s)
+       if not status then
+               log('map() error while connecting to control socket %s: '
+                       .. '%s (ignoring this socket)', path, err)
+               return nil
+       end
+       local ret = _sock_check(s, s.write, {s, '__binary\n'}, path,
+               'write __binary')
+       _sock_assert(ret, path,
+               'write __binary result')
+       local recv = _sock_check(s, s.read, {s, 2}, path,
+               'read reply to __binary')
+       _sock_assert(recv and recv == '> ', path,
+               'unexpected reply to __binary')
+       _sock_check(s, s.write, {s, cmd..'\n'}, path,
+               'command write')
+       recv = _sock_check(s, s.read, {s, 4}, path,
+               'response length read')
+       _sock_assert(recv and #recv == 4, path,
+               'length of response length preambule does not match')
+       local len = tonumber(recv:byte(1))
+       for i=2,4 do
+               len = bit.bor(bit.lshift(len, 8), tonumber(recv:byte(i)))
+       end
+       ret = _sock_check(s, s.read, {s, len}, path,
+               'read response')
+       _sock_assert(ret and #ret == len, path,
+               'actual response length does not match length in preambule')
+       s:close()
+       return ret
+end
+
+function map(cmd, format)
        local local_sockets = {}
        local results = {}
 
@@ -720,30 +785,14 @@ function map(cmd, format)
                if verbose() then
                        log('executing map() on %s: command %s', path_name, cmd)
                end
-
                local ret
                if local_exec then
                        ret = eval_cmd(cmd)
                else
-                       local s = socket.connect({ path = path })
-                       s:setmode('bn', 'bn')
-                       local status, err = pcall(s.connect, s)
-                       if not status then
-                               log('map() error while connecting to control socket %s: '
-                                       .. '%s (ignoring this socket)', path, err)
+                       ret = map_send_recv(cmd, path)
+                       -- skip dead sockets (leftovers from dead instances)
+                       if ret == nil then
                                goto continue
-                       else
-                               s:write('__binary\n')
-                               local recv = s:read(2)
-                               assert(recv == '> ', 'map() protocol error, undefined state')
-                               s:write(cmd..'\n')
-                               recv = s:read(4)
-                               local len = tonumber(recv:byte(1))
-                               for i=2,4 do
-                                       len = bit.bor(bit.lshift(len, 8), tonumber(recv:byte(i)))
-                               end
-                               ret = s:read(len)
-                               s:close()
                        end
                end
                result_count = result_count + 1