max_trains = 1000,
max_epoch = 1000,
max_usages = 10,
+ max_iterations = 25, -- Torch style
mse = 0.001,
autotrain = true,
},
meta_functions.rspamd_count_metatokens()
if torch then
- local nlayers = #ann
- if nlayers ~= rule.nlayers then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
- prefix, nlayers)
- return false
- end
-
- local inp = ann:get(1):nElement()
- if n ~= inp then
- rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', prefix, inp, n)
- return false
- end
+ return true
else
if n ~= ann:get_inputs() then
rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
end
local function fann_scores_filter(task)
- for _,rule in ipairs(settings.rules) do
- local id = rule.prefix .. '0'
+
+ for _,rule in pairs(settings.rules) do
+ local id = '0'
if rule.use_settings then
local sid = task:get_settings_id()
if sid then
- id = rule.prefix .. tostring(sid)
+ id = tostring(sid)
end
end
if rule.per_user then
end
if is_fann_valid(rule, prefix, ann) then
+ if not fanns[id] then fanns[id] = {} end
fanns[id].fann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
if string.match(err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
end
+ else
+ rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix)
end
end
true, -- is write
redis_save_cb, --callback
'EVALSHA', -- command
- {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+ {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
)
end
end
{prefix .. '_locked'}
)
else
- rspamd_logger.infox(rspamd_config, 'trained ANN %s',
- prefix)
+ rspamd_logger.infox(rspamd_config, 'trained ANN %s, %s bytes',
+ prefix, #data)
local ann_data
local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
ann_data = rspamd_util.zstd_compress(f:storage():string())
true, -- is write
redis_save_cb, --callback
'EVALSHA', -- command
- {redis_save_unlock_sha, '2', prefix, ann_data, tostring(rule.ann_expire)}
+ {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
)
end
end
local trainer = nn.StochasticGradient(fanns[elt].fann_train,
criterion)
trainer.learning_rate = 0.01
+ trainer.verbose = false
+ trainer.maxIteration = rule.train.max_iterations
trainer.hookIteration = function(self, iteration, currentError)
rspamd_logger.infox(rspamd_config, "learned %s iterations, error: %s",
iteration, currentError)
local function check_fanns(rule, _, ev_base)
local function members_cb(err, data)
if err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
+ rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s',
+ err)
elseif type(data) == 'table' then
fun.each(function(elt)
elt = tostring(elt)
local redis_update_cb = function(_err, _data)
if _err then
- rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, _err)
+ rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
+ elt, _err)
if string.match(_err, 'NOSCRIPT') then
load_scripts(rspamd_config, ev_base, nil)
end
elseif _data and type(_data) == 'table' then
load_or_invalidate_fann(rule, _data, elt, ev_base)
+ else
+ rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s',
+ type(_data), elt)
end
end