]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
new serialization method krprint.serialize_lua
authorPetr Špaček <petr.spacek@nic.cz>
Fri, 10 Jul 2020 12:32:25 +0000 (14:32 +0200)
committerTomas Krizek <tomas.krizek@nic.cz>
Mon, 26 Oct 2020 13:25:13 +0000 (14:25 +0100)
Serializes: boolean, nil, number, string, table.
Skips all other types (functions, cdata, thread ...) and repeated
references to tables.

Resulting string should Lua-evaluate to identical objects.

daemon/lua/krprint.lua [new file with mode: 0644]
daemon/lua/krprint.test.lua [new file with mode: 0644]
daemon/lua/meson.build
daemon/lua/sandbox.lua.in

diff --git a/daemon/lua/krprint.lua b/daemon/lua/krprint.lua
new file mode 100644 (file)
index 0000000..055c889
--- /dev/null
@@ -0,0 +1,169 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+
+local serializer_class = {
+       __inst_mt = {}
+}
+-- class instances with following metatable inherit all class members
+serializer_class.__inst_mt.__index = serializer_class
+
+-- constructor
+function serializer_class.new(unrepresentable)
+       unrepresentable = unrepresentable or 'comment'
+       if not (unrepresentable == 'comment'
+               or unrepresentable == 'error') then
+               error('unsupported val2expr unrepresentable option ' .. tostring(unrepresentable))
+       end
+       local inst = {}
+       inst.unrepresentable = unrepresentable
+       inst.done = {}
+       setmetatable(inst, serializer_class.__inst_mt)
+       return inst
+end
+
+-- format comment with leading/ending whitespace if needed
+local function format_note(note, ws_prefix, ws_suffix)
+       if note == nil then
+               return ''
+       else
+               return string.format('%s--[[ %s ]]%s',
+                       ws_prefix or '', note, ws_suffix or '')
+       end
+end
+
+local function static_serializer(val, unrepresentable)
+       local inst = serializer_class.new(unrepresentable)
+       local expr, note = inst:val2expr(val)
+       return string.format('%s%s', format_note(note, nil, ' '), expr)
+end
+
+function serializer_class.val2expr(self, val)
+       local val_type = type(val)
+       local val_repr = self[val_type]
+       if val_repr then
+               return val_repr(self, val)
+       else  -- function, thread, userdata
+               if self.unrepresentable == 'comment' then
+                       return 'nil', string.format('missing %s', val)
+               elseif self.unrepresentable == 'error' then
+                       error(string.format('cannot print %s', val_type), 2)
+               end
+       end
+end
+
+serializer_class['nil'] = function(_, val)
+       assert(type(val) == 'nil')
+       return 'nil'
+end
+
+function serializer_class.number(_, val)
+       assert(type(val) == 'number')
+       if val == math.huge then
+               return 'math.huge'
+       elseif val == -math.huge then
+               return '-math.huge'
+       elseif tostring(val) == 'nan' then
+               return 'tonumber(\'nan\')'
+       else
+               return string.format("%.60f", val)
+       end
+end
+
+function serializer_class.string(_, val)
+       assert(type(val) == 'string')
+       val = tostring(val)
+       local chars = {'\''}
+       for i = 1, #val do
+               local c = string.byte(val, i)
+               -- ASCII (from space to ~) and not ' or \
+               if (c >= 0x20 and c < 0x7f)
+                       and c ~= 0x27 and c ~= 0x5C then
+                       table.insert(chars, string.char(c))
+               else
+                       table.insert(chars, string.format('\\%03d', c))
+               end
+       end
+       table.insert(chars, '\'')
+       return table.concat(chars)
+end
+
+function serializer_class.boolean(_, val)
+       assert(type(val) == 'boolean')
+       return tostring(val)
+end
+
+function serializer_class.table(self, tab)
+       assert(type(tab) == 'table')
+       if self.done[tab] then
+               error('cyclic reference', 0)
+       end
+       self.done[tab] = true
+
+       local items = {'{'}
+       local previdx = 0
+       for idx, val in pairs(tab) do
+               local errors, valok, valexpr, valnote, idxok, idxexpr, idxnote
+               errors = {}
+               valok, valexpr, valnote = pcall(self.val2expr, self, val)
+               if not valok then
+                       table.insert(errors, string.format('value: %s', valexpr))
+               end
+
+               local addidx
+               if previdx and type(idx) == 'number' and idx - 1 == previdx then
+                       -- monotonic sequence, do not print key
+                       previdx = idx
+                       addidx = false
+               else
+                       -- end of monotonic sequence
+                       -- from now on print keys as well
+                       previdx = nil
+                       addidx = true
+               end
+
+               if addidx then
+                       idxok, idxexpr, idxnote = pcall(self.val2expr, self, idx)
+                       if not idxok or idxexpr == 'nil' then
+                               table.insert(errors, string.format('key: not serializable', idxexpr))
+                       end
+               end
+
+               if #errors == 0 then
+                       -- finally serialize one [key=]?value expression
+                       if addidx then
+                               table.insert(items,
+                                       string.format('%s[%s]', format_note(idxnote, nil, ' '), idxexpr))
+                               table.insert(items, '=')
+                       end
+                       table.insert(items, string.format('%s%s,', format_note(valnote, nil, ' '), valexpr))
+               else
+                       local errmsg = string.format('%s = %s (%s)',
+                               tostring(idx),
+                               tostring(val),
+                               table.concat(errors, ', '))
+                       if self.unrepresentable == 'error' then
+                               error(errmsg, 0)
+                       else
+                               errmsg = string.format('--[[ missing %s ]]', errmsg)
+                               table.insert(items, errmsg)
+                       end
+               end
+       end  -- one key+value
+       table.insert(items, '}')
+       return table.concat(items, ' '), string.format('%s follows', tab)
+end
+
+local function deserialize_lua(serial)
+       assert(type(serial) == 'string')
+       local deserial_func = loadstring('return ' .. serial)
+       if type(deserial_func) ~= 'function' then
+               panic('input is not a valid Lua expression')
+       end
+       return deserial_func()
+end
+
+local M = {
+       serialize_lua = static_serializer,
+       deserialize_lua = deserialize_lua
+}
+
+return M
diff --git a/daemon/lua/krprint.test.lua b/daemon/lua/krprint.test.lua
new file mode 100644 (file)
index 0000000..ddd39cb
--- /dev/null
@@ -0,0 +1,223 @@
+local serialize_lua = require('krprint').serialize_lua
+local deserialize_lua = require('krprint').deserialize_lua
+
+local function gen_string(maxlen)
+       maxlen = maxlen or 100
+       local len = math.random(0, maxlen)
+       local buf = {}
+       for _=1,len do
+               table.insert(buf, string.char(math.random(0, 255)))
+       end
+       return table.concat(buf)
+end
+
+local function test_de_serialization(orig_val, desc)
+       local serial = serialize_lua(orig_val)
+       ok(type(serial) == 'string' and #serial > 0,
+               'serialization returns non-empty string: ' .. desc)
+       local deserial_val = deserialize_lua(serial)
+       same(type(orig_val), type(deserial_val),
+               'deserialized value has the same type: ' .. desc)
+       if type(orig_val) == 'number' then
+               -- nan cannot be compared using == operator
+               if tostring(orig_val) == 'nan' and tostring(deserial_val) == 'nan' then
+                       pass('nan value serialized and deserialized')
+               elseif orig_val ~= math.huge and orig_val ~= -math.huge then
+               -- tolerance measured experimentally on x86_64 LuaJIT 2.1.0-beta3
+                       local tolerance = 1e-14
+                       ok(math.abs(orig_val - deserial_val) <= tolerance,
+                               'deserialized number is within tolerance ' .. tolerance)
+               else
+                       same(orig_val, deserial_val, 'deserialization returns the same infinity:' .. desc)
+               end
+       else
+               same(orig_val, deserial_val,
+                       'deserialization returns the same value: ' .. desc)
+       end
+end
+
+local function test_de_serialization_autodesc(orig_val)
+       test_de_serialization(orig_val, tostring(orig_val))
+end
+
+local function test_bool()
+       test_de_serialization_autodesc(true)
+       test_de_serialization_autodesc(false)
+end
+
+local function test_nil()
+       test_de_serialization_autodesc(nil)
+end
+
+local function gen_number_int()
+       local number
+       -- make "small" numbers more likely so they actually happen
+       if math.random() < 0.5 then
+               number = math.random(-2^32, 2^32)
+       else
+               number = math.random(-2^48, 2^48)
+       end
+       return number
+end
+
+local function gen_number_float()
+       return math.random()
+end
+
+local function test_number()
+       test_de_serialization_autodesc(0)
+       test_de_serialization_autodesc(-math.huge)
+       test_de_serialization_autodesc(math.huge)
+       test_de_serialization_autodesc(tonumber('nan'))
+       for _=1,100 do  -- integers
+               test_de_serialization_autodesc(gen_number_int())
+       end
+       for _=1,100 do  -- floats
+               test_de_serialization_autodesc(gen_number_float())
+       end
+end
+
+local function test_string()
+       test_de_serialization('', 'empty string')
+       for _=1,100 do
+               local str = gen_string(1024*10)
+               test_de_serialization(str, 'random string length ' .. #str)
+       end
+end
+
+local function gen_number()
+       -- pure random would not produce special cases often enough
+       local generators = {
+               function() return 0 end,
+               function() return -math.huge end,
+               function() return math.huge end,
+               gen_number_int,
+               gen_number_float,
+       }
+       return generators[math.random(1, #generators)]()
+end
+
+local function gen_boolean()
+       local options = {true, false}
+       return options[math.random(1, #options)]
+end
+
+local function gen_table_atomic()
+       -- nil keys or values are not allowed
+       -- nested tables are handled elsewhere
+       local supported_types = {
+               gen_number,
+               gen_string,
+               gen_boolean,
+       }
+       val = supported_types[math.random(1, #supported_types)]()
+       return val
+end
+
+local function gen_test_tables_supported(level)
+       level = level or 1
+       local max_level = 10
+       local max_items_per_table = 30
+       local t = {}
+       for _=1, math.random(0, max_items_per_table) do
+               local val_as_table = (level <= max_level) and math.random() < 0.1
+               local key, val
+               -- tapered.same method cannot compare keys with type table
+               key = gen_table_atomic()
+               if val_as_table then
+                       val = gen_test_tables_supported(level + 1)
+               else
+                       val = gen_table_atomic()
+               end
+               t[key] = val
+       end
+       return t
+end
+
+function test_table_supported()
+       for i=1,100 do
+               local t = gen_test_tables_supported()
+               test_de_serialization(t, 'random table no. ' .. i)
+       end
+end
+
+local ffi = require('ffi')
+local const_func = tostring
+local const_thread = coroutine.create(tostring)
+local const_userdata = ffi.C
+local const_cdata = ffi.new('int')
+
+local function gen_unsupported_atomic()
+       -- nested tables are handled elsewhere
+       local unsupported_types = {
+               const_func,
+               const_thread,
+               const_userdata,
+               const_cdata
+       }
+       val = unsupported_types[math.random(1, #unsupported_types)]
+       return val
+end
+
+local function test_unsupported(val, desc)
+       desc = desc or string.format('unsupported %s', type(val))
+       return function()
+               boom(serialize_lua, { val, 'error' }, string.format(
+                       'attempt to serialize %s in error mode '
+                       .. 'causes error', desc))
+               local output = serialize_lua(val, 'comment')
+               same('string', type(output),
+                       string.format('attempt to serialize %s in '
+                               .. 'comment mode provides returned a string',
+                               desc))
+               ok(string.find(output, '--'), 'returned string contains a comment')
+       end
+end
+
+local kluautil = require('kluautil')
+local function make_table_unsupported(t, always)
+       local tab_len = kluautil.kr_table_len(t)
+       local modified = false
+       -- modify some values
+       for key, val in pairs(t) do
+               if math.random(1, tab_len) == 1 then
+                       if type(val) == 'table' then
+                               modified = modifier or make_table_unsupported(val, false)
+                       else
+                               t[key] = gen_unsupported_atomic()
+                               modified = true
+                       end
+               end
+       end
+       if always and not modified then
+               -- fallback, add an unsupported key
+               t[gen_unsupported_atomic()] = true
+       end
+       return modified
+end
+
+local function gen_test_tables_unsupported()
+       local t = gen_test_tables_supported()
+       make_table_unsupported(t, true)
+       return t
+end
+
+local function test_unsupported_table()
+       for i=1,100 do
+               local t = gen_test_tables_unsupported()
+               test_unsupported(t, 'random unsupported table no. ' .. i)()
+       end
+end
+
+return {
+       test_bool,
+       test_nil,
+       test_number,
+       test_string,
+       test_table_supported,
+       test_unsupported(const_func),
+       test_unsupported(const_thread),
+       test_unsupported(const_userdata),
+       test_unsupported(const_cdata),
+       test_unsupported_table
+}
index b0e4ace6ca3492358521208c658fc650c2d699c3..968c954c79f9999055cc83bbd60fecb44cfe97ce 100644 (file)
@@ -3,6 +3,7 @@
 
 config_tests += [
   ['controlsock', files('controlsock.test.lua')],
+  ['krprint', files('krprint.test.lua')],
   ['ta', files('trust_anchors.test/ta.test.lua')],
   ['ta_bootstrap', files('trust_anchors.test/bootstrap.test.lua')],
 ]
@@ -45,6 +46,7 @@ lua_src = [
   trust_anchors,
   files('zonefile.lua'),
   files('kluautil.lua'),
+  files('krprint.lua'),
   distro_preconfig,
 ]
 
index a02584d22cde266296ab18788aa77a258f9ce8f8..1fa31906efd074eab2af709c14cc4166f1e73438 100644 (file)
@@ -645,6 +645,7 @@ end
 
 -- Global commands for map()
 
+local krprint = require("krprint")
 function map(cmd, format)
        local socket = require('cqueues.socket')
        local kluautil = require('kluautil')
@@ -666,7 +667,7 @@ function map(cmd, format)
 
        local filetab = kluautil.list_dir(worker.control_path)
        if next(filetab) == nil then
-               local ret = eval_cmd(cmd, true)
+               local ret = eval_cmd(cmd, format == 'luaobj')
                if ret == nil then
                        results = {}
                else
@@ -684,7 +685,9 @@ function map(cmd, format)
                end
 
                if local_exec then
-                       table.insert(results, eval_cmd(cmd, true))
+                       local ret = eval_cmd(cmd, format == 'luaobj')
+                       -- crop to a single return value similarly to original map()
+                       table.insert(results, ret)
                else
                        local s = socket.connect({ path = worker.control_path..file })
                        s:setmode('bn', 'bn')
@@ -695,9 +698,10 @@ function map(cmd, format)
                                s:write('__binary\n')
                                recv = s:read(2)
                                if format == 'luaobj' then
-                                       cmd = 'tojson('..cmd..')'
+                                       s:write('require("krprint").serialize_lua('..cmd..')\n')
+                               else
+                                       s:write(cmd..'\n')
                                end
-                               s:write(cmd..'\n')
                                local recv = s:read(4)
                                local len = tonumber(recv:byte(1))
                                for i=2,4 do
@@ -707,7 +711,7 @@ function map(cmd, format)
                                if format == 'strings' then
                                        table.insert(results, recv)
                                else
-                                       table.insert(results, fromjson(recv))
+                                       table.insert(results, krprint.deserialize_lua(recv))
                                end
 
                                s:close()