]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Rework fann learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Apr 2016 13:22:01 +0000 (14:22 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 6 Apr 2016 13:22:01 +0000 (14:22 +0100)
src/plugins/lua/fann_scores.lua

index f430d915085bbe07f09f037c7f9f602db8be42f3..63dcb573321028b65e2d28cdc0c32ce07fecb988 100644 (file)
@@ -21,11 +21,12 @@ local rspamd_logger = require "rspamd_logger"
 local rspamd_fann = require "rspamd_fann"
 local rspamd_util = require "rspamd_util"
 local fann_symbol = 'FANN_SCORE'
+require "fun" ()
 local ucl = require "ucl"
 
 -- Module vars
-local fann
-local fann_train
+local fann = nil
+local fann_train = nil
 local fann_file
 local ntrains = 0
 local max_trains = 1000
@@ -34,6 +35,26 @@ local max_epoch = 100
 local fann_mtime = 0
 local opts = rspamd_config:get_all_opt("fann_scores")
 
+local function symbols_to_fann_vector(syms)
+  local learn_data = {}
+  local matched_symbols = {}
+  local n = rspamd_config:get_symbols_count()
+
+  each(function(s)
+    matched_symbols[s + 1] = 1
+  end, syms)
+
+  for i=1,n do
+    if matched_symbols[i] then
+      learn_data[i] = 1
+    else
+      learn_data[i] = 0
+    end
+  end
+
+  return learn_data
+end
+
 local function load_fann()
   local err,st = rspamd_util.stat(fann_file)
 
@@ -60,8 +81,6 @@ local function load_fann()
 end
 
 local function check_fann()
-  local n = rspamd_config:get_symbols_count()
-
   if fann then
     local n = rspamd_config:get_symbols_count()
 
@@ -88,17 +107,10 @@ local function fann_scores_filter(task)
   check_fann()
 
   if fann then
-    local fann_input = {}
-
-    for sym,idx in pairs(symbols) do
-      if task:has_symbol(sym) then
-        fann_input[idx + 1] = 1
-      else
-        fann_input[idx + 1] = 0
-      end
-    end
+    local symbols = task:get_symbols_numeric()
+    local fann_data = symbols_to_fann_vector(symbols)
 
-    local out = fann:test(nsymbols, fann_input)
+    local out = fann:test(fann_data)
     local result = rspamd_util.tanh(2 * (out[1] - 0.5))
     local symscore = string.format('%.3f', out[1])
     rspamd_logger.infox(task, 'fann score: %s', symscore)
@@ -117,7 +129,7 @@ local function create_train_fann(n)
   epoch = 0
 end
 
-local function fann_train(score, required_score,results, cf, opts)
+local function fann_train_callback(score, required_score,results, cf, opts)
   local n = cf:get_symbols_count()
 
   if not fann_train then
@@ -162,28 +174,17 @@ local function fann_train(score, required_score,results, cf, opts)
   end
 
   if learn_spam or learn_ham then
-    local learn_data = {}
-    local matched_symbols = {}
-
-    for _,sym in ipairs(results) do
-      matched_symbols[sym[1] + 1] = 1
-    end
-
-    for i=1,(n + 1) do
-      if matched_symbols[i] then
-        learn_data[i] = 1
-      else
-        learn_data[i] = 0
-      end
-    end
+    local learn_data = symbols_to_fann_vector(
+      map(function(r) return r[1] end, results)
+    )
 
     if learn_spam then
-      fann_train:train(learn_data, 1.0)
+      fann_train:train(learn_data, {1.0})
     else
-      fann_train:train(learn_data, 0.0)
+      fann_train:train(learn_data, {0.0})
     end
 
-    trains = trains + 1
+    ntrains = ntrains + 1
   end
 end
 
@@ -208,7 +209,7 @@ else
           max_trains = opts['train']['max_epoch']
         end
         cfg:register_worker_script("log_helper", function(score, req_score, results, cf)
-          fann_train(score, req_score, results, cf, opts['train'])
+          fann_train_callback(score, req_score, results, cf, opts['train'])
         end)
       end)
     end