]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add neural net classifier to fann_scores module
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 8 Oct 2016 15:35:23 +0000 (16:35 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 8 Oct 2016 15:35:42 +0000 (16:35 +0100)
src/plugins/lua/fann_scores.lua

index 9ddb79fc38490f9082c7be156fbce3b224766cc9..c67eb597dbb45fd50b2e0a66a936fe16f3001154 100644 (file)
@@ -498,9 +498,11 @@ end
 if not rspamd_fann.is_enabled() then
   rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
     'module is eventually disabled')
+
+  return
 else
   if not opts['fann_file'] then
-    rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' ..
+    rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' ..
       '`fann_file` to be specified')
   else
     fann_file = opts['fann_file']
@@ -560,3 +562,215 @@ else
     end
   end
 end
+
+local redis_params
+local classifier_config = {
+  key = 'neural_net',
+  neurons = 200,
+  layers = 3,
+}
+
+local current_classify_ann = {
+  loaded = false,
+  version = 0,
+  spam_learned = 0,
+  ham_learned = 0
+}
+
+redis_params = rspamd_parse_redis_server('fann_scores')
+
+local function maybe_load_fann(task, continue_cb, call_if_fail)
+  local function load_fann()
+    local function redis_fann_load_cb(task, err, data)
+      if not err and type(data) == 'table' and type(data[2]) == 'string' then
+        local version = tonumber(data[1])
+        local ann_data = data[2]
+        local ann = rspamd_fann.load_data(ann_data)
+
+        if ann then
+          current_classify_ann.loaded = true
+          current_classify_ann.version = version
+          current_classify_ann.ann = ann
+          current_classify_ann.spam_learned = tonumber(data[3])
+          current_classify_ann.ham_learned = tonumber(data[4])
+          rspamd_logger.infox(task, "loaded fann classifier version %s", version)
+          continue_cb(task, true)
+        elseif call_if_fail then
+          continue_cb(task, false)
+        end
+      elseif call_if_fail then
+        continue_cb(task, false)
+      end
+    end
+
+    local key = classifier_config.key
+    local ret,_,_ = rspamd_redis_make_request(task,
+      redis_params, -- connect params
+      key, -- hash key
+      false, -- is write
+      redis_fann_load_cb, --callback
+      'HMGET', -- command
+      {key, 'version', 'data', 'spam', 'ham'} -- arguments
+    )
+  end
+
+  local function check_fann()
+    local function redis_fann_check_cb(task, err, data)
+      if not err and type(data) == 'string' then
+        local version = tonumber(data)
+
+        if version == current_classify_ann.version then
+          continue_cb(task, true)
+        else
+          load_fann()
+        end
+      end
+    end
+
+    local key = classifier_config.key
+    local ret,_,_ = rspamd_redis_make_request(task,
+      redis_params, -- connect params
+      key, -- hash key
+      false, -- is write
+      redis_fann_check_cb, --callback
+      'HGET', -- command
+      {key, 'version'} -- arguments
+    )
+  end
+
+  if not current_classify_ann.loaded then
+    load_fann()
+  else
+    check_fann()
+  end
+end
+
+local function tokens_to_vector(tokens)
+  local vec = map(function(tok) return tok[1] end, tokens)
+  local ret = {}
+  local neurons = classifier_config.neurons
+  for i = 1,neurons do
+    ret[i] = 0
+  end
+  each(function(e)
+    local n = (e % neurons) + 1
+    ret[n] = ret[n] + 1
+  end, vec)
+  for i = 1,neurons do
+    if ret[i] ~= 0 then
+      ret[i] = 1.0 / ret[i]
+    end
+  end
+
+  return ret
+end
+
+local function add_metatokens(task, vec)
+    local mt = gen_metatokens(task)
+    for _,tok in ipairs(mt) do
+      table.insert(vec, tok)
+    end
+end
+
+local function create_fann()
+  local layers = {}
+  local mt_size = count_metatokens()
+  local neurons = classifier_config.neurons + mt_size
+
+  for i = 1,classifier_config.layers - 1 do
+    layers[i] = math.floor(neurons / i)
+  end
+
+  table.insert(layers, 1)
+
+  local ann = rspamd_fann.create(classifier_config.layers, layers)
+  current_classify_ann.loaded = true
+  current_classify_ann.version = 0
+  current_classify_ann.ann = ann
+  current_classify_ann.spam_learned = 0
+  current_classify_ann.ham_learned = 0
+end
+
+local function save_fann(task, is_spam)
+  local function redis_fann_save_cb(task, err, data)
+    if err then
+      rspamd_logger.errx(task, "cannot save neural net to redis: %s", err)
+    end
+  end
+
+  local data = current_classify_ann.ann:data()
+  local key = classifier_config.key
+  current_classify_ann.version = current_classify_ann.version + 1
+
+  if is_spam then
+    current_classify_ann.spam_learned = current_classify_ann.spam_learned + 1
+  else
+    current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1
+  end
+  local ret,_,_ = rspamd_redis_make_request(task,
+    redis_params, -- connect params
+    key, -- hash key
+    true, -- is write
+    redis_fann_save_cb, --callback
+    'HMSET', -- command
+    {
+      key,
+      'version', tostring(current_classify_ann.version),
+      'data', tostring(data),
+      'spam', tostring(current_classify_ann.spam_learned),
+      'ham', tostring(current_classify_ann.ham_learned),
+    } -- arguments
+  )
+end
+
+if redis_params then
+  rspamd_classifiers['neural'] = {
+    classify = function(task, classifier, tokens)
+      local function classify_cb(task)
+        local vec = tokens_to_vector(tokens)
+        add_metatokens(task, vec)
+        local out = current_classify_ann.ann:test(vec)
+        local result = rspamd_util.tanh(2 * (out[1] - 0.5))
+        local symscore = string.format('%.3f', out[1])
+        rspamd_logger.infox(task, 'fann classifier score: %s', symscore)
+
+        if result > 0 then
+          each(function(st)
+              task:insert_result(st:get_symbol(), result, symscore)
+            end,
+            filter(function(st)
+              return st:is_spam()
+            end, classifier:get_statfiles())
+          )
+        else
+          each(function(st)
+              task:insert_result(st:get_symbol(), -result, symscore)
+            end,
+            filter(function(st)
+              return not st:is_spam()
+            end, classifier:get_statfiles())
+          )
+        end
+      end
+      maybe_load_fann(task, classify_cb, false)
+    end,
+
+    learn = function(task, classifier, tokens, is_spam, is_unlearn)
+      local function learn_cb(task, is_loaded)
+        if not is_loaded then
+          create_fann()
+        end
+        local vec = tokens_to_vector(tokens)
+        add_metatokens(task, vec)
+        rspamd_logger.infox(task, "vector: %s", vec)
+        if is_spam then
+          current_classify_ann.ann:train(vec, {1.0})
+        else
+          current_classify_ann.ann:train(vec, {0.0})
+        end
+        save_fann(task, is_spam)
+      end
+      maybe_load_fann(task, learn_cb, true)
+    end,
+  }
+end