if nham <= lim and nham + 1 >= nspam then
return tostring(nspam + 1)
else
- return tostring(-(nham + 1))
+ return tostring(-(nspam))
end
else
if nspam <= lim and nspam + 1 >= nham then
return tostring(nham + 1)
else
- return tostring(-(nspam + 1))
+ return tostring(-(nham))
end
end
-- We ignore number of layers so far when using torch
local ann = nn.Sequential()
local nhidden = math.floor((n + 1) / 2)
+ ann:add(nn.NaN(nn.Identity()))
ann:add(nn.Linear(n, nhidden))
ann:add(nn.PReLU())
ann:add(nn.Linear(nhidden, 1))
+ ann:add(nn.Tanh())
return ann
else
end
local function create_train_fann(rule, n, id)
- id = rule.prefix .. tostring(id)
local prefix = gen_fann_prefix(rule, id)
if not fanns[id] then
fanns[id] = {}
if not is_fann_valid(rule, prefix, fanns[id].fann) then
fanns[id].fann_train = create_fann(n, rule.nlayers)
fanns[id].fann = nil
+ rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
elseif fanns[id].version % rule.train.max_usages == 0 then
-- Forget last fann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
fanns[id].fann_train = create_fann(n, rule.nlayers)
else
fanns[id].fann_train = fanns[id].fann
+ rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
end
else
fanns[id].fann_train = create_fann(n, rule.nlayers)
+ rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
fanns[id].version = 0
end
end
local dataset = {}
fun.each(function(s)
table.insert(dataset, {torch.Tensor(s), torch.Tensor({1.0})})
- end, spam_elts)
+ end, fun.filter(filt, spam_elts))
fun.each(function(s)
table.insert(dataset, {torch.Tensor(s), torch.Tensor({-1.0})})
- end, ham_elts)
+ end, fun.filter(filt, ham_elts))
-- Needed for torch
- dataset.size = function(tbl) return #tbl end
+ dataset.size = function() return #dataset end
local function train_torch()
local criterion = nn.MSECriterion()
fun.each(function(elt)
elt = tostring(elt)
local prefix = gen_fann_prefix(rule, elt)
+ rspamd_logger.infox(cfg, "check ANN %s", prefix)
local redis_len_cb = function(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
prefix, tonumber(_data), rule.train.max_trains)
train_fann(rule, cfg, ev_base, elt, worker)
+ else
+ rspamd_logger.infox(rspamd_config,
+ 'no need to learn ANN %s %s learn vectors (%s required)',
+ prefix, tonumber(_data), rule.train.max_trains)
end
end
end
return copy
end
local function override_defaults(def, override)
- for k,v in pairs(def) do
- if override[k] then
- if def[k] then
- if type(override[k]) == 'table' then
- override_defaults(def[k], override[k])
- else
- def[k] = override[k]
- end
+ for k,v in pairs(override) do
+ if def[k] then
+ if type(override[k]) == 'table' then
+ override_defaults(def[k], override[k])
else
def[k] = override[k]
end
+ else
+ def[k] = override[k]
end
end
end
if not def_rules.name then
def_rules.name = k
end
+ if def_rules.train.max_train then
+ def_rules.train.max_trains = def_rules.train.max_train
+ end
rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
settings.rules[k] = def_rules
rspamd_config:set_metric_symbol({
check_fanns(rule, cfg, ev_base)
end)
- if worker:get_name() == 'normal' then
+ if worker:get_name() == 'controller' and worker:get_index() == 0 then
-- We also want to train neural nets when they have enough data
rspamd_config:add_periodic(ev_base, 0.0,
function(_, _)