From: Vsevolod Stakhov Date: Sat, 16 Sep 2017 20:11:44 +0000 (+0100) Subject: [Feature] Allow to disable torch and skip train samples for ANN X-Git-Tag: 1.7.0~634 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3ac673af36fa8cb19ec535cafc29f21013ebe461;p=thirdparty%2Frspamd.git [Feature] Allow to disable torch and skip train samples for ANN --- diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index e2a7eb4f53..f07a84033f 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -46,6 +46,7 @@ local default_options = { max_iterations = 25, -- Torch style mse = 0.001, autotrain = true, + train_prob = 1.0, }, use_settings = false, per_user = false, @@ -431,7 +432,7 @@ local function create_train_fann(rule, n, id) fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].fann = nil rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) - elseif fanns[id].version % rule.train.max_usages == 0 then + elseif rule.train.max_usages > 0 and fanns[id].version % rule.train.max_usages == 0 then -- Forget last fann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, fanns[id].version) @@ -540,7 +541,7 @@ local function fann_train_callback(rule, task, score, required_score, id) local function learn_vec_cb(err) if err then - rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err) + rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err) else rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes", rule['name'], k, vec_len) @@ -549,6 +550,11 @@ local function fann_train_callback(rule, task, score, required_score, id) local function can_train_cb(err, data) if not err and tonumber(data) > 0 then + local coin = math.random() + if coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return + end local fann_data = task:get_symbols_tokens() local mt = meta_functions.rspamd_gen_metatokens(task) -- Add filtered meta tokens @@ -1069,6 +1075,10 @@ else rules['RFANN'] = opts end + if opts.disable_torch then + use_torch = false + end + local id = rspamd_config:register_symbol({ name = 'FANN_CHECK', type = 'postfilter,nostat',