]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Use more layers for fann and another normalization
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 15 Oct 2016 12:34:22 +0000 (13:34 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 15 Oct 2016 12:34:42 +0000 (13:34 +0100)
src/plugins/lua/fann_scores.lua

index c1c3d80c02dd64cf15090b173c14a42f13fecb00..9647fd3d320213e14078e9f51221ef39fb541bdc 100644 (file)
@@ -335,6 +335,13 @@ local function check_fann(id)
       ' is found in the cache', data[id].fann:get_inputs(), n)
       data[id].fann = nil
     end
+    local layers = data[id].fann:get_layers()
+
+    if not layers or #layers ~= 5 then
+      rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
+        #layers)
+      data[id].fann = nil
+    end
   end
 
   local fname = gen_fann_file(id)
@@ -373,14 +380,15 @@ local function fann_scores_filter(task)
     end
 
     local out = data[id].fann:test(fann_data)
-    local result = rspamd_util.tanh(2 * (out[1] - 0.5))
     local symscore = string.format('%.3f', out[1])
     rspamd_logger.infox(task, 'fann score: %s', symscore)
 
-    if result > 0 then
+    if out[1] > 0 then
+      local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
       task:insert_result(fann_symbol_spam, result, symscore, id)
     else
-      task:insert_result(fann_symbol_ham, -(result), symscore, id)
+      local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+      task:insert_result(fann_symbol_ham, result, symscore, id)
     end
   else
     if load_fann(id) then
@@ -390,7 +398,7 @@ local function fann_scores_filter(task)
 end
 
 local function create_train_fann(n, id)
-  data[id].fann_train = rspamd_fann.create(3, n, n / 2, 1)
+  data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
   data[id].ntrains = 0
   data[id].epoch = 0
 end
@@ -480,7 +488,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
     if learn_spam then
       data[id].fann_train:train(learn_data, {1.0})
     else
-      data[id].fann_train:train(learn_data, {0.0})
+      data[id].fann_train:train(learn_data, {-1.0})
     end
 
     data[id].ntrains = data[id].ntrains + 1