return {redis.call('GET', KEYS[1] .. '_data'), ret}
end
- return false
+ return tonumber(ret)
]]
local redis_maybe_load_sha = nil
local n = rspamd_config:get_symbols_count() +
meta_functions.rspamd_count_metatokens()
- if torch then
+ if use_torch then
return true
else
if n ~= ann:get_inputs() then
fun.each(function(e) table.insert(fann_data, e) end, mt)
local score
- if torch then
+ if use_torch then
local out = fanns[id].fann:forward(torch.Tensor(fann_data))
score = out[1]
else
rspamd_logger.infox(task, 'fann score: %s', symscore)
if score > 0 then
- local result = rspamd_util.normalize_prob(score / 2.0, 0)
+ local result = score
+ if not use_torch then
+ result = rspamd_util.normalize_prob(score / 2.0, 0)
+ end
task:insert_result(rule.symbol_spam, result, symscore, id)
else
- local result = rspamd_util.normalize_prob((-score) / 2.0, 0)
+ local result = -(score)
+ if not use_torch then
+ result = rspamd_util.normalize_prob(-(score) / 2.0, 0)
+ end
task:insert_result(rule.symbol_ham, result, symscore, id)
end
end
end
local function create_fann(n, nlayers)
- if torch then
+ if use_torch then
-- We ignore number of layers so far when using torch
local ann = nn.Sequential()
local nhidden = math.floor((n + 1) / 2)
rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
return
else
- if torch then
+ if use_torch then
ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
else
ann = rspamd_fann.load_data(ann_data)
rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
prefix, train_mse)
local ann_data
- if torch then
+ if use_torch then
local f = torch.MemoryFile()
f:writeObject(fanns[elt].fann_train)
ann_data = rspamd_util.zstd_compress(f:storage():string())
{redis_locked_invalidate_sha, 1, prefix}
)
else
- if torch then
+ if use_torch then
-- For torch we do not need to mix samples as they would be flushed
local dataset = {}
fun.each(function(s)
elseif _data and type(_data) == 'table' then
load_or_invalidate_fann(rule, _data, elt, ev_base)
else
- rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s',
- type(_data), elt)
+ if type(_data) == 'number' then
+ -- no new version
+ else
+ rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
+ type(_data), elt)
+ end
end
end
for _,rule in pairs(settings.rules) do
rspamd_config:add_on_load(function(cfg, ev_base, worker)
load_scripts(cfg, ev_base, function(_, _)
- check_fanns(rule, cfg, ev_base)
+ return check_fanns(rule, cfg, ev_base)
end)
if worker:get_name() == 'controller' and worker:get_index() == 0 then