]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Move ratelimit parsing stuff to a common library
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 11 Sep 2024 13:16:23 +0000 (14:16 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 11 Sep 2024 13:16:23 +0000 (14:16 +0100)
lualib/plugins/ratelimit.lua [new file with mode: 0644]
src/fuzzy_storage.c
src/plugins/lua/ratelimit.lua

diff --git a/lualib/plugins/ratelimit.lua b/lualib/plugins/ratelimit.lua
new file mode 100644 (file)
index 0000000..24afed1
--- /dev/null
@@ -0,0 +1,155 @@
+--[[
+Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local ts = require("tableshape").types
+
+local exports = {}
+
+local limit_parser
+local function parse_string_limit(lim, no_error)
+  local function parse_time_suffix(s)
+    if s == 's' then
+      return 1
+    elseif s == 'm' then
+      return 60
+    elseif s == 'h' then
+      return 3600
+    elseif s == 'd' then
+      return 86400
+    end
+  end
+  local function parse_num_suffix(s)
+    if s == '' then
+      return 1
+    elseif s == 'k' then
+      return 1000
+    elseif s == 'm' then
+      return 1000000
+    elseif s == 'g' then
+      return 1000000000
+    end
+  end
+  local lpeg = require "lpeg"
+
+  if not limit_parser then
+    local digit = lpeg.R("09")
+    limit_parser = {}
+    limit_parser.integer = (lpeg.S("+-") ^ -1) *
+        (digit ^ 1)
+    limit_parser.fractional = (lpeg.P(".")) *
+        (digit ^ 1)
+    limit_parser.number = (limit_parser.integer *
+        (limit_parser.fractional ^ -1)) +
+        (lpeg.S("+-") * limit_parser.fractional)
+    limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
+        (limit_parser.number / tonumber) *
+        ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
+        function(acc, val)
+          return acc * val
+        end)
+    limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
+        (limit_parser.number / tonumber) *
+        ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
+        function(acc, val)
+          return acc * val
+        end)
+    limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
+        (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
+        limit_parser.time)
+  end
+  local t = lpeg.match(limit_parser.limit, lim)
+
+  if t and t[1] and t[2] and t[2] ~= 0 then
+    return t[2], t[1]
+  end
+
+  if not no_error then
+    rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
+  end
+
+  return nil
+end
+
+local function str_to_rate(str)
+  local divider, divisor = parse_string_limit(str, false)
+
+  if not divisor then
+    rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str)
+
+    return nil
+  end
+
+  return divisor / divider
+end
+
+local bucket_schema = ts.shape {
+  burst = ts.number + ts.string / lua_util.dehumanize_number,
+  rate = ts.number + ts.string / str_to_rate,
+  skip_recipients = ts.boolean:is_optional(),
+  symbol = ts.string:is_optional(),
+  message = ts.string:is_optional(),
+  skip_soft_reject = ts.boolean:is_optional(),
+}
+
+exports.parse_limit = function(name, data)
+  if type(data) == 'table' then
+    -- 2 cases here:
+    --  * old limit in format [burst, rate]
+    --  * vector of strings in Andrew's string format (removed from 1.8.2)
+    --  * proper bucket table
+    if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
+      -- Old style ratelimit
+      rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
+      if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
+        return {
+          burst = data[1],
+          rate = data[2]
+        }
+      elseif data[1] ~= 0 then
+        rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
+      else
+        rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
+      end
+
+      return nil
+    else
+      local parsed_bucket, err = bucket_schema:transform(data)
+
+      if not parsed_bucket or err then
+        rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s',
+            name, err, data)
+      else
+        return parsed_bucket
+      end
+    end
+  elseif type(data) == 'string' then
+    local rep_rate, burst = parse_string_limit(data)
+    rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s',
+        name, data)
+    if rep_rate and burst then
+      return {
+        burst = burst,
+        rate = burst / rep_rate -- reciprocal
+      }
+    end
+  end
+
+  return nil
+end
+
+return exports
\ No newline at end of file
index 2f889c75918a33ea781912fde7185a4e3a9d2f20..257d22bcd820503c85924916420067cb7b62daa8 100644 (file)
@@ -698,6 +698,11 @@ fuzzy_key_dtor(gpointer p)
                        kh_destroy(fuzzy_key_ids_set, key->forbidden_ids);
                }
 
+               if (key->rl_bucket) {
+                       /* TODO: save bucket stats */
+                       g_free(key->rl_bucket);
+               }
+
                g_free(key);
        }
 }
@@ -2827,6 +2832,10 @@ fuzzy_add_keypair_from_ucl(const ucl_object_t *obj, khash_t(rspamd_fuzzy_keys_ha
                                                                                                 rspamd_inet_address_hash, rspamd_inet_address_equal);
        key->stat = keystat;
        key->flags_stat = kh_init(fuzzy_key_flag_stat);
+       key->burst = NAN;
+       key->rate = NAN;
+       key->expire = NAN;
+       key->rl_bucket = NULL;
        /* Preallocate some space for flags */
        kh_resize(fuzzy_key_flag_stat, key->flags_stat, 8);
        const unsigned char *pk = rspamd_keypair_component(kp, RSPAMD_KEYPAIR_COMPONENT_PK,
@@ -2874,6 +2883,20 @@ fuzzy_add_keypair_from_ucl(const ucl_object_t *obj, khash_t(rspamd_fuzzy_keys_ha
                                }
                        }
                }
+
+               /*
+                * TODO: parse ratelimit using Lua code from `ratelimit` plugin to
+                * have unified form of settings
+                */
+               const ucl_object_t *ratelimit = ucl_object_lookup(extensions, "ratelimit");
+
+               if (ratelimit && ucl_object_type(ratelimit) == UCL_STRING) {
+               }
+
+               const ucl_object_t *expire = ucl_object_lookup(extensions, "expire");
+               if (expire && ucl_object_type(expire) == UCL_STRING) {
+                       struct tm tm;
+               }
        }
 
        msg_debug("loaded keypair %*bs", crypto_box_publickeybytes(), pk);
index f3331e850e289b6fb9b8750ac60ef4d7f0f8b313..168d8d63ae35208a8652f33a29c1e7ca22731102 100644 (file)
@@ -29,8 +29,7 @@ local lua_util = require "lua_util"
 local lua_verdict = require "lua_verdict"
 local rspamd_hash = require "rspamd_cryptobox_hash"
 local lua_selectors = require "lua_selectors"
-local ts = require("tableshape").types
-
+local ratelimit_common = require "plugins/ratelimit"
 -- A plugin that implements ratelimits using redis
 
 local E = {}
@@ -76,138 +75,6 @@ local function load_scripts(_, _)
   bucket_cleanup_id = lua_redis.load_redis_script_from_file(bucket_cleanup_script, redis_params)
 end
 
-local limit_parser
-local function parse_string_limit(lim, no_error)
-  local function parse_time_suffix(s)
-    if s == 's' then
-      return 1
-    elseif s == 'm' then
-      return 60
-    elseif s == 'h' then
-      return 3600
-    elseif s == 'd' then
-      return 86400
-    end
-  end
-  local function parse_num_suffix(s)
-    if s == '' then
-      return 1
-    elseif s == 'k' then
-      return 1000
-    elseif s == 'm' then
-      return 1000000
-    elseif s == 'g' then
-      return 1000000000
-    end
-  end
-  local lpeg = require "lpeg"
-
-  if not limit_parser then
-    local digit = lpeg.R("09")
-    limit_parser = {}
-    limit_parser.integer = (lpeg.S("+-") ^ -1) *
-        (digit ^ 1)
-    limit_parser.fractional = (lpeg.P(".")) *
-        (digit ^ 1)
-    limit_parser.number = (limit_parser.integer *
-        (limit_parser.fractional ^ -1)) +
-        (lpeg.S("+-") * limit_parser.fractional)
-    limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
-        (limit_parser.number / tonumber) *
-        ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
-        function(acc, val)
-          return acc * val
-        end)
-    limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
-        (limit_parser.number / tonumber) *
-        ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
-        function(acc, val)
-          return acc * val
-        end)
-    limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
-        (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
-        limit_parser.time)
-  end
-  local t = lpeg.match(limit_parser.limit, lim)
-
-  if t and t[1] and t[2] and t[2] ~= 0 then
-    return t[2], t[1]
-  end
-
-  if not no_error then
-    rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
-  end
-
-  return nil
-end
-
-local function str_to_rate(str)
-  local divider, divisor = parse_string_limit(str, false)
-
-  if not divisor then
-    rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str)
-
-    return nil
-  end
-
-  return divisor / divider
-end
-
-local bucket_schema = ts.shape {
-  burst = ts.number + ts.string / lua_util.dehumanize_number,
-  rate = ts.number + ts.string / str_to_rate,
-  skip_recipients = ts.boolean:is_optional(),
-  symbol = ts.string:is_optional(),
-  message = ts.string:is_optional(),
-  skip_soft_reject = ts.boolean:is_optional(),
-}
-
-local function parse_limit(name, data)
-  if type(data) == 'table' then
-    -- 2 cases here:
-    --  * old limit in format [burst, rate]
-    --  * vector of strings in Andrew's string format (removed from 1.8.2)
-    --  * proper bucket table
-    if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
-      -- Old style ratelimit
-      rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
-      if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
-        return {
-          burst = data[1],
-          rate = data[2]
-        }
-      elseif data[1] ~= 0 then
-        rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
-      else
-        rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
-      end
-
-      return nil
-    else
-      local parsed_bucket, err = bucket_schema:transform(data)
-
-      if not parsed_bucket or err then
-        rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s',
-            name, err, data)
-      else
-        return parsed_bucket
-      end
-    end
-  elseif type(data) == 'string' then
-    local rep_rate, burst = parse_string_limit(data)
-    rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s',
-        name, data)
-    if rep_rate and burst then
-      return {
-        burst = burst,
-        rate = burst / rep_rate -- reciprocal
-      }
-    end
-  end
-
-  return nil
-end
-
 --- Check whether this addr is bounce
 local function check_bounce(from)
   return fun.any(function(b)
@@ -490,7 +357,7 @@ local function ratelimit_cb(task)
     local ret, redis_key, bd = pcall(hdl, task)
 
     if ret then
-      local bucket = parse_limit(k, bd)
+      local bucket = ratelimit_common.parse_limit(k, bd)
       if bucket then
         prefixes[redis_key] = make_prefix(redis_key, k, bucket)
       end
@@ -718,7 +585,7 @@ if opts then
 
         if lim.bucket[1] then
           for _, bucket in ipairs(lim.bucket) do
-            local b = parse_limit(t, bucket)
+            local b = ratelimit_common.parse_limit(t, bucket)
 
             if not b then
               rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
@@ -729,7 +596,7 @@ if opts then
             table.insert(buckets, b)
           end
         else
-          local bucket = parse_limit(t, lim.bucket)
+          local bucket = ratelimit_common.parse_limit(t, lim.bucket)
 
           if not bucket then
             rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"',
@@ -757,7 +624,7 @@ if opts then
         end
       else
         rspamd_logger.warnx(rspamd_config, 'old syntax for ratelimits: %s', lim)
-        buckets = parse_limit(t, lim)
+        buckets = ratelimit_common.parse_limit(t, lim)
         if buckets then
           settings.limits[t] = {
             buckets = { buckets }