]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Rework fann_redis to use redis scripts framework
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Dec 2017 18:03:52 +0000 (18:03 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Dec 2017 18:03:52 +0000 (18:03 +0000)
src/plugins/lua/fann_redis.lua

index 2b9b06f28dbbdff2e0161e9413486d0ea01b4f04..3120d8b181db07ce73f2d26bc79c998379ec65a2 100644 (file)
@@ -115,7 +115,7 @@ local redis_lua_script_can_train = [[
 
   return tostring(0)
 ]]
-local redis_can_train_sha = nil
+local redis_can_train_id = nil
 
 -- Lua script to load ANN from redis
 -- Uses the following keys
@@ -132,7 +132,7 @@ local redis_lua_script_maybe_load = [[
 
   return tonumber(ret) or 0
 ]]
-local redis_maybe_load_sha = nil
+local redis_maybe_load_id = nil
 
 -- Lua script to invalidate ANN from redis
 -- Uses the following keys
@@ -149,7 +149,7 @@ local redis_lua_script_maybe_invalidate = [[
   redis.call('DEL', KEYS[1] .. '_hostname')
   return 1
 ]]
-local redis_maybe_invalidate_sha = nil
+local redis_maybe_invalidate_id = nil
 
 -- Lua script to invalidate ANN from redis
 -- Uses the following keys
@@ -163,7 +163,7 @@ local redis_lua_script_locked_invalidate = [[
   redis.call('DEL', KEYS[1] .. '_hostname')
   return 1
 ]]
-local redis_locked_invalidate_sha = nil
+local redis_locked_invalidate_id = nil
 
 -- Lua script to invalidate ANN from redis
 -- Uses the following keys
@@ -182,7 +182,7 @@ local redis_lua_script_maybe_lock = [[
   redis.call('SET', KEYS[1] .. '_hostname', KEYS[4])
   return 1
 ]]
-local redis_maybe_lock_sha = nil
+local redis_maybe_lock_id = nil
 
 -- Lua script to save and unlock ANN in redis
 -- Uses the following keys
@@ -200,119 +200,23 @@ local redis_lua_script_save_unlock = [[
   redis.call('EXPIRE', KEYS[1] .. '_version', KEYS[3])
   return 1
 ]]
-local redis_save_unlock_sha = nil
+local redis_save_unlock_id = nil
 
 local redis_params
 
-local function load_scripts(cfg, ev_base, on_load_cb)
-  local function can_train_sha_cb(err, data)
-    if err or not data or type(data) ~= 'string' then
-      rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err)
-    else
-      redis_can_train_sha = tostring(data)
-    end
-  end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    redis_params,
-    nil,
-    true, -- is write
-    can_train_sha_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', redis_lua_script_can_train} -- arguments
-  )
-
-  local function maybe_load_sha_cb(err, data)
-    if err or not data or type(data) ~= 'string' then
-      rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err)
-    else
-      redis_maybe_load_sha = tostring(data)
-
-      if on_load_cb then
-        rspamd_config:add_periodic(ev_base, 0.0,
-          function(_cfg, _ev_base)
-            return on_load_cb(_cfg, _ev_base)
-          end)
-      end
-    end
-  end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    redis_params,
-    nil,
-    true, -- is write
-    maybe_load_sha_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', redis_lua_script_maybe_load} -- arguments
-  )
-
-  local function maybe_invalidate_sha_cb(err, data)
-    if err or not data or type(data) ~= 'string' then
-      rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err)
-    else
-      redis_maybe_invalidate_sha = tostring(data)
-    end
-  end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    redis_params,
-    nil,
-    true, -- is write
-    maybe_invalidate_sha_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', redis_lua_script_maybe_invalidate} -- arguments
-  )
-
-  local function locked_invalidate_sha_cb(err, data)
-    if err or not data or type(data) ~= 'string' then
-      rspamd_logger.errx(cfg, 'cannot save redis locked invalidate script: %s', err)
-    else
-      redis_locked_invalidate_sha = tostring(data)
-    end
-  end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    redis_params,
-    nil,
-    true, -- is write
-    locked_invalidate_sha_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', redis_lua_script_locked_invalidate} -- arguments
-  )
-
-  local function maybe_lock_sha_cb(err, data)
-    if err or not data or type(data) ~= 'string' then
-      rspamd_logger.errx(cfg, 'cannot save redis lock script: %s', err)
-    else
-      redis_maybe_lock_sha = tostring(data)
-    end
-  end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    redis_params,
-    nil,
-    true, -- is write
-    maybe_lock_sha_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', redis_lua_script_maybe_lock} -- arguments
-  )
-
-  local function save_unlock_sha_cb(err, data)
-    if err or not data or type(data) ~= 'string' then
-      rspamd_logger.errx(cfg, 'cannot save redis save script: %s', err)
-    else
-      redis_save_unlock_sha = tostring(data)
-    end
-  end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    redis_params,
-    nil,
-    true, -- is write
-    save_unlock_sha_cb, --callback
-    'SCRIPT', -- command
-    {'LOAD', redis_lua_script_save_unlock} -- arguments
-  )
+local function load_scripts(params)
+  redis_can_train_id = rspamd_redis.add_redis_script(redis_lua_script_can_train,
+    params)
+  redis_maybe_load_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_load,
+    params)
+  redis_maybe_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_invalidate,
+    params)
+  redis_locked_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_locked_invalidate,
+    params)
+  redis_maybe_lock_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_lock,
+    params)
+  redis_save_unlock_id = rspamd_redis.add_redis_script(redis_lua_script_save_unlock,
+    params)
 end
 
 local function gen_fann_prefix(rule, id)
@@ -490,9 +394,6 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
     local function redis_invalidate_cb(_err, _data)
       if _err then
         rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
-        if string.match(_err, 'NOSCRIPT') then
-          load_scripts(rspamd_config, ev_base, nil)
-        end
       elseif type(_data) == 'string' then
         rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
         fanns[id].version = 0
@@ -500,15 +401,10 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
     end
     -- Invalidate ANN
     rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix)
-    rspamd_redis.redis_make_request_taskless(ev_base,
-      rspamd_config,
-      rule.redis,
-      nil,
-      true, -- is write
-      redis_invalidate_cb, --callback
-      'EVALSHA', -- command
-      {redis_maybe_invalidate_sha, 1, prefix}
-    )
+    rspamd_redis.exec_redis_script(redis_maybe_invalidate_id,
+      {ev_base = ev_base, is_write = true},
+      redis_invalidate_cb,
+      {prefix})
   end
 end
 
@@ -589,9 +485,6 @@ local function fann_train_callback(rule, task, score, required_score, id)
       else
         if err then
           rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err)
-          if string.match(err, 'NOSCRIPT') then
-            load_scripts(rspamd_config, task:get_ev_base(), nil)
-          end
         elseif tonumber(data) < 0 then
           rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s",
             fname, k, -tonumber(data))
@@ -599,15 +492,10 @@ local function fann_train_callback(rule, task, score, required_score, id)
       end
     end
 
-    rspamd_redis.rspamd_redis_make_request(task,
-      rule.redis,
-      nil,
-      true, -- is write
-      can_train_cb, --callback
-      'EVALSHA', -- command
-      {redis_can_train_sha, '4', gen_fann_prefix(rule, nil),
-        suffix, k, tostring(train_opts.max_trains)} -- arguments
-    )
+    rspamd_redis.exec_redis_script(redis_can_train_id,
+      {task = task, is_write = true},
+      can_train_cb,
+      {gen_fann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)})
   end
 end
 
@@ -637,9 +525,6 @@ local function train_fann(rule, _, ev_base, elt, worker)
         'DEL', -- command
         {prefix .. '_locked'}
       )
-      if string.match(err, 'NOSCRIPT') then
-        load_scripts(rspamd_config, ev_base, nil)
-      end
     else
       rspamd_logger.infox(rspamd_config, 'saved ANN %s, key: %s_data', elt, prefix)
     end
@@ -674,15 +559,10 @@ local function train_fann(rule, _, ev_base, elt, worker)
       fanns[elt].version = fanns[elt].version + 1
       fanns[elt].fann = fanns[elt].fann_train
       fanns[elt].fann_train = nil
-      rspamd_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        true, -- is write
-        redis_save_cb, --callback
-        'EVALSHA', -- command
-        {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
-      )
+      rspamd_redis.exec_redis_script(redis_save_unlock_id,
+        {ev_base = ev_base, is_write = true},
+        redis_save_cb,
+        {prefix, tostring(ann_data), tostring(rule.ann_expire)})
     end
   end
 
@@ -711,15 +591,10 @@ local function train_fann(rule, _, ev_base, elt, worker)
       fanns[elt].version = fanns[elt].version + 1
       fanns[elt].fann = fanns[elt].fann_train
       fanns[elt].fann_train = nil
-      rspamd_redis.redis_make_request_taskless(ev_base,
-        rspamd_config,
-        rule.redis,
-        nil,
-        true, -- is write
-        redis_save_cb, --callback
-        'EVALSHA', -- command
-        {redis_save_unlock_sha, '3', prefix, tostring(ann_data), tostring(rule.ann_expire)}
-      )
+      rspamd_redis.exec_redis_script(redis_save_unlock_id,
+        {ev_base = ev_base, is_write = true},
+        redis_save_cb,
+        {prefix, tostring(ann_data), tostring(rule.ann_expire)})
     end
   end
 
@@ -768,15 +643,10 @@ local function train_fann(rule, _, ev_base, elt, worker)
         end
         -- Invalidate ANN
         rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix)
-        rspamd_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          true, -- is write
-          redis_invalidate_cb, --callback
-          'EVALSHA', -- command
-          {redis_locked_invalidate_sha, 1, prefix}
-        )
+        rspamd_redis.exec_redis_script(redis_locked_invalidate_id,
+          {ev_base = ev_base, is_write = true},
+          redis_invalidate_cb,
+          {prefix})
       else
         if use_torch then
           -- For torch we do not need to mix samples as they would be flushed
@@ -874,9 +744,6 @@ local function train_fann(rule, _, ev_base, elt, worker)
     if err then
       rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
         prefix, err)
-      if string.match(err, 'NOSCRIPT') then
-        load_scripts(rspamd_config, ev_base, nil)
-      end
     elseif type(data) == 'number' then
       -- Can train ANN
       rspamd_redis.redis_make_request_taskless(ev_base,
@@ -926,16 +793,10 @@ local function train_fann(rule, _, ev_base, elt, worker)
     rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix)
     return
   end
-  rspamd_redis.redis_make_request_taskless(ev_base,
-    rspamd_config,
-    rule.redis,
-    nil,
-    true, -- is write
-    redis_lock_cb, --callback
-    'EVALSHA', -- command
-    {redis_maybe_lock_sha, '4', prefix, tostring(os.time()),
-      tostring(rule.lock_expire), rspamd_util.get_hostname()}
-  )
+  rspamd_redis.exec_redis_script(redis_maybe_lock_id,
+    {ev_base = ev_base, is_write = true},
+    redis_lock_cb,
+    {prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()})
 end
 
 local function maybe_train_fanns(rule, cfg, ev_base, worker)
@@ -979,10 +840,6 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
     end
   end
 
-  if not redis_maybe_load_sha then
-    -- Plan new event early
-    return 1.0
-  end
   -- First we need to get all fanns stored in our Redis
   rspamd_redis.redis_make_request_taskless(ev_base,
     rspamd_config,
@@ -1009,9 +866,6 @@ local function check_fanns(rule, _, ev_base)
           if _err then
             rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s',
               elt, _err)
-            if string.match(_err, 'NOSCRIPT') then
-              load_scripts(rspamd_config, ev_base, nil)
-            end
           elseif _data and type(_data) == 'table' then
             load_or_invalidate_fann(rule, _data, elt, ev_base)
           else
@@ -1028,24 +882,15 @@ local function check_fanns(rule, _, ev_base)
             local_ver = fanns[elt].version
           end
         end
-        rspamd_redis.redis_make_request_taskless(ev_base,
-          rspamd_config,
-          rule.redis,
-          nil,
-          false, -- is write
-          redis_update_cb, --callback
-          'EVALSHA', -- command
-          {redis_maybe_load_sha, 2, gen_fann_prefix(rule, elt), tostring(local_ver)}
-        )
+        rspamd_redis.exec_redis_script(redis_maybe_load_id,
+          {ev_base = ev_base, is_write = false},
+          redis_update_cb,
+          {gen_fann_prefix(rule, elt), tostring(local_ver)})
       end,
       data)
     end
   end
 
-  if not redis_maybe_load_sha then
-    -- Plan new event early
-    return 1.0
-  end
   -- First we need to get all fanns stored in our Redis
   rspamd_redis.redis_make_request_taskless(ev_base,
     rspamd_config,
@@ -1187,10 +1032,9 @@ else
 
   -- Add training scripts
   for _,rule in pairs(settings.rules) do
+    load_scripts(rule.redis)
     rspamd_config:add_on_load(function(cfg, ev_base, worker)
-      load_scripts(cfg, ev_base, function(_, _)
-          return check_fanns(rule, cfg, ev_base)
-      end)
+      check_fanns(rule, cfg, ev_base)
 
       if worker:get_name() == 'controller' and worker:get_index() == 0 then
         -- We also want to train neural nets when they have enough data