current_classify_ann.loaded = true
current_classify_ann.version = version
current_classify_ann.ann = ann
- current_classify_ann.spam_learned = tonumber(data[3])
- current_classify_ann.ham_learned = tonumber(data[4])
- rspamd_logger.infox(task, "loaded fann classifier version %s", version)
+ if type(data[3]) == 'string' then
+ current_classify_ann.spam_learned = tonumber(data[3])
+ else
+ current_classify_ann.spam_learned = 0
+ end
+ if type(data[4]) == 'string' then
+ current_classify_ann.ham_learned = tonumber(data[4])
+ else
+ current_classify_ann.ham_learned = 0
+ end
+ rspamd_logger.infox(task, "loaded fann classifier version %s (%s spam, %s ham), %s MSE",
+ version, current_classify_ann.spam_learned,
+ current_classify_ann.ham_learned,
+ ann:get_mse())
continue_cb(task, true)
elseif call_if_fail then
continue_cb(task, false)
if not err and type(data) == 'string' then
local version = tonumber(data)
- if version == current_classify_ann.version then
+ if version <= current_classify_ann.version then
continue_cb(task, true)
else
load_fann()
end
local function tokens_to_vector(tokens)
- local vec = map(function(tok) return tok[1] end, tokens)
+ local vec = totable(map(function(tok) return tok[1] end, tokens))
local ret = {}
+ local ntok = #vec
local neurons = classifier_config.neurons
for i = 1,neurons do
ret[i] = 0
local n = (e % neurons) + 1
ret[n] = ret[n] + 1
end, vec)
+ local norm = 0
+ for i = 1,neurons do
+ if ret[i] > norm then
+ norm = ret[i]
+ end
+ end
for i = 1,neurons do
- if ret[i] ~= 0 then
- ret[i] = 1.0 / ret[i]
+ if ret[i] ~= 0 and norm > 0 then
+ ret[i] = ret[i] / norm
end
end
else
current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1
end
- local ret,_,_ = rspamd_redis_make_request(task,
+ local ret,conn,_ = rspamd_redis_make_request(task,
redis_params, -- connect params
key, -- hash key
true, -- is write
'HMSET', -- command
{
key,
- 'version', tostring(current_classify_ann.version),
'data', rspamd_util.zstd_compress(data),
- 'spam', tostring(current_classify_ann.spam_learned),
- 'ham', tostring(current_classify_ann.ham_learned),
- } -- arguments
- )
+ }) -- arguments
+
+ if conn then
+ conn:add_cmd('HINCRBY', {key, 'version', 1})
+ if is_spam then
+ conn:add_cmd('HINCRBY', {key, 'spam', 1})
+ rspamd_logger.errx(task, 'hui')
+ else
+ conn:add_cmd('HINCRBY', {key, 'ham', 1})
+ rspamd_logger.errx(task, 'pezda')
+ end
+ end
end
if redis_params then
local vec = tokens_to_vector(tokens)
add_metatokens(task, vec)
local out = current_classify_ann.ann:test(vec)
- local result = rspamd_util.tanh(2 * (out[1] - 0.5))
+ local result = rspamd_util.tanh(2 * (out[1]))
local symscore = string.format('%.3f', out[1])
rspamd_logger.infox(task, 'fann classifier score: %s', symscore)
end
local vec = tokens_to_vector(tokens)
add_metatokens(task, vec)
- rspamd_logger.infox(task, "vector: %s", vec)
+
if is_spam then
current_classify_ann.ann:train(vec, {1.0})
+ rspamd_logger.infox(task, "learned ANN spam, MSE: %s",
+ current_classify_ann.ann:get_mse())
else
- current_classify_ann.ann:train(vec, {0.0})
+ current_classify_ann.ann:train(vec, {-1.0})
+ rspamd_logger.infox(task, "learned ANN ham, MSE: %s",
+ current_classify_ann.ann:get_mse())
end
+
save_fann(task, is_spam)
end
maybe_load_fann(task, learn_cb, true)