]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Multiple issues in fann_redis
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 15 Nov 2016 14:13:00 +0000 (14:13 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 15 Nov 2016 14:13:00 +0000 (14:13 +0000)
src/plugins/lua/fann_redis.lua

index c55f376def746e5d74b785b3348b5e0e62e48fb8..361d82303efc4f72bad6dde8c44176e4c73d8f30 100644 (file)
@@ -28,7 +28,7 @@ local ucl = require "ucl"
 local module_log_id = 0x200
 -- Module vars
 -- ANNs indexed by settings id
-local data = {
+local fanns = {
   ['0'] = {
     version = 0,
   }
@@ -80,7 +80,7 @@ local redis_lua_script_maybe_load = [[
   local ver = 0
   local ret = redis.call('GET', KEYS[1] .. '_version')
   if ret then ver = tonumber(ret) end
-  if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end
+  if ver > tonumber(KEYS[2]) then return redis.call('GET', KEYS[1] .. '_ann') end
 
   return false
 ]]
@@ -135,6 +135,7 @@ local max_epoch = 100
 local use_settings = false
 local watch_interval = 60.0
 local mse = 0.0001
+local nlayers = 4
 
 local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
   if not ev_base or not redis_params or not callback or not command then
@@ -222,7 +223,7 @@ local function is_fann_valid(ann)
     end
     local layers = ann:get_layers()
 
-    if not layers or #layers ~= 5 then
+    if not layers or #layers ~= nlayers then
       rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s',
         #layers)
       return false
@@ -241,7 +242,7 @@ local function fann_scores_filter(task)
    end
   end
 
-  if data[id].fann then
+  if fanns[id].fann then
     local symbols,scores = task:get_symbols_numeric()
     local fann_data = symbols_to_fann_vector(symbols, scores)
     local mt = rspamd_gen_metatokens(task)
@@ -250,7 +251,7 @@ local function fann_scores_filter(task)
       table.insert(fann_data, tok)
     end
 
-    local out = data[id].fann:test(fann_data)
+    local out = fanns[id].fann:test(fann_data)
     local symscore = string.format('%.3f', out[1])
     rspamd_logger.infox(task, 'fann score: %s', symscore)
 
@@ -265,8 +266,18 @@ local function fann_scores_filter(task)
 end
 
 local function create_train_fann(n, id)
-  data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1)
-  data[id].version = 0
+  id = tostring(id)
+  if not fanns[id] then
+    fanns[id] = {}
+  end
+
+  if fanns[id].fann then
+    fanns[id].fann_train = fanns[id].fann
+    fanns[id].fann = nil
+  else
+    fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+    fanns[id].version = 0
+  end
 end
 
 local function load_or_invalidate_fann(data, id, ev_base)
@@ -280,7 +291,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
   end
 
   if is_fann_valid(ann) then
-    data[id].fann = ann
+    fanns[id].fann = ann
   else
     local function redis_invalidate_cb(err, data)
       if err then
@@ -367,6 +378,7 @@ end
 local function train_fann(cfg, ev_base, elt)
   local spam_elts = {}
   local ham_elts = {}
+  elt = tostring(elt)
 
   local function redis_unlock_cb(err, data)
     if err then
@@ -398,7 +410,9 @@ local function train_fann(cfg, ev_base, elt)
       rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
         fann_prefix .. elt, train_mse)
       local ann_data = rspamd_util.zstd_compress(data[elt].fann:data())
-      data[elt].version = data[elt].version + 1
+      fanns[elt].version = fanns[elt].version + 1
+      fanns[elt].fann = fanns[elt].fann_train
+      fanns[elt].fann_train = nil
       redis_make_request(ev_base,
         rspamd_config,
         nil,
@@ -424,32 +438,32 @@ local function train_fann(cfg, ev_base, elt)
       )
     else
       -- Decompress and convert to numbers each training vector
-      ham_elts = map(function(i, tok)
-        local str = tostring(rspamd_util.zstd_decompress(tok))
-        return map(tonumber, rspamd_str_split(str, ';'))
+      ham_elts = map(function(tok)
+        local _,str = rspamd_util.zstd_decompress(tok)
+        return map(tonumber, rspamd_str_split(tostring(str), ';'))
       end, data)
 
       -- Now we need to join inputs and create the appropriate test vectors
       local inputs = {}
       local outputs = {}
 
-      each(function(i, sample)
+      each(function(sample)
         table.insert(inputs, totable(sample))
-        table.insert(outputs, 1.0)
+        table.insert(outputs, {1.0})
       end, spam_elts)
-      each(function(i, sample)
+      each(function(sample)
         table.insert(inputs, totable(sample))
-        table.insert(outputs, -1.0)
-      end, spam_elts)
+        table.insert(outputs, {-1.0})
+      end, ham_elts)
 
       -- Now we can train fann
       local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
-      if not data[elt].fann then
+      if not fanns[elt] or not fanns[elt].fann_train then
         -- Create fann if it does not exist
         create_train_fann(n, elt)
       end
 
-      data[elt].fann:train_threaded(inputs, outputs, ann_trained, ev_base,
+      fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
         {max_epochs = max_epoch, desired_mse = mse})
     end
   end
@@ -468,9 +482,9 @@ local function train_fann(cfg, ev_base, elt)
       )
     else
       -- Decompress and convert to numbers each training vector
-      spam_elts = map(function(i, tok)
-        local str = tostring(rspamd_util.zstd_decompress(tok))
-        return map(tonumber, rspamd_str_split(str, ';'))
+      spam_elts = map(function(tok)
+        local _,str = rspamd_util.zstd_decompress(tok)
+        return map(tonumber, rspamd_str_split(tostring(str), ';'))
       end, data)
       redis_make_request(ev_base,
         rspamd_config,
@@ -514,7 +528,8 @@ local function maybe_train_fanns(cfg, ev_base)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
     elseif type(data) == 'table' then
-      each(function(i, elt)
+      each(function(elt)
+        elt = tostring(elt)
         local redis_len_cb = function(err, data)
           if err then
             rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err)
@@ -527,9 +542,9 @@ local function maybe_train_fanns(cfg, ev_base)
 
         local local_ver = 0
         local numelt = tonumber(elt)
-        if data[numelt] then
-          if data[numelt].version then
-            local_ver = data[numelt].version
+        if fanns[numelt] then
+          if fanns[numelt].version then
+            local_ver = fanns[numelt].version
           end
         end
         redis_make_request(ev_base,
@@ -567,7 +582,8 @@ local function check_fanns(cfg, ev_base)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
     elseif type(data) == 'table' then
-      each(function(i, elt)
+      each(function(elt)
+        elt = tostring(elt)
         local redis_update_cb = function(err, data)
           if err then
             rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
@@ -578,9 +594,9 @@ local function check_fanns(cfg, ev_base)
 
         local local_ver = 0
         local numelt = tonumber(elt)
-        if data[numelt] then
-          if data[numelt].version then
-            local_ver = data[numelt].version
+        if fanns[numelt] then
+          if fanns[numelt].version then
+            local_ver = fanns[numelt].version
           end
         end
         redis_make_request(ev_base,
@@ -683,7 +699,7 @@ else
       end
     end)
     -- This is needed to pass extra tokens from worker to log_helper
-    rspamd_plugins["fann_score"] = {
+    rspamd_plugins["fann_redis"] = {
       log_callback = function(task)
         return totable(map(
           function(tok) return {module_log_id, tok} end,