]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Rework fann module to understand settings
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 27 Apr 2016 13:07:33 +0000 (14:07 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 27 Apr 2016 13:07:33 +0000 (14:07 +0100)
src/plugins/lua/fann_scores.lua

index b20338fb801671c761354f348578e7e3dd6836c0..ea82974a101cf9e29c088387eb20115c5e488c83 100644 (file)
@@ -25,14 +25,18 @@ require "fun" ()
 local ucl = require "ucl"
 
 -- Module vars
-local fann = nil
-local fann_train = nil
+-- ANNs indexed by settings id
+local data = {
+  ['0'] = {
+    fann_mtime = 0,
+    ntrains = 0,
+    epoch = 0,
+  }
+}
 local fann_file
-local ntrains = 0
 local max_trains = 1000
-local epoch = 0
 local max_epoch = 100
-local fann_mtime = 0
+local use_settings = false
 local opts = rspamd_config:get_all_opt("fann_scores")
 
 local function symbols_to_fann_vector(syms)
@@ -55,125 +59,144 @@ local function symbols_to_fann_vector(syms)
   return learn_data
 end
 
-local function load_fann()
-  local err,st = rspamd_util.stat(fann_file)
+local function gen_fann_file(id)
+  if use_settings then
+    return fann_file .. id
+  else
+    return fann_file
+  end
+end
+
+local function load_fann(id)
+  local fname = gen_fann_file(id)
+  local err,st = rspamd_util.stat(fname)
 
   if err then
     return false
   end
 
-  fann = rspamd_fann.load(fann_file)
+  data[id].fann = rspamd_fann.load(fname)
 
-  if fann then
+  if data[id].fann then
     local n = rspamd_config:get_symbols_count()
 
-    if n ~= fann:get_inputs() then
+    if n ~= data[id].fann:get_inputs() then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
-      ' is found in the cache; removing', fann:get_inputs(), n)
-      fann = nil
+      ' is found in the cache; removing', data[id].fann:get_inputs(), n)
+      data[id].fann = nil
 
-      local ret,err = rspamd_util.unlink(fann_file)
+      local ret,err = rspamd_util.unlink(fname)
       if not ret then
         rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
-          fann_file, err)
+          fname, err)
       end
     else
-      rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fann_file)
+      rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname)
       return true
     end
   else
-    rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fann_file)
-    local ret,err = rspamd_util.unlink(fann_file)
+    rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname)
+    local ret,err = rspamd_util.unlink(fname)
     if not ret then
       rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
-        fann_file, err)
+        fname, err)
     end
   end
 
   return false
 end
 
-local function check_fann()
-  if fann then
+local function check_fann(id)
+  if data[id].fann then
     local n = rspamd_config:get_symbols_count()
 
-    if n ~= fann:get_inputs() then
+    if n ~= data[id].fann:get_inputs() then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
-      ' is found in the cache', fann:get_inputs(), n)
-      fann = nil
+      ' is found in the cache', data[id].fann:get_inputs(), n)
+      data[id].fann = nil
     end
   end
 
-  local err,st = rspamd_util.stat(fann_file)
+  local fname = gen_fann_file(id)
+  local err,st = rspamd_util.stat(fname)
 
   if not err then
     local mtime = st['mtime']
 
-    if mtime > fann_mtime then
+    if mtime > data[id].fann_mtime then
       rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' ..
-        'file: %s -> %s, need to reload %s', fann_mtime, mtime, fann_file)
-      fann_mtime = mtime
-      fann = nil
+        'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname)
+      data[id].fann_mtime = mtime
+      data[id].fann = nil
     end
   end
 end
 
 local function fann_scores_filter(task)
-  check_fann()
+  local id = '0'
+  if use_settings then
+   local sid = task:get_settings_id()
+   if sid then
+    id = tostring(sid)
+   end
+  end
+
+  check_fann(id)
 
-  if fann then
+  if data[id].fann then
     local symbols = task:get_symbols_numeric()
     local fann_data = symbols_to_fann_vector(symbols)
 
-    local out = fann:test(fann_data)
+    local out = data[id].fann:test(fann_data)
     local result = rspamd_util.tanh(2 * (out[1] - 0.5))
     local symscore = string.format('%.3f', out[1])
     rspamd_logger.infox(task, 'fann score: %s', symscore)
 
-    task:insert_result(fann_symbol, result, symscore)
+    task:insert_result(fann_symbol, result, symscore, id)
   else
-    if load_fann() then
+    if load_fann(id) then
       fann_scores_filter(task)
     end
   end
 end
 
-local function create_train_fann(n)
-  fann_train = rspamd_fann.create(3, n, n / 2, 1)
-  ntrains = 0
-  epoch = 0
+local function create_train_fann(n, id)
+  data[id].fann_train = rspamd_fann.create(3, n, n / 2, 1)
+  data[id].ntrains = 0
+  data[id].epoch = 0
 end
 
-local function fann_train_callback(score, required_score,results, cf, opts)
+local function fann_train_callback(score, required_score,results, cf, id, opts)
   local n = cf:get_symbols_count()
+  local fname = gen_fann_file(id)
 
-  if not fann_train then
-    create_train_fann(n)
+  if not data[id].fann_train then
+    create_train_fann(n, id)
   end
 
-  if fann_train:get_inputs() ~= n then
+  if data[id].fann_train:get_inputs() ~= n then
     rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' ..
-      ' is found in the cache', fann_train:get_inputs(), n)
-    create_train_fann(n)
+      ' is found in the cache', data[id].fann_train:get_inputs(), n)
+    create_train_fann(n, id)
   end
 
-  if ntrains > max_trains then
+  if data[id].ntrains > max_trains then
     -- Store fann on disk
-    res = fann_train:save(fann_file)
+    local res = data[id].fann_train:save(fname)
 
     if not res then
-      rspamd_logger.errx(cf, 'cannot save fann in %s', fann_file)
+      rspamd_logger.errx(cf, 'cannot save fann in %s', fname)
     else
-      ntrains = 0
-      epoch = epoch + 1
+      data[id].ntrains = 0
+      data[id].epoch = data[id].epoch + 1
     end
   end
 
-  if epoch > max_epoch then
+  if data[id].epoch > max_epoch then
     -- Re-create fann
-    rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fann_file,
+    rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname,
       max_epoch)
-    create_train_fann(n)
+    create_train_fann(n, id)
   end
 
   local learn_spam, learn_ham = false, false
@@ -194,12 +217,12 @@ local function fann_train_callback(score, required_score,results, cf, opts)
     )
 
     if learn_spam then
-      fann_train:train(learn_data, {1.0})
+      data[id].fann_train:train(learn_data, {1.0})
     else
-      fann_train:train(learn_data, {0.0})
+      data[id].fann_train:train(learn_data, {0.0})
     end
 
-    ntrains = ntrains + 1
+    data[id].ntrains = data[id].ntrains + 1
   end
 end
 
@@ -212,6 +235,7 @@ else
       '`fann_file` to be specified')
   else
     fann_file = opts['fann_file']
+    use_settings = opts['use_settings']
     rspamd_config:set_metric_symbol(fann_symbol, 3.0, 'Experimental FANN adjustment')
     rspamd_config:register_post_filter(fann_scores_filter)
 
@@ -223,8 +247,14 @@ else
         if opts['train']['max_epoch'] then
           max_trains = opts['train']['max_epoch']
         end
-        cfg:register_worker_script("log_helper", function(score, req_score, results, cf)
-          fann_train_callback(score, req_score, results, cf, opts['train'])
+        cfg:register_worker_script("log_helper",
+          function(score, req_score, results, cf, id)
+            if use_settings then
+              fann_train_callback(score, req_score, results, cf,
+                tostring(id), opts['train'])
+            else
+              fann_train_callback(score, req_score, results, cf, '0', opts['train'])
+            end
         end)
       end)
     end