]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Multiple fixes to neural net classifier
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 10 Oct 2016 19:36:48 +0000 (20:36 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 10 Oct 2016 19:36:48 +0000 (20:36 +0100)
src/plugins/lua/fann_scores.lua

index 8123a92bbd71e24874122020b04cc7663f9bc80e..c1c3d80c02dd64cf15090b173c14a42f13fecb00 100644 (file)
@@ -597,9 +597,20 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
           current_classify_ann.loaded = true
           current_classify_ann.version = version
           current_classify_ann.ann = ann
-          current_classify_ann.spam_learned = tonumber(data[3])
-          current_classify_ann.ham_learned = tonumber(data[4])
-          rspamd_logger.infox(task, "loaded fann classifier version %s", version)
+          if type(data[3]) == 'string' then
+            current_classify_ann.spam_learned = tonumber(data[3])
+          else
+            current_classify_ann.spam_learned = 0
+          end
+          if type(data[4]) == 'string' then
+            current_classify_ann.ham_learned = tonumber(data[4])
+          else
+            current_classify_ann.ham_learned = 0
+          end
+          rspamd_logger.infox(task, "loaded fann classifier version %s (%s spam, %s ham), %s MSE",
+            version, current_classify_ann.spam_learned,
+            current_classify_ann.ham_learned,
+            ann:get_mse())
           continue_cb(task, true)
         elseif call_if_fail then
           continue_cb(task, false)
@@ -625,7 +636,7 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
       if not err and type(data) == 'string' then
         local version = tonumber(data)
 
-        if version == current_classify_ann.version then
+        if version <= current_classify_ann.version then
           continue_cb(task, true)
         else
           load_fann()
@@ -652,8 +663,9 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
 end
 
 local function tokens_to_vector(tokens)
-  local vec = map(function(tok) return tok[1] end, tokens)
+  local vec = totable(map(function(tok) return tok[1] end, tokens))
   local ret = {}
+  local ntok = #vec
   local neurons = classifier_config.neurons
   for i = 1,neurons do
     ret[i] = 0
@@ -662,9 +674,15 @@ local function tokens_to_vector(tokens)
     local n = (e % neurons) + 1
     ret[n] = ret[n] + 1
   end, vec)
+  local norm = 0
+  for i = 1,neurons do
+    if ret[i] > norm then
+      norm = ret[i]
+    end
+  end
   for i = 1,neurons do
-    if ret[i] ~= 0 then
-      ret[i] = 1.0 / ret[i]
+    if ret[i] ~= 0 and norm > 0 then
+      ret[i] = ret[i] / norm
     end
   end
 
@@ -713,7 +731,7 @@ local function save_fann(task, is_spam)
   else
     current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1
   end
-  local ret,_,_ = rspamd_redis_make_request(task,
+  local ret,conn,_ = rspamd_redis_make_request(task,
     redis_params, -- connect params
     key, -- hash key
     true, -- is write
@@ -721,12 +739,19 @@ local function save_fann(task, is_spam)
     'HMSET', -- command
     {
       key,
-      'version', tostring(current_classify_ann.version),
       'data', rspamd_util.zstd_compress(data),
-      'spam', tostring(current_classify_ann.spam_learned),
-      'ham', tostring(current_classify_ann.ham_learned),
-    } -- arguments
-  )
+    }) -- arguments
+
+  if conn then
+    conn:add_cmd('HINCRBY', {key, 'version', 1})
+    if is_spam then
+      conn:add_cmd('HINCRBY', {key, 'spam', 1})
+      rspamd_logger.errx(task, 'hui')
+    else
+      conn:add_cmd('HINCRBY', {key, 'ham', 1})
+      rspamd_logger.errx(task, 'pezda')
+    end
+  end
 end
 
 if redis_params then
@@ -754,7 +779,7 @@ if redis_params then
         local vec = tokens_to_vector(tokens)
         add_metatokens(task, vec)
         local out = current_classify_ann.ann:test(vec)
-        local result = rspamd_util.tanh(2 * (out[1] - 0.5))
+        local result = rspamd_util.tanh(2 * (out[1]))
         local symscore = string.format('%.3f', out[1])
         rspamd_logger.infox(task, 'fann classifier score: %s', symscore)
 
@@ -786,12 +811,17 @@ if redis_params then
         end
         local vec = tokens_to_vector(tokens)
         add_metatokens(task, vec)
-        rspamd_logger.infox(task, "vector: %s", vec)
+
         if is_spam then
           current_classify_ann.ann:train(vec, {1.0})
+          rspamd_logger.infox(task, "learned ANN spam, MSE: %s",
+            current_classify_ann.ann:get_mse())
         else
-          current_classify_ann.ann:train(vec, {0.0})
+          current_classify_ann.ann:train(vec, {-1.0})
+          rspamd_logger.infox(task, "learned ANN ham, MSE: %s",
+            current_classify_ann.ann:get_mse())
         end
+
         save_fann(task, is_spam)
       end
       maybe_load_fann(task, learn_cb, true)