' is found in the cache', data[id].fann:get_inputs(), n)
data[id].fann = nil
end
+ local layers = data[id].fann:get_layers()
+
+ if not layers or #layers ~= 5 then
+ rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
+ #layers)
+ data[id].fann = nil
+ end
end
local fname = gen_fann_file(id)
end
local out = data[id].fann:test(fann_data)
- local result = rspamd_util.tanh(2 * (out[1] - 0.5))
local symscore = string.format('%.3f', out[1])
rspamd_logger.infox(task, 'fann score: %s', symscore)
- if result > 0 then
+ if out[1] > 0 then
+ local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
task:insert_result(fann_symbol_spam, result, symscore, id)
else
- task:insert_result(fann_symbol_ham, -(result), symscore, id)
+ local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+ task:insert_result(fann_symbol_ham, result, symscore, id)
end
else
if load_fann(id) then
end
local function create_train_fann(n, id)
- data[id].fann_train = rspamd_fann.create(3, n, n / 2, 1)
+ data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
data[id].ntrains = 0
data[id].epoch = 0
end
if learn_spam then
data[id].fann_train:train(learn_data, {1.0})
else
- data[id].fann_train:train(learn_data, {0.0})
+ data[id].fann_train:train(learn_data, {-1.0})
end
data[id].ntrains = data[id].ntrains + 1