From: Vsevolod Stakhov Date: Thu, 5 Apr 2018 19:22:20 +0000 (+0100) Subject: [Rework] Restore leaky bucket model in ratelimit plugin X-Git-Tag: 1.7.3~27 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=1f0732d5bcd916c8d005800eeef98c2257d3e822;p=thirdparty%2Frspamd.git [Rework] Restore leaky bucket model in ratelimit plugin --- diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index 8e5cab328d..54c40dc771 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -21,27 +21,124 @@ end -- A plugin that implements ratelimits using redis -local E, settings = {}, {} +local E = {} local N = 'ratelimit' +local redis_params -- Senders that are considered as bounce -local bounce_senders = {'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon'} +local settings = { + bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' }, -- Do not check ratelimits for these recipients -local whitelisted_rcpts = {'postmaster', 'mailer-daemon'} -local whitelisted_ip -local whitelisted_user -local max_rcpt = 5 -local redis_params -local ratelimit_symbol -local info_symbol --- Do not delay mail after 1 day -local use_ip_score = false -local rl_prefix = 'RL' -local ip_score_lower_bound = 10 -local ip_score_ham_multiplier = 1.1 -local ip_score_spam_divisor = 1.1 -local limits_hash - -local message_func = function(_, limit_type) + whitelisted_rcpts = { 'postmaster', 'mailer-daemon' }, + whitelisted_ip = {}, + whitelisted_user = {}, + prefix = 'RL', + ham_factor_rate = 1.01, + spam_factor_rate = 0.99, + ham_factor_burst = 1.02, + spam_factor_burst = 0.98, + max_rate_mult = 5, + max_bucket_mult = 10, + expire = 60 * 60 * 24 * 2, -- 2 days by default + limits = {}, + allow_local = false, +} + +-- Checks bucket, updating it if needed +-- KEYS[1] - prefix to update, e.g. RL__ +-- KEYS[2] - current time in milliseconds +-- KEYS[3] - bucket leak rate (messages per millisecond) +-- KEYS[4] - bucket burst +-- KEYS[5] - expire for a bucket +-- return 1 if message should be ratelimited and 0 if not +-- Redis keys used: +-- l - last hit +-- b - current burst +-- dr - current dynamic rate multiplier (*10000) +-- db - current dynamic burst multiplier (*10000) +local bucket_check_script = [[ + local last = redis.call('HGET', KEYS[1], 'l') + local now = tonumber(KEYS[2]) + if not last then + -- New bucket + redis.call('HSET', KEYS[1], 'l', KEYS[2]) + redis.call('HSET', KEYS[1], 'b', '0') + redis.call('HSET', KEYS[1], 'dr', '10000') + redis.call('HSET', KEYS[1], 'db', '10000') + redis.call('EXPIRE', KEYS[1], KEYS[5]) + return 0 + end + + last = tonumber(last) + local burst = tonumber(redis.call('HGET', KEYS[1], 'b')) + -- Perform leak + if burst > 0 then + if last < tonumber(KEYS[2]) then + local rate = tonumber(KEYS[3]) + local dyn = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0 + rate = rate * dyn + redis.call('HINCRBYFLOAT', KEYS[1], 'b', -((now - last) * rate)) + end + local dyn = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0 + + if tonumber(burst) * tonumber(dyn) > tonumber(KEYS[4]) then + return 1 + end + else + redis.call('HSET', KEYS[1], 'b', '0') + end + + return 0 +]] +local bucket_check_id + + +-- Updates a bucket +-- KEYS[1] - prefix to update, e.g. RL__ +-- KEYS[2] - current time in milliseconds +-- KEYS[3] - dynamic rate multiplier +-- KEYS[4] - dynamic burst multiplier +-- KEYS[5] - max dyn rate (min: 1/x) +-- KEYS[6] - max burst rate (min: 1/x) +-- KEYS[7] - expire for a bucket +-- Redis keys used: +-- l - last hit +-- b - current burst +-- dr - current dynamic rate multiplier +-- db - current dynamic burst multiplier +local bucket_update_script = [[ + local last = redis.call('HGET', KEYS[1], 'l') + local now = tonumber(KEYS[2]) + if not last then + -- New bucket + redis.call('HSET', KEYS[1], 'l', KEYS[2]) + redis.call('HSET', KEYS[1], 'b', '1') + redis.call('HSET', KEYS[1], 'dr', '10000') + redis.call('HSET', KEYS[1], 'db', '10000') + redis.call('EXPIRE', KEYS[1], KEYS[7]) + return + end + + local burst = tonumber(redis.call('HGET', KEYS[1], 'b')) + local db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000 + local dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000 + + if dr < tonumber(KEYS[5]) and dr > 1.0 / tonumber(KEYS[5]) then + dr = dr * tonumber(KEYS[3]) + redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000))) + end + + if db < tonumber(KEYS[6]) and db > 1.0 / tonumber(KEYS[6]) then + db = db * tonumber(KEYS[4]) + redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000))) + end + + redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1) + redis.call('HSET', KEYS[1], 'l', KEYS[2]) + redis.call('EXPIRE', KEYS[1], KEYS[7]) +]] +local bucket_update_id + +local message_func = function(_, limit_type, _) return string.format('Ratelimit "%s" exceeded', limit_type) end @@ -53,120 +150,11 @@ local fun = require "fun" local lua_maps = require "lua_maps" local lua_util = require "lua_util" -local user_keywords = {'user'} - -local redis_script_id -local redis_script = [[local bucket -local limited = false -local buckets = {} -local queue_id = table.remove(ARGV) -local now = table.remove(ARGV) - -local argi = 0 -for i = 1, #KEYS do - local key = KEYS[i] - local period = tonumber(ARGV[argi+1]) - local limit = tonumber(ARGV[argi+2]) - if not buckets[key] then - buckets[key] = { - max_period = period, - limits = { {period, limit} }, - } - else - table.insert(buckets[key].limits, {period, limit}) - if period > buckets[key].max_period then - buckets[key].max_period = period - end - end - argi = argi + 2 -end -for k, v in pairs(buckets) do - local maxp = v.max_period - redis.call('ZREMRANGEBYSCORE', k, '-inf', now - maxp) - for _, lim in ipairs(v.limits) do - local period = lim[1] - local limit = lim[2] - local rate - if period == maxp then - rate = redis.call('ZCARD', k) - else - rate = redis.call('ZCOUNT', k, now - period, '+inf') - end - if rate and rate >= limit then - limited = true - bucket = k - end - end - redis.call('EXPIRE', k, maxp) - if limited then break end -end - -if not limited then - for k in pairs(buckets) do - redis.call('ZADD', k, now, queue_id) - end -end - -return {limited, bucket}]] - -local redis_script_symbol = [[local limited = false -local buckets, results = {}, {} -local queue_id = table.remove(ARGV) -local now = table.remove(ARGV) - -local argi = 0 -for i = 1, #KEYS do - local key = KEYS[i] - local period = tonumber(ARGV[argi+1]) - local limit = tonumber(ARGV[argi+2]) - if not buckets[key] then - buckets[key] = { - max_period = period, - limits = { {period, limit} }, - } - else - table.insert(buckets[key].limits, {period, limit}) - if period > buckets[key].max_period then - buckets[key].max_period = period - end - end - argi = argi + 2 -end - -for k, v in pairs(buckets) do - local maxp = v.max_period - redis.call('ZREMRANGEBYSCORE', k, '-inf', now - maxp) - for _, lim in ipairs(v.limits) do - local period = lim[1] - local limit = lim[2] - local rate - if period == maxp then - rate = redis.call('ZCARD', k) - else - rate = redis.call('ZCOUNT', k, now - period, '+inf') - end - if rate then - local mult = 2 * math.tanh(rate / (limit * 2)) - if mult >= 0.5 then - table.insert(results, {k, tostring(mult)}) - end - end - end - redis.call('ZADD', k, now, queue_id) - redis.call('EXPIRE', k, maxp) -end - -return results]] local function load_scripts(cfg, ev_base) - local script - if ratelimit_symbol then - script = redis_script_symbol - else - script = redis_script - end - redis_script_id = lua_redis.add_redis_script(script, redis_params) + bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params) + bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params) end local limit_parser @@ -233,43 +221,23 @@ local function parse_string_limit(lim, no_error) return nil end -local function resize_element(x_score, x_total, element) - local x_ip_score - if not x_total then x_total = 0 end - if x_total < ip_score_lower_bound or x_total <= 0 then - x_score = 1 - else - x_score = x_score / x_total - end - if x_score > 0 then - x_ip_score = x_score / ip_score_spam_divisor - element = element * rspamd_util.tanh(2.718281 * x_ip_score) - elseif x_score < 0 then - x_ip_score = ((1 + (x_score * -1)) * ip_score_ham_multiplier) - element = element * x_ip_score - end - return element -end - --- Check whether this addr is bounce local function check_bounce(from) - return fun.any(function(b) return b == from end, bounce_senders) + return fun.any(function(b) return b == from end, settings.bounce_senders) end -local custom_keywords = {} - local keywords = { ['ip'] = { - ['get_value'] = function(task) + get_value = function(task) local ip = task:get_ip() - if ip and ip:is_valid() then return ip end + if ip and ip:is_valid() then return tostring(ip) end return nil end, }, ['rip'] = { ['get_value'] = function(task) local ip = task:get_ip() - if ip and ip:is_valid() and not ip:is_local() then return ip end + if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end return nil end, }, @@ -312,140 +280,44 @@ local keywords = { end, }, ['to'] = { - ['get_value'] = function() - return '%s' -- 'to' is special + ['get_value'] = function(task) + return task:get_principal_recipient() end, }, } -local function dynamic_rate_key(task, rtype) - local key_t = {rl_prefix, rtype} - local key_keywords = rspamd_str_split(rtype, '_') - local have_to, have_user = false, false +local function gen_rate_key(task, rtype, bucket) + local key_t = {settings.prefix, tostring(lua_util.round(100000.0 / bucket[1]))} + local key_keywords = lua_util.str_split(rtype, '_') + local have_user = false + for _, v in ipairs(key_keywords) do - if (custom_keywords[v] and type(custom_keywords[v]['condition']) == 'function') then - if not custom_keywords[v]['condition']() then return nil end - end local ret - if custom_keywords[v] and type(custom_keywords[v]['get_value']) == 'function' then - ret = custom_keywords[v]['get_value'](task) - elseif keywords[v] and type(keywords[v]['get_value']) == 'function' then + + if keywords[v] and type(keywords[v]['get_value']) == 'function' then ret = keywords[v]['get_value'](task) end if not ret then return nil end - for _, uk in ipairs(user_keywords) do - if v == uk then have_user = true end - if have_user then break end - end - if v == 'to' then have_to = true end + if v == 'user' then have_user = true end if type(ret) ~= 'string' then ret = tostring(ret) end table.insert(key_t, ret) end - if (not have_user) and task:get_user() then + + if have_user and not task:get_user() then return nil end - if not have_to then - return table.concat(key_t, ":") - else - local rate_keys = {} - local rcpts = task:get_recipients(0) - if not ((rcpts or E)[1] or E).addr then - return nil - end - local key_s = table.concat(key_t, ":") - local total_rcpt = 0 - for _, r in ipairs(rcpts) do - if r['addr'] and total_rcpt < max_rcpt then - local key_f = string.format(key_s, string.lower(r['addr'])) - table.insert(rate_keys, key_f) - total_rcpt = total_rcpt + 1 - end - end - return rate_keys - end -end -local function process_buckets(task, buckets) - if not buckets then return end - local function rl_redis_cb(err, data) - if err then - rspamd_logger.infox(task, 'got error while setting limit: %1', err) - end - if not data then return end - if data[1] == 1 then - if info_symbol then - task:insert_result(info_symbol, 1.0, data[2]) - end - rspamd_logger.infox(task, - 'ratelimit "%s" exceeded', - data[2]) - task:set_pre_result('soft reject', - message_func(task, data[2])) - end - end - local function rl_symbol_redis_cb(err, data) - if err then - rspamd_logger.infox(task, 'got error while setting limit: %1', err) - end - if not data then return end - for i, b in ipairs(data) do - task:insert_result(ratelimit_symbol, b[2], string.format('%s:%s:%s', i, b[1], b[2])) - end - end - local redis_cb = rl_redis_cb - if ratelimit_symbol then redis_cb = rl_symbol_redis_cb end - local kwargs, args = {}, {} - for _, bucket in ipairs(buckets) do - table.insert(kwargs, bucket[2]) - end - for _, bucket in ipairs(buckets) do - if use_ip_score then - local asn_score,total_asn, - country_score,total_country, - ipnet_score,total_ipnet, - ip_score, total_ip = task:get_mempool():get_variable('ip_score', - 'double,double,double,double,double,double,double,double') - local key_keywords = rspamd_str_split(bucket[2], '_') - local has_asn, has_ip = false, false - for _, v in ipairs(key_keywords) do - if v == "asn" then has_asn = true end - if v == "ip" then has_ip = true end - if has_ip and has_asn then break end - end - if has_asn and not has_ip then - bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2]) - elseif has_ip then - if total_ip and total_ip > ip_score_lower_bound then - bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2]) - elseif total_ipnet and total_ipnet > ip_score_lower_bound then - bucket[1][2] = resize_element(ipnet_score, total_ipnet, bucket[1][2]) - elseif total_asn and total_asn > ip_score_lower_bound then - bucket[1][2] = resize_element(asn_score, total_asn, bucket[1][2]) - elseif total_country and total_country > ip_score_lower_bound then - bucket[1][2] = resize_element(country_score, total_country, bucket[1][2]) - else - bucket[1][2] = resize_element(ip_score, total_ip, bucket[1][2]) - end - end - end - table.insert(args, bucket[1][1]) - table.insert(args, bucket[1][2]) - end - table.insert(args, rspamd_util.get_time()) - table.insert(args, task:get_queue_id() or task:get_uid()) - local ret = lua_redis.exec_redis_script(redis_script_id, {task = task, is_write = true}, redis_cb, kwargs, args) - if not ret then - rspamd_logger.errx(task, 'got error connecting to redis') - end + return table.concat(key_t, ":") end local function ratelimit_cb(task) - if rspamd_lua_utils.is_rspamc_or_controller(task) then return end - local args = {} + if not settings.allow_local and + rspamd_lua_utils.is_rspamc_or_controller(task) then return end + -- Get initial task data local ip = task:get_from_ip() - if ip and ip:is_valid() and whitelisted_ip then - if whitelisted_ip:get_key(ip) then + if ip and ip:is_valid() and settings.whitelisted_ip then + if settings.whitelisted_ip:get_key(ip) then -- Do not check whitelisted ip rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP') return @@ -459,131 +331,121 @@ local function ratelimit_cb(task) fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'}) end, rcpts) - if fun.any(function(r) return whitelisted_rcpts:get_key(r) end, rcpts_user) then + if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient') return end end -- Get user (authuser) - if whitelisted_user then + if settings.whitelisted_user then local auser = task:get_user() - if whitelisted_user:get_key(auser) then + if settings.whitelisted_user:get_key(auser) then rspamd_logger.infox(task, 'skip ratelimit for whitelisted user') return end end + -- Now create all ratelimit prefixes + local prefixes = {} + local nprefixes = 0 + + for k,v in pairs(settings.limits) do + for _,bucket in ipairs(v) do + local prefix = gen_rate_key(task, k, bucket) + + if prefix then + prefixes[prefix] = bucket + nprefixes = nprefixes + 1 + end + end + end - local redis_keys = {} - local redis_keys_rev = {} - local function collect_redis_keys() - local function collect_cb(err, data) + local function gen_check_cb(prefix, bucket) + return function(err, data) if err then - rspamd_logger.errx(task, 'redis error: %1', err) - else - for i, d in ipairs(data) do - if type(d) == 'string' then - local plim, size = parse_string_limit(d) - if plim then - table.insert(args, {{plim, size}, redis_keys_rev[i]}) - end - end + rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data) + end + if data and data == 1 then + if settings.info_symbol then + task:insert_result(settings.info_symbol, 1.0, prefix) end - return process_buckets(task, args) + rspamd_logger.infox(task, + 'ratelimit "%s" exceeded, (%s / %s)', + prefix, bucket[2], bucket[1]) + task:set_pre_result('soft reject', + message_func(task, prefix, bucket)) end end - local params, method - if limits_hash then - params = {limits_hash, rspamd_lua_utils.unpack(redis_keys)} - method = 'HMGET' - else - method = 'MGET' - params = redis_keys - end - local requested_keys = rspamd_redis_make_request(task, - redis_params, -- connect params - nil, -- hash key - true, -- is write - collect_cb, --callback - method, -- command - params -- arguments - ) - if not requested_keys then - rspamd_logger.errx(task, 'got error connecting to redis') - return process_buckets(task, args) + end + + if nprefixes > 0 then + -- Save prefixes to the cache to allow update + task:cache_set('ratelimit_prefixes', prefixes) + local now = rspamd_util.get_time() + now = lua_util.round(now * 1000.0) -- Get milliseconds + -- Now call check script for all defined prefixes + + for pr,bucket in pairs(prefixes) do + local rate = (1.0 / bucket[1]) / 1000.0 -- Leak rate in messages/ms + rspamd_logger.debugm(N, task, "check limit %s (%s/%s)", + pr, bucket[2], bucket[1]) + lua_redis.exec_redis_script(bucket_check_id, + {task = task, is_write = true}, + gen_check_cb(pr, bucket), + {pr, tostring(now), tostring(rate), tostring(bucket[2]), + tostring(settings.expire)}) end end +end - local rate_key - for k in pairs(settings) do - rate_key = dynamic_rate_key(task, k) - if rate_key then - if type(rate_key) == 'table' then - for _, rk in ipairs(rate_key) do - if type(settings[k]) == 'string' and - (custom_keywords[settings[k]] and type(custom_keywords[settings[k]]['get_limit']) == 'function') then - local res = custom_keywords[settings[k]]['get_limit'](task) - if type(res) == 'string' then res = {res} end - for _, r in ipairs(res) do - local plim, size = parse_string_limit(r, true) - if plim then - table.insert(args, {{plim, size}, rk}) - else - local rkey = string.match(settings[k], 'redis:(.*)') - if rkey then - table.insert(redis_keys, rkey) - redis_keys_rev[#redis_keys] = rk - else - rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k]) - end - end - end - end - end - else - if type(settings[k]) == 'string' and - (custom_keywords[settings[k]] and type(custom_keywords[settings[k]]['get_limit']) == 'function') then - local res = custom_keywords[settings[k]]['get_limit'](task) - if type(res) == 'string' then res = {res} end - for _, r in ipairs(res) do - local plim, size = parse_string_limit(r, true) - if plim then - table.insert(args, {{plim, size}, rate_key}) - else - local rkey = string.match(r, 'redis:(.*)') - if rkey then - table.insert(redis_keys, rkey) - redis_keys_rev[#redis_keys] = rate_key - else - rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k]) - end - end - end - elseif type(settings[k]) == 'table' then - for _, rl in ipairs(settings[k]) do - table.insert(args, {{rl[1], rl[2]}, rate_key}) - end - elseif type(settings[k]) == 'string' then - local rkey = string.match(settings[k], 'redis:(.*)') - if rkey then - table.insert(redis_keys, rkey) - redis_keys_rev[#redis_keys] = rate_key - else - rspamd_logger.infox(task, "Don't know what to do with limit: %1", settings[k]) - end +local function ratelimit_update_cb(task) + local prefixes = task:cache_get('ratelimit_prefixes') + + if prefixes then + local action = task:get_metric_action() + local is_spam = true + + if action == 'soft reject' then + -- Already rate limited/greylisted, do nothing + rspamd_logger.debugm(N, task, 'already soft rejected, do not update') + elseif action == 'no action' then + is_spam = false + end + + local mult_burst = settings.ham_factor_burst + local mult_rate = settings.ham_factor_burst + + if is_spam then + mult_burst = settings.spam_factor_burst + mult_rate = settings.spam_factor_rate + end + + -- Update each bucket + for k, v in pairs(prefixes) do + local function update_bucket_cb(err, _) + if err then + rspamd_logger.errx(task, 'cannot update rate bucket %s: %s', + k, err) end end + local now = rspamd_util.get_time() + now = lua_util.round(now * 1000.0) -- Get milliseconds + rspamd_logger.debugm(N, task, "update limit %s (%s/%s)", + k, v[2], v[1]) + lua_redis.exec_redis_script(bucket_update_id, + {task = task, is_write = true}, + update_bucket_cb, + {k, tostring(now), tostring(mult_rate), tostring(mult_burst), + tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult), + tostring(settings.expire)}) end end - - if redis_keys[1] then - return collect_redis_keys() - else - return process_buckets(task, args) - end end local opts = rspamd_config:get_all_opt(N) if opts then + + settings = lua_util.override_defaults(settings, opts) + if opts['limit'] then rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported') end @@ -592,12 +454,12 @@ if opts then -- new way of setting limits fun.each(function(t, lim) if type(lim) == 'table' then - settings[t] = {} + settings.limits[t] = {} if #lim == 2 and tonumber(lim[1]) and tonumber(lim[2]) then -- Old style ratelimit rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', t) if tonumber(lim[1]) > 0 and tonumber(lim[2]) > 0 then - table.insert(settings[t], {1.0/lim[2], lim[1]}) + table.insert(settings.limits[t], {1.0/lim[2], lim[1]}) elseif lim[1] ~= 0 then rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', t) else @@ -607,31 +469,24 @@ if opts then fun.each(function(l) local plim, size = parse_string_limit(l) if plim then - table.insert(settings[t], {plim, size}) + table.insert(settings.limits[t], {plim, size}) end end, lim) end elseif type(lim) == 'string' then local plim, size = parse_string_limit(lim) if plim then - settings[t] = { {plim, size} } + settings.limits[t] = { {plim, size} } end end end, opts['rates']) end - if opts['dynamic_rates'] and type(opts['dynamic_rates']) == 'table' then - fun.each(function(t, lim) - if type(lim) == 'string' then - settings[t] = lim - end - end, opts['dynamic_rates']) - end - local enabled_limits = fun.totable(fun.map(function(t) return t end, settings)) - rspamd_logger.infox(rspamd_config, 'enabled rate buckets: [%1]', table.concat(enabled_limits, ',')) + rspamd_logger.infox(rspamd_config, + 'enabled rate buckets: [%1]', table.concat(enabled_limits, ',')) -- Ret, ret, ret: stupid legacy stuff: -- If we have a string with commas then load it as as static map @@ -640,70 +495,41 @@ if opts then local wrcpts = opts['whitelisted_rcpts'] if type(wrcpts) == 'string' then if string.find(wrcpts, ',') then - whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts') else - whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', 'Ratelimit whitelisted rcpts') end elseif type(opts['whitelisted_rcpts']) == 'table' then - whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', 'Ratelimit whitelisted rcpts') else -- Stupid default... - whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(whitelisted_rcpts, 'set', - 'Ratelimit whitelisted rcpts') + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( + opts.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts') end if opts['whitelisted_ip'] then - whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix', + settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix', 'Ratelimit whitelist ip map') end if opts['whitelisted_user'] then - whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set', + settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set', 'Ratelimit whitelist user map') end - if opts['symbol'] then - -- We want symbol instead of pre-result - ratelimit_symbol = opts['symbol'] - end - - if opts['info_symbol'] then - -- We want symbol in addition to pre-result - info_symbol = opts['info_symbol'] - end - - if opts['max_rcpt'] then - max_rcpt = tonumber(opts['max_rcpt']) - end - - if opts['use_ip_score'] then - use_ip_score = true - local ip_score_opts = rspamd_config:get_all_opt('ip_score') - if ip_score_opts and ip_score_opts['lower_bound'] then - ip_score_lower_bound = ip_score_opts['lower_bound'] - end - end - if opts['custom_keywords'] then - custom_keywords = dofile(opts['custom_keywords']) - end - - if opts['user_keywords'] then - user_keywords = opts['user_keywords'] + settings.custom_keywords = dofile(opts['custom_keywords']) end if opts['message_func'] then message_func = assert(load(opts['message_func']))() end - if opts['limits_hash'] then - limits_hash = opts['limits_hash'] - end + redis_params = lua_redis.parse_redis_server('ratelimit') - redis_params = rspamd_parse_redis_server('ratelimit') if not redis_params then rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') lua_util.disable_module(N, "redis") @@ -715,25 +541,30 @@ if opts then callback = ratelimit_cb, flags = 'empty', } - if use_ip_score then - s.type = 'normal' - end - if ratelimit_symbol then - s.name = ratelimit_symbol - elseif info_symbol then - s.name = info_symbol + + if settings.symbol then + s.name = settings.symbol + elseif settings.info_symbol then + s.name = settings.info_symbol end + rspamd_config:register_symbol(s) - if use_ip_score then - rspamd_config:register_dependency(s.name, 'IP_SCORE') - end - for _, v in pairs(custom_keywords) do - if type(v) == 'table' and type(v['init']) == 'function' then - v['init']() + rspamd_config:register_symbol { + type = 'idempotent', + name = 'RATELIMIT_UPDATE', + callback = ratelimit_update_cb, + } + + if settings.custom_keywords then + for _, v in pairs(settings.custom_keywords) do + if type(v) == 'table' and type(v['init']) == 'function' then + v['init']() + end end end end end + rspamd_config:add_on_load(function(cfg, ev_base, worker) load_scripts(cfg, ev_base) end)