]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Allow multiple fann rules
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 30 Jul 2017 08:56:52 +0000 (09:56 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 30 Jul 2017 08:56:52 +0000 (09:56 +0100)
src/plugins/lua/fann_redis.lua

index 378000a24c8d44f5db383eaeba0af689b02e0f1c..f2b6c44f69155bea48afcf6dee21473e7a279090 100644 (file)
@@ -24,17 +24,36 @@ end
 local rspamd_logger = require "rspamd_logger"
 local rspamd_fann = require "rspamd_fann"
 local rspamd_util = require "rspamd_util"
-local fann_symbol_spam = 'FANNR_SPAM'
-local fann_symbol_ham = 'FANNR_HAM'
 local rspamd_redis = require "lua_redis"
 local fun = require "fun"
 local meta_functions = require "meta_functions"
+
 -- Module vars
+local default_options = {
+  train = {
+    max_trains = 1000,
+    max_epoch = 1000,
+    max_usages = 10,
+    use_settings = false,
+    watch_interval = 60.0,
+    mse = 0.001,
+    autotrain = true,
+  },
+  nlayers = 4,
+  lock_expire = 600,
+  learning_spawned = false,
+  ann_expire = 60 * 60 * 24 * 2, -- 2 days
+  symbol_spam = 'FANNR_SPAM',
+  symbol_ham = 'FANNR_HAM',
+}
+
+local settings = {
+  rules = {
+  }
+}
+
 -- ANNs indexed by settings id
 local fanns = {
-  ['0'] = {
-    version = 0,
-  }
 }
 
 local opts = rspamd_config:get_all_opt("fann_redis")
@@ -162,19 +181,6 @@ local redis_lua_script_save_unlock = [[
 local redis_save_unlock_sha = nil
 
 local redis_params
-redis_params = rspamd_parse_redis_server('fann_redis')
-
-local fann_prefix = 'RFANN'
-local max_trains = 1000
-local max_epoch = 1000
-local max_usages = 10
-local use_settings = false
-local watch_interval = 60.0
-local mse = 0.0001
-local nlayers = 4
-local lock_expire = 600
-local learning_spawned = false
-local ann_expire = 60 * 60 * 24 * 2 -- 2 days
 
 local function load_scripts(cfg, ev_base, on_load_cb)
   local function can_train_sha_cb(err, data)
@@ -287,15 +293,16 @@ local function load_scripts(cfg, ev_base, on_load_cb)
   )
 end
 
-local function gen_fann_prefix(id)
+local function gen_fann_prefix(rule, id)
   if id then
-    return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id
+    return rule.prefix .. rspamd_config:get_symbols_cksum():hex() .. id,
+      rule.prefix .. id
   else
-    return fann_prefix .. rspamd_config:get_symbols_cksum():hex(), nil
+    return rule.prefix .. rspamd_config:get_symbols_cksum():hex(), nil
   end
 end
 
-local function is_fann_valid(prefix, ann)
+local function is_fann_valid(rule, prefix, ann)
   if ann then
     local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens()
 
@@ -306,7 +313,7 @@ local function is_fann_valid(prefix, ann)
     end
     local layers = ann:get_layers()
 
-    if not layers or #layers ~= nlayers then
+    if not layers or #layers ~= rule.nlayers then
       rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
         prefix, #layers)
       return false
@@ -317,65 +324,69 @@ local function is_fann_valid(prefix, ann)
 end
 
 local function fann_scores_filter(task)
-  local id = '0'
-  if use_settings then
-   local sid = task:get_settings_id()
-   if sid then
-    id = tostring(sid)
-   end
-  end
+  for _,rule in settings.rules do
+    local id = rule.prefix .. '0'
+    if rule.use_settings then
+     local sid = task:get_settings_id()
+     if sid then
+      id = rule.prefix .. tostring(sid)
+     end
+    end
 
-  if fanns[id].fann then
-    local fann_data = task:get_symbols_tokens()
-    local mt = meta_functions.rspamd_gen_metatokens(task)
-    -- Add filtered meta tokens
-    fun.each(function(e) table.insert(fann_data, e) end, mt)
-
-    local out = fanns[id].fann:test(fann_data)
-    local symscore = string.format('%.3f', out[1])
-    rspamd_logger.infox(task, 'fann score: %s', symscore)
-
-    if out[1] > 0 then
-      local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
-      task:insert_result(fann_symbol_spam, result, symscore, id)
-    else
-      local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
-      task:insert_result(fann_symbol_ham, result, symscore, id)
+    if fanns[id].fann then
+      local fann_data = task:get_symbols_tokens()
+      local mt = meta_functions.rspamd_gen_metatokens(task)
+      -- Add filtered meta tokens
+      fun.each(function(e) table.insert(fann_data, e) end, mt)
+
+      local out = fanns[id].fann:test(fann_data)
+      local symscore = string.format('%.3f', out[1])
+      rspamd_logger.infox(task, 'fann score: %s', symscore)
+
+      if out[1] > 0 then
+        local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
+        task:insert_result(rule.symbol_spam, result, symscore, id)
+      else
+        local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+        task:insert_result(rule.symbol_ham, result, symscore, id)
+      end
     end
   end
 end
 
-local function create_train_fann(n, id)
-  id = tostring(id)
-  local prefix = gen_fann_prefix(id)
+local function create_train_fann(rule, n, id)
+  id = rule.prefix .. tostring(id)
+  local prefix = gen_fann_prefix(rule, id)
   if not fanns[id] then
     fanns[id] = {}
   end
-
+  -- Fix that for flexibe layers number
   if fanns[id].fann then
-    if n ~= fanns[id].fann:get_inputs() or
+    if n ~= fanns[id].fann:get_inputs() or --
       (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then
-      rspamd_logger.infox(rspamd_config, 'recreate ANN %s as it has a wrong number of inputs, version %s', prefix,
+      rspamd_logger.infox(rspamd_config,
+        'recreate ANN %s as it has a wrong number of inputs, version %s',
+        prefix,
         fanns[id].version)
-      fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+      fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
       fanns[id].fann = nil
-    elseif fanns[id].version % max_usages == 0 then
+    elseif fanns[id].version % rule.max_usages == 0 then
       -- Forget last fann
       rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
         fanns[id].version)
-      fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+      fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
     else
       fanns[id].fann_train = fanns[id].fann
     end
   else
-    fanns[id].fann_train = rspamd_fann.create(nlayers, n, n / 2, n / 4, 1)
+    fanns[id].fann_train = rspamd_fann.create(rule.nlayers, n, n / 2, n / 4, 1)
     fanns[id].version = 0
   end
 end
 
-local function load_or_invalidate_fann(data, id, ev_base)
+local function load_or_invalidate_fann(rule, data, id, ev_base)
   local ver = data[2]
-  local prefix = gen_fann_prefix(id)
+  local prefix = gen_fann_prefix(rule, id)
 
   if not ver or not tonumber(ver) then
     rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix)
@@ -392,7 +403,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
     ann = rspamd_fann.load_data(ann_data)
   end
 
-  if is_fann_valid(prefix, ann) then
+  if is_fann_valid(rule, prefix, ann) then
     fanns[id].fann = ann
     rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
       prefix, ver)
@@ -413,7 +424,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
     rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
     rspamd_redis.redis_make_request_taskless(ev_base,
       rspamd_config,
-      redis_params,
+      rule.redis,
       nil,
       true, -- is write
       redis_invalidate_cb, --callback
@@ -423,9 +434,9 @@ local function load_or_invalidate_fann(data, id, ev_base)
   end
 end
 
-local function fann_train_callback(task, score, required_score, id)
-  local train_opts = opts['train']
-  local fname,suffix = gen_fann_prefix(id)
+local function fann_train_callback(rule, task, score, required_score, id)
+  local train_opts = rule['train']
+  local fname,suffix = gen_fann_prefix(rule, id)
 
   local learn_spam, learn_ham
 
@@ -459,7 +470,7 @@ local function fann_train_callback(task, score, required_score, id)
         local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
 
         rspamd_redis.redis_make_request(task,
-          redis_params,
+          rule.redis,
           nil,
           true, -- is write
           learn_vec_cb, --callback
@@ -477,22 +488,22 @@ local function fann_train_callback(task, score, required_score, id)
     end
 
     rspamd_redis.rspamd_redis_make_request(task,
-      redis_params,
+      rule.redis,
       nil,
       true, -- is write
       can_train_cb, --callback
       'EVALSHA', -- command
-      {redis_can_train_sha, '4', gen_fann_prefix(nil),
+      {redis_can_train_sha, '4', gen_fann_prefix(rule, nil),
         suffix, k, tostring(max_trains)} -- arguments
     )
   end
 end
 
-local function train_fann(_, ev_base, elt)
+local function train_fann(rule, _, ev_base, elt)
   local spam_elts = {}
   local ham_elts = {}
   elt = tostring(elt)
-  local prefix = gen_fann_prefix(elt)
+  local prefix = gen_fann_prefix(rule, elt)
 
   local function redis_unlock_cb(err)
     if err then
@@ -507,7 +518,7 @@ local function train_fann(_, ev_base, elt)
         prefix, err)
       rspamd_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
-        redis_params,
+        rule.redis,
         nil,
         false, -- is write
         redis_unlock_cb, --callback
@@ -521,13 +532,13 @@ local function train_fann(_, ev_base, elt)
   end
 
   local function ann_trained(errcode, errmsg, train_mse)
-    learning_spawned = false
+    rule.learning_spawned = false
     if errcode ~= 0 then
       rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s',
         prefix, errmsg)
       rspamd_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
-        redis_params,
+        rule.redis,
         nil,
         true, -- is write
         redis_unlock_cb, --callback
@@ -543,7 +554,7 @@ local function train_fann(_, ev_base, elt)
       fanns[elt].fann_train = nil
       rspamd_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
-        redis_params,
+        rule.redis,
         nil,
         true, -- is write
         redis_save_cb, --callback
@@ -559,7 +570,7 @@ local function train_fann(_, ev_base, elt)
         prefix, err)
       rspamd_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
-        redis_params,
+        rule.redis,
         nil,
         true, -- is write
         redis_unlock_cb, --callback
@@ -593,7 +604,7 @@ local function train_fann(_, ev_base, elt)
       if not fanns[elt] or not fanns[elt].fann_train
         or n ~= fanns[elt].fann_train:get_inputs() then
         -- Create fann if it does not exist
-        create_train_fann(n, elt)
+        create_train_fann(rule, n, elt)
       end
 
       if #inputs < max_trains / 2 then
@@ -610,7 +621,7 @@ local function train_fann(_, ev_base, elt)
         rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
         rspamd_redis.redis_make_request_taskless(ev_base,
           rspamd_config,
-          redis_params,
+          rule.redis,
           nil,
           true, -- is write
           redis_invalidate_cb, --callback
@@ -618,10 +629,13 @@ local function train_fann(_, ev_base, elt)
           {redis_locked_invalidate_sha, 1, prefix}
         )
       else
-        learning_spawned = true
+        rule.learning_spawned = true
         rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
-        fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, ev_base,
-          {max_epochs = max_epoch, desired_mse = mse})
+        fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained,
+          ev_base, {
+            max_epochs = rule.train.max_epoch,
+            desired_mse = rule.train.mse
+          })
       end
     end
   end
@@ -632,7 +646,7 @@ local function train_fann(_, ev_base, elt)
         prefix, err)
       rspamd_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
-        redis_params,
+        rule.redis,
         nil,
         true, -- is write
         redis_unlock_cb, --callback
@@ -647,7 +661,7 @@ local function train_fann(_, ev_base, elt)
       end, data))
       rspamd_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
-        redis_params,
+        rule.redis,
         nil,
         false, -- is write
         redis_ham_cb, --callback
@@ -668,7 +682,7 @@ local function train_fann(_, ev_base, elt)
       -- Can train ANN
       rspamd_redis.redis_make_request_taskless(ev_base,
         rspamd_config,
-        redis_params,
+        rule.redis,
         nil,
         false, -- is write
         redis_spam_cb, --callback
@@ -687,10 +701,10 @@ local function train_fann(_, ev_base, elt)
                 prefix)
             end
           end
-          if learning_spawned then
+          if rule.learning_spawned then
             rspamd_redis.redis_make_request_taskless(ev_base,
               rspamd_config,
-              redis_params,
+              rule.redis,
               nil,
               true, -- is write
               redis_lock_extend_cb, --callback
@@ -709,45 +723,47 @@ local function train_fann(_, ev_base, elt)
       rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, locked by another process', prefix)
     end
   end
-  if learning_spawned then
+  if rule.learning_spawned then
     rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
     return
   end
   rspamd_redis.redis_make_request_taskless(ev_base,
     rspamd_config,
-    redis_params,
+    rule.redis,
     nil,
     true, -- is write
     redis_lock_cb, --callback
     'EVALSHA', -- command
     {redis_maybe_lock_sha, '4', prefix, tostring(os.time()),
-      tostring(lock_expire), rspamd_util.get_hostname()}
+      tostring(rule.lock_expire), rspamd_util.get_hostname()}
   )
 end
 
-local function maybe_train_fanns(cfg, ev_base)
+local function maybe_train_fanns(rule, cfg, ev_base)
   local function members_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
     elseif type(data) == 'table' then
       fun.each(function(elt)
         elt = tostring(elt)
-        local prefix = gen_fann_prefix(elt)
+        local prefix = gen_fann_prefix(rule, elt)
         local redis_len_cb = function(_err, _data)
           if _err then
-            rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', prefix, _err)
+            rspamd_logger.errx(rspamd_config,
+              'cannot get FANN trains %s from redis: %s', prefix, _err)
           elseif _data and type(_data) == 'number' or type(_data) == 'string' then
             if tonumber(_data) and tonumber(_data) >= max_trains then
-              rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)',
+              rspamd_logger.infox(rspamd_config,
+                'need to learn ANN %s after %s learn vectors (%s required)',
                 prefix, tonumber(_data), max_trains)
-              train_fann(cfg, ev_base, elt)
+              train_fann(rule, cfg, ev_base, elt)
             end
           end
         end
 
         rspamd_redis.redis_make_request_taskless(ev_base,
           rspamd_config,
-          redis_params,
+          rule.redis,
           nil,
           false, -- is write
           redis_len_cb, --callback
@@ -766,18 +782,18 @@ local function maybe_train_fanns(cfg, ev_base)
   -- First we need to get all fanns stored in our Redis
   rspamd_redis.redis_make_request_taskless(ev_base,
     rspamd_config,
-    redis_params,
+    rule.redis,
     nil,
     false, -- is write
     members_cb, --callback
     'SMEMBERS', -- command
-    {gen_fann_prefix(nil)} -- arguments
+    {gen_fann_prefix(rule, nil)} -- arguments
   )
 
-  return watch_interval
+  return rule.watch_interval
 end
 
-local function check_fanns(_, ev_base)
+local function check_fanns(rule, _, ev_base)
   local function members_cb(err, data)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
@@ -791,7 +807,7 @@ local function check_fanns(_, ev_base)
               load_scripts(rspamd_config, ev_base, nil)
             end
           elseif _data and type(_data) == 'table' then
-            load_or_invalidate_fann(_data, elt, ev_base)
+            load_or_invalidate_fann(rule, _data, elt, ev_base)
           end
         end
 
@@ -803,12 +819,12 @@ local function check_fanns(_, ev_base)
         end
         rspamd_redis.redis_make_request_taskless(ev_base,
           rspamd_config,
-          redis_params,
+          rule.redis,
           nil,
           false, -- is write
           redis_update_cb, --callback
           'EVALSHA', -- command
-          {redis_maybe_load_sha, 2, gen_fann_prefix(elt), tostring(local_ver)}
+          {redis_maybe_load_sha, 2, gen_fann_prefix(rule, elt), tostring(local_ver)}
         )
       end,
       data)
@@ -822,27 +838,32 @@ local function check_fanns(_, ev_base)
   -- First we need to get all fanns stored in our Redis
   rspamd_redis.redis_make_request_taskless(ev_base,
     rspamd_config,
-    redis_params,
+    rule.redis,
     nil,
     false, -- is write
     members_cb, --callback
     'SMEMBERS', -- command
-    {gen_fann_prefix(nil)} -- arguments
+    {gen_fann_prefix(rule, nil)} -- arguments
   )
 
-  return watch_interval
+  return rule.watch_interval
 end
 
 local function ann_push_vector(task)
   local scores = task:get_metric_score()
   local sid = task:get_settings_id()
-  if use_settings then
-    fann_train_callback(task, scores[1], scores[2], tostring(sid))
-  else
-    fann_train_callback(task, scores[1], scores[2], "1")
+
+  for _,rule in settings.rules do
+    if rule.use_settings then
+      fann_train_callback(rule, task, scores[1], scores[2], tostring(sid))
+    else
+      fann_train_callback(rule, task, scores[1], scores[2], "0")
+    end
   end
 end
 
+redis_params = rspamd_parse_redis_server('fann_redis')
+
 -- Initialization part
 if not (opts and type(opts) == 'table') or not redis_params then
   rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
@@ -854,72 +875,75 @@ if not rspamd_fann.is_enabled() then
     'module is eventually disabled')
   return
 else
-  use_settings = opts['use_settings']
-  if opts['spam_symbol'] then
-    fann_symbol_spam = opts['spam_symbol']
-  end
-  if opts['ham_symbol'] then
-    fann_symbol_ham = opts['ham_symbol']
-  end
-  if opts['prefix'] then
-    fann_prefix = opts['prefix']
-  end
-  if opts['lock_expire'] then
-    lock_expire = tonumber(opts['lock_expire'])
+  local rules = opts['rules']
+
+  if not rules then
+    -- Use legacy configuration
+    rules = {}
+    rules['RFANN'] = opts
   end
-  rspamd_config:set_metric_symbol({
-    name = fann_symbol_spam,
-    score = 3.0,
-    description = 'Neural network SPAM',
-    group = 'fann'
-  })
+
   local id = rspamd_config:register_symbol({
-    name = fann_symbol_spam,
+    name = 'FANN_CHECK',
     type = 'postfilter,nostat',
     priority = 6,
     callback = fann_scores_filter
   })
-  rspamd_config:set_metric_symbol({
-    name = fann_symbol_ham,
-    score = -2.0,
-    description = 'Neural network HAM',
-    group = 'fann'
-  })
-  rspamd_config:register_symbol({
-    name = fann_symbol_ham,
-    type = 'virtual,nostat',
-    parent = id
-  })
-  if opts['train'] then
-    if opts['train']['max_train'] then
-      max_trains = opts['train']['max_train']
-    end
-    if opts['train']['max_epoch'] then
-      max_epoch = opts['train']['max_epoch']
-    end
-    if opts['train']['max_usages'] then
-      max_usages = opts['train']['max_usages']
+
+  for k,r in rules do
+    rules[k] = default_options
+    rules[k]['redis'] = redis_params
+    local cur = rules[k]
+    -- Override defaults
+    for sk,v in r do
+      cur[sk] = v
     end
-    if opts['train']['mse'] then
-      mse = opts['train']['mse']
+    if not cur.prefix then
+      cur.prefix = k
     end
+    rspamd_config:set_metric_symbol({
+      name = cur.symbol_spam,
+      score = 3.0,
+      description = 'Neural network SPAM',
+      group = 'fann'
+    })
+
+    rspamd_config:set_metric_symbol({
+      name = cur.symbol_ham,
+      score = -2.0,
+      description = 'Neural network HAM',
+      group = 'fann'
+    })
     rspamd_config:register_symbol({
-      name = 'FANN_VECTOR_PUSH',
-      type = 'postfilter,nostat',
-      priority = 5,
-      callback = ann_push_vector
+      name = cur.symbol_ham,
+      type = 'virtual,nostat',
+      parent = id
     })
   end
+
+  rspamd_config:register_symbol({
+    name = 'FANN_VECTOR_PUSH',
+    type = 'postfilter,nostat',
+    priority = 5,
+    callback = ann_push_vector
+  })
+
+  settings.rules = rules
+
   -- Add training scripts
-  rspamd_config:add_on_load(function(cfg, ev_base, worker)
-    load_scripts(cfg, ev_base, check_fanns)
-
-    if worker:get_name() == 'normal' then
-      -- We also want to train neural nets when they have enough data
-      rspamd_config:add_periodic(ev_base, 0.0,
-        function(_cfg, _ev_base)
-          return maybe_train_fanns(_cfg, _ev_base)
-        end)
-    end
-  end)
+  for k,rule in settings.rules do
+    rspamd_config:add_on_load(function(cfg, ev_base, worker)
+      load_scripts(cfg, ev_base, function(cfg, ev_base)
+          check_fanns(rule, cfg, ev_base)
+      end)
+
+      if worker:get_name() == 'normal' then
+        -- We also want to train neural nets when they have enough data
+        rspamd_config:add_periodic(ev_base, 0.0,
+          function(_cfg, _ev_base)
+            return maybe_train_fanns(rule, _cfg, _ev_base)
+          end)
+      end
+    end)
+  end
 end