]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Allow to disable torch and skip train samples for ANN
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 20:11:44 +0000 (21:11 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Sep 2017 20:11:44 +0000 (21:11 +0100)
src/plugins/lua/fann_redis.lua

index e2a7eb4f53e3e7b45cfc33b639f13bf93b6599bb..f07a84033fdba77c102ba124ce77c499932b2a59 100644 (file)
@@ -46,6 +46,7 @@ local default_options = {
     max_iterations = 25, -- Torch style
     mse = 0.001,
     autotrain = true,
+    train_prob = 1.0,
   },
   use_settings = false,
   per_user = false,
@@ -431,7 +432,7 @@ local function create_train_fann(rule, n, id)
       fanns[id].fann_train = create_fann(n, rule.nlayers)
       fanns[id].fann = nil
       rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
-    elseif fanns[id].version % rule.train.max_usages == 0 then
+    elseif rule.train.max_usages > 0 and fanns[id].version % rule.train.max_usages == 0 then
       -- Forget last fann
       rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
         fanns[id].version)
@@ -540,7 +541,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
 
     local function learn_vec_cb(err)
       if err then
-        rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err)
+        rspamd_logger.errx(task, 'cannot store train vector for %s: %s', fname, err)
       else
         rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
           rule['name'], k, vec_len)
@@ -549,6 +550,11 @@ local function fann_train_callback(rule, task, score, required_score, id)
 
     local function can_train_cb(err, data)
       if not err and tonumber(data) > 0 then
+        local coin = math.random()
+        if coin < 1.0 - train_opts.train_prob then
+          rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+          return
+        end
         local fann_data = task:get_symbols_tokens()
         local mt = meta_functions.rspamd_gen_metatokens(task)
         -- Add filtered meta tokens
@@ -1069,6 +1075,10 @@ else
     rules['RFANN'] = opts
   end
 
+  if opts.disable_torch then
+    use_torch = false
+  end
+
   local id = rspamd_config:register_symbol({
     name = 'FANN_CHECK',
     type = 'postfilter,nostat',