local rspamd_fann = require "rspamd_fann"
local rspamd_util = require "rspamd_util"
local fann_symbol = 'FANN_SCORE'
+require "fun" ()
local ucl = require "ucl"
-- Module vars
-local fann
-local fann_train
+local fann = nil
+local fann_train = nil
local fann_file
local ntrains = 0
local max_trains = 1000
local fann_mtime = 0
local opts = rspamd_config:get_all_opt("fann_scores")
+local function symbols_to_fann_vector(syms)
+ local learn_data = {}
+ local matched_symbols = {}
+ local n = rspamd_config:get_symbols_count()
+
+ each(function(s)
+ matched_symbols[s + 1] = 1
+ end, syms)
+
+ for i=1,n do
+ if matched_symbols[i] then
+ learn_data[i] = 1
+ else
+ learn_data[i] = 0
+ end
+ end
+
+ return learn_data
+end
+
local function load_fann()
local err,st = rspamd_util.stat(fann_file)
end
local function check_fann()
- local n = rspamd_config:get_symbols_count()
-
if fann then
local n = rspamd_config:get_symbols_count()
check_fann()
if fann then
- local fann_input = {}
-
- for sym,idx in pairs(symbols) do
- if task:has_symbol(sym) then
- fann_input[idx + 1] = 1
- else
- fann_input[idx + 1] = 0
- end
- end
+ local symbols = task:get_symbols_numeric()
+ local fann_data = symbols_to_fann_vector(symbols)
- local out = fann:test(nsymbols, fann_input)
+ local out = fann:test(fann_data)
local result = rspamd_util.tanh(2 * (out[1] - 0.5))
local symscore = string.format('%.3f', out[1])
rspamd_logger.infox(task, 'fann score: %s', symscore)
epoch = 0
end
-local function fann_train(score, required_score,results, cf, opts)
+local function fann_train_callback(score, required_score,results, cf, opts)
local n = cf:get_symbols_count()
if not fann_train then
end
if learn_spam or learn_ham then
- local learn_data = {}
- local matched_symbols = {}
-
- for _,sym in ipairs(results) do
- matched_symbols[sym[1] + 1] = 1
- end
-
- for i=1,(n + 1) do
- if matched_symbols[i] then
- learn_data[i] = 1
- else
- learn_data[i] = 0
- end
- end
+ local learn_data = symbols_to_fann_vector(
+ map(function(r) return r[1] end, results)
+ )
if learn_spam then
- fann_train:train(learn_data, 1.0)
+ fann_train:train(learn_data, {1.0})
else
- fann_train:train(learn_data, 0.0)
+ fann_train:train(learn_data, {0.0})
end
- trains = trains + 1
+ ntrains = ntrains + 1
end
end
max_trains = opts['train']['max_epoch']
end
cfg:register_worker_script("log_helper", function(score, req_score, results, cf)
- fann_train(score, req_score, results, cf, opts['train'])
+ fann_train_callback(score, req_score, results, cf, opts['train'])
end)
end)
end