]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add framework to manage Redis scripts
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Dec 2017 15:40:37 +0000 (15:40 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 16 Dec 2017 15:40:37 +0000 (15:40 +0000)
lualib/lua_redis.lua

index 25f0078baee9c0248848881e99971d5ed9a970fd..d88d174894e2da72130076ccedce815c5ee37e79 100644 (file)
@@ -606,4 +606,161 @@ end
 exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
 exports.redis_make_request_taskless = redis_make_request_taskless
 
+local redis_scripts = {
+}
+
+local function load_redis_script(script, cfg, ev_base, _)
+  local function merge_tables(t1, t2)
+    for k,v in pairs(t2) do t1[k] = v end
+  end
+
+  local function set_loaded()
+    if script.sha then
+      script.loaded = true
+    end
+
+    local wait_table = {}
+    for _,s in ipairs(script.waitq) do
+      table.insert(wait_table, s)
+    end
+
+    script.waitq = {}
+
+    for _,s in ipairs(wait_table) do
+      s(script)
+    end
+  end
+  local servers = {}
+
+  if script.redis_params.read_servers then
+    merge_tables(servers, script.redis_params.read_servers:all_upstreams())
+  end
+  if script.redis_params.write_servers then
+    merge_tables(servers, script.redis_params.write_servers:all_upstreams())
+  end
+
+  -- Call load script on each server, set loaded flag
+  script.in_flight = #servers
+  for _,s in ipairs(servers) do
+    local function script_cb(err, data)
+      if err then
+        s:fail()
+      else
+        s:ok()
+        script.sha = data -- We assume that sha is the same on all servers
+      end
+      script.in_flight = script.in_flight - 1
+
+      if script.in_flight == 0 then
+        set_loaded(script)
+      end
+    end
+
+    local rspamd_redis = require "rspamd_redis"
+
+    local options = {
+      ev_base = ev_base,
+      config = cfg,
+      callback = script_cb,
+      host = s:get_addr(),
+      timeout = script.redis_params['timeout'],
+      cmd = 'SCRIPT',
+      args = {'LOAD', script.script}
+    }
+
+    if script.redis_params['password'] then
+      options['password'] = script.redis_params['password']
+    end
+
+    if script.redis_params['db'] then
+      options['dbname'] = script.redis_params['db']
+    end
+
+    local ret = rspamd_redis.make_request(options)
+    if not ret then
+      logger.errx('cannot execute redis request to load script')
+      script.in_flight = script.in_flight - 1
+    end
+  end
+
+  if script.in_flight == 0 then
+    set_loaded(script)
+  end
+end
+
+local function add_redis_script(script, redis_params)
+  local new_script = {
+    loaded = false,
+    redis_params = redis_params,
+    script = script,
+    waitq = {}, -- callbacks pending for script being loaded
+    id = #redis_scripts + 1
+  }
+
+  -- Register on load function
+  rspamd_config:add_on_load(function(cfg, ev_base, worker)
+    load_redis_script(new_script, cfg, ev_base, worker)
+  end)
+
+  table.insert(redis_scripts, new_script)
+
+  return #redis_scripts
+end
+exports.add_redis_script = add_redis_script
+
+local function exec_redis_script(id, params, callback, args)
+  if not redis_scripts[id] then
+    return false
+  end
+
+  local script = redis_scripts[id]
+
+  local function do_call()
+    local function redis_cb(err, data)
+      if not err then
+        callback(err, data)
+      elseif err == 'NOSCRIPT' then
+        -- Schedule restart
+        table.insert(script.waitq, do_call)
+        if script.in_flight ~= 0 then
+          -- Reload scripts if this has not been initiated yet
+          if params.task then
+            load_redis_script(script, rspamd_config,
+              params.task:get_ev_base(), nil)
+          else
+            load_redis_script(script, rspamd_config,
+              params.ev_base, nil)
+          end
+        end
+      else
+        callback(err, data)
+      end
+    end
+
+    if params.task then
+      if not rspamd_redis_make_request(params.task, script.redis_params,
+        params.key, params.is_write, redis_cb, 'EVALSHA', args) then
+        callback('Cannot make redis request', nil)
+      end
+    else
+      if not redis_make_request_taskless(params.ev_base, rspamd_config,
+        script.redis_params,
+        params.key, params.is_write, redis_cb, 'EVALSHA', args) then
+        callback('Cannot make redis request', nil)
+      end
+    end
+  end
+
+  if not script.loaded then
+    do_call()
+  else
+    -- Delayed until scripts are loaded
+    table.insert(script.waitq, do_call)
+  end
+
+  return true
+end
+
+exports.exec_redis_script = exec_redis_script
+
 return exports