]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Multiple fixes in torch based ANN plugins
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Sep 2017 08:01:09 +0000 (09:01 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Sep 2017 08:01:09 +0000 (09:01 +0100)
- Fix ANNs load
- Fix disabling torch
- Remove normalisation as we have tanh on output

src/plugins/lua/fann_redis.lua

index f07a84033fdba77c102ba124ce77c499932b2a59..2751b5d79198e5ea3b68ec116e11f582f3b63c01 100644 (file)
@@ -127,7 +127,7 @@ local redis_lua_script_maybe_load = [[
     return {redis.call('GET', KEYS[1] .. '_data'), ret}
   end
 
-  return false
+  return tonumber(ret)
 ]]
 local redis_maybe_load_sha = nil
 
@@ -332,7 +332,7 @@ local function is_fann_valid(rule, prefix, ann)
     local n = rspamd_config:get_symbols_count() +
         meta_functions.rspamd_count_metatokens()
 
-    if torch then
+    if use_torch then
       return true
     else
       if n ~= ann:get_inputs() then
@@ -375,7 +375,7 @@ local function fann_scores_filter(task)
       fun.each(function(e) table.insert(fann_data, e) end, mt)
 
       local score
-      if torch then
+      if use_torch then
         local out = fanns[id].fann:forward(torch.Tensor(fann_data))
         score = out[1]
       else
@@ -387,10 +387,16 @@ local function fann_scores_filter(task)
       rspamd_logger.infox(task, 'fann score: %s', symscore)
 
       if score > 0 then
-        local result = rspamd_util.normalize_prob(score / 2.0, 0)
+        local result = score
+        if not use_torch then
+          result = rspamd_util.normalize_prob(score / 2.0, 0)
+        end
         task:insert_result(rule.symbol_spam, result, symscore, id)
       else
-        local result = rspamd_util.normalize_prob((-score) / 2.0, 0)
+        local result = -(score)
+        if not use_torch then
+          result = rspamd_util.normalize_prob(-(score) / 2.0, 0)
+        end
         task:insert_result(rule.symbol_ham, result, symscore, id)
       end
     end
@@ -398,7 +404,7 @@ local function fann_scores_filter(task)
 end
 
 local function create_fann(n, nlayers)
-  if torch then
+  if use_torch then
     -- We ignore number of layers so far when using torch
     local ann = nn.Sequential()
     local nhidden = math.floor((n + 1) / 2)
@@ -464,7 +470,7 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
     rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
     return
   else
-    if torch then
+    if use_torch then
       ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
     else
       ann = rspamd_fann.load_data(ann_data)
@@ -647,7 +653,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
       rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
         prefix, train_mse)
       local ann_data
-      if torch then
+      if use_torch then
         local f = torch.MemoryFile()
         f:writeObject(fanns[elt].fann_train)
         ann_data = rspamd_util.zstd_compress(f:storage():string())
@@ -762,7 +768,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
           {redis_locked_invalidate_sha, 1, prefix}
         )
       else
-        if torch then
+        if use_torch then
           -- For torch we do not need to mix samples as they would be flushed
           local dataset = {}
           fun.each(function(s)
@@ -996,8 +1002,12 @@ local function check_fanns(rule, _, ev_base)
           elseif _data and type(_data) == 'table' then
             load_or_invalidate_fann(rule, _data, elt, ev_base)
           else
-            rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis %s for prefix: %s',
-              type(_data), elt)
+            if type(_data) == 'number' then
+              -- no new version
+            else
+              rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s',
+                type(_data), elt)
+            end
           end
         end
 
@@ -1161,7 +1171,7 @@ else
   for _,rule in pairs(settings.rules) do
     rspamd_config:add_on_load(function(cfg, ev_base, worker)
       load_scripts(cfg, ev_base, function(_, _)
-          check_fanns(rule, cfg, ev_base)
+          return check_fanns(rule, cfg, ev_base)
       end)
 
       if worker:get_name() == 'controller' and worker:get_index() == 0 then