if not rspamd_fann.is_enabled() then
rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
'module is eventually disabled')
+
+ return
else
if not opts['fann_file'] then
- rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' ..
+ rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' ..
'`fann_file` to be specified')
else
fann_file = opts['fann_file']
end
end
end
+
+local redis_params
+local classifier_config = {
+ key = 'neural_net',
+ neurons = 200,
+ layers = 3,
+}
+
+local current_classify_ann = {
+ loaded = false,
+ version = 0,
+ spam_learned = 0,
+ ham_learned = 0
+}
+
+redis_params = rspamd_parse_redis_server('fann_scores')
+
+local function maybe_load_fann(task, continue_cb, call_if_fail)
+ local function load_fann()
+ local function redis_fann_load_cb(task, err, data)
+ if not err and type(data) == 'table' and type(data[2]) == 'string' then
+ local version = tonumber(data[1])
+ local ann_data = data[2]
+ local ann = rspamd_fann.load_data(ann_data)
+
+ if ann then
+ current_classify_ann.loaded = true
+ current_classify_ann.version = version
+ current_classify_ann.ann = ann
+ current_classify_ann.spam_learned = tonumber(data[3])
+ current_classify_ann.ham_learned = tonumber(data[4])
+ rspamd_logger.infox(task, "loaded fann classifier version %s", version)
+ continue_cb(task, true)
+ elseif call_if_fail then
+ continue_cb(task, false)
+ end
+ elseif call_if_fail then
+ continue_cb(task, false)
+ end
+ end
+
+ local key = classifier_config.key
+ local ret,_,_ = rspamd_redis_make_request(task,
+ redis_params, -- connect params
+ key, -- hash key
+ false, -- is write
+ redis_fann_load_cb, --callback
+ 'HMGET', -- command
+ {key, 'version', 'data', 'spam', 'ham'} -- arguments
+ )
+ end
+
+ local function check_fann()
+ local function redis_fann_check_cb(task, err, data)
+ if not err and type(data) == 'string' then
+ local version = tonumber(data)
+
+ if version == current_classify_ann.version then
+ continue_cb(task, true)
+ else
+ load_fann()
+ end
+ end
+ end
+
+ local key = classifier_config.key
+ local ret,_,_ = rspamd_redis_make_request(task,
+ redis_params, -- connect params
+ key, -- hash key
+ false, -- is write
+ redis_fann_check_cb, --callback
+ 'HGET', -- command
+ {key, 'version'} -- arguments
+ )
+ end
+
+ if not current_classify_ann.loaded then
+ load_fann()
+ else
+ check_fann()
+ end
+end
+
+local function tokens_to_vector(tokens)
+ local vec = map(function(tok) return tok[1] end, tokens)
+ local ret = {}
+ local neurons = classifier_config.neurons
+ for i = 1,neurons do
+ ret[i] = 0
+ end
+ each(function(e)
+ local n = (e % neurons) + 1
+ ret[n] = ret[n] + 1
+ end, vec)
+ for i = 1,neurons do
+ if ret[i] ~= 0 then
+ ret[i] = 1.0 / ret[i]
+ end
+ end
+
+ return ret
+end
+
+local function add_metatokens(task, vec)
+ local mt = gen_metatokens(task)
+ for _,tok in ipairs(mt) do
+ table.insert(vec, tok)
+ end
+end
+
+local function create_fann()
+ local layers = {}
+ local mt_size = count_metatokens()
+ local neurons = classifier_config.neurons + mt_size
+
+ for i = 1,classifier_config.layers - 1 do
+ layers[i] = math.floor(neurons / i)
+ end
+
+ table.insert(layers, 1)
+
+ local ann = rspamd_fann.create(classifier_config.layers, layers)
+ current_classify_ann.loaded = true
+ current_classify_ann.version = 0
+ current_classify_ann.ann = ann
+ current_classify_ann.spam_learned = 0
+ current_classify_ann.ham_learned = 0
+end
+
+local function save_fann(task, is_spam)
+ local function redis_fann_save_cb(task, err, data)
+ if err then
+ rspamd_logger.errx(task, "cannot save neural net to redis: %s", err)
+ end
+ end
+
+ local data = current_classify_ann.ann:data()
+ local key = classifier_config.key
+ current_classify_ann.version = current_classify_ann.version + 1
+
+ if is_spam then
+ current_classify_ann.spam_learned = current_classify_ann.spam_learned + 1
+ else
+ current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1
+ end
+ local ret,_,_ = rspamd_redis_make_request(task,
+ redis_params, -- connect params
+ key, -- hash key
+ true, -- is write
+ redis_fann_save_cb, --callback
+ 'HMSET', -- command
+ {
+ key,
+ 'version', tostring(current_classify_ann.version),
+ 'data', tostring(data),
+ 'spam', tostring(current_classify_ann.spam_learned),
+ 'ham', tostring(current_classify_ann.ham_learned),
+ } -- arguments
+ )
+end
+
+if redis_params then
+ rspamd_classifiers['neural'] = {
+ classify = function(task, classifier, tokens)
+ local function classify_cb(task)
+ local vec = tokens_to_vector(tokens)
+ add_metatokens(task, vec)
+ local out = current_classify_ann.ann:test(vec)
+ local result = rspamd_util.tanh(2 * (out[1] - 0.5))
+ local symscore = string.format('%.3f', out[1])
+ rspamd_logger.infox(task, 'fann classifier score: %s', symscore)
+
+ if result > 0 then
+ each(function(st)
+ task:insert_result(st:get_symbol(), result, symscore)
+ end,
+ filter(function(st)
+ return st:is_spam()
+ end, classifier:get_statfiles())
+ )
+ else
+ each(function(st)
+ task:insert_result(st:get_symbol(), -result, symscore)
+ end,
+ filter(function(st)
+ return not st:is_spam()
+ end, classifier:get_statfiles())
+ )
+ end
+ end
+ maybe_load_fann(task, classify_cb, false)
+ end,
+
+ learn = function(task, classifier, tokens, is_spam, is_unlearn)
+ local function learn_cb(task, is_loaded)
+ if not is_loaded then
+ create_fann()
+ end
+ local vec = tokens_to_vector(tokens)
+ add_metatokens(task, vec)
+ rspamd_logger.infox(task, "vector: %s", vec)
+ if is_spam then
+ current_classify_ann.ann:train(vec, {1.0})
+ else
+ current_classify_ann.ann:train(vec, {0.0})
+ end
+ save_fann(task, is_spam)
+ end
+ maybe_load_fann(task, learn_cb, true)
+ end,
+ }
+end