]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Fix] Use a more straightforward approach for learn cache 5547/head
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 30 Jul 2025 09:01:10 +0000 (10:01 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 30 Jul 2025 09:01:10 +0000 (10:01 +0100)
.overcommit.yml
lualib/lua_bayes_redis.lua
lualib/redis_scripts/bayes_cache_learn.lua
src/libstat/learn_cache/redis_cache.cxx

index d26d3de5203c8b5a52a242492c31cb4c703b2b43..9212c33b32944e44fa636af802bb8fe0414767aa 100644 (file)
@@ -29,7 +29,8 @@ PreCommit:
     command: ['luacheck', 'lualib', 'src/plugins/lua']
   ClangFormat:
     enabled: true
-    command: ['git', 'clang-format', '--diff']
+    command: ['sh', '-c', 'git clang-format --diff --quiet || (echo "Running clang-format to fix issues..." && git clang-format && git add -u && echo "Files formatted and staged.")']
+    on_warn: fail
 #PostCheckout:
 #  ALL: # Special hook name that customizes all hooks of this type
 #    quiet: true # Change all post-checkout hooks to only display output on failure
index 53ee06b0efa822f16466ddf112bd787e45dd9291..a7af80bf14bf5c76b1082885214004d354a229f8 100644 (file)
@@ -233,26 +233,16 @@ end
 
 local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
   local packed_conf = ucl.to_format(conf, 'msgpack')
-  return function(task, cache_id, class_name)
+  return function(task, cache_id, class_name, class_id)
     local function learn_redis_cb(err, data)
       lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
     end
 
-    -- Handle backward compatibility for boolean values
-    local cache_class_name = class_name
-    if type(class_name) == "boolean" then
-      cache_class_name = class_name and "spam" or "ham"
-    elseif class_name == true or class_name == "true" then
-      cache_class_name = "spam"
-    elseif class_name == false or class_name == "false" then
-      cache_class_name = "ham"
-    end
-
-    lua_util.debugm(N, task, 'try to learn cache: %s as %s', cache_id, cache_class_name)
+    lua_util.debugm(N, task, 'try to learn cache: %s as %s (id=%s)', cache_id, class_name, class_id)
     lua_redis.exec_redis_script(learn_script_id,
         { task = task, is_write = true, key = cache_id },
         learn_redis_cb,
-        { cache_id, cache_class_name, packed_conf })
+        { cache_id, tostring(class_id), packed_conf })
   end
 end
 
index f18b29c06def4e4b88a4b2e1726f2d7401b9e3de..a7c9ac443cafb343ebd1135d22fbf829e59738a0 100644 (file)
@@ -1,28 +1,15 @@
 -- Lua script to perform cache checking for bayes classification (multi-class)
 -- This script accepts the following parameters:
 -- key1 - cache id
--- key2 - class name (e.g. "spam", "ham", "transactional")
+-- key2 - class_id (numeric hash of class name, computed by C side)
 -- key3 - configuration table in message pack
 
 local cache_id = KEYS[1]
-local class_name = KEYS[2]
+local class_id = KEYS[2]
 local conf = cmsgpack.unpack(KEYS[3])
 
--- Convert class names to numeric cache values for consistency
-local cache_value
-if class_name == "1" or class_name == "spam" or class_name == "S" then
-  cache_value = "1" -- spam
-elseif class_name == "0" or class_name == "ham" or class_name == "H" then
-  cache_value = "0" -- ham
-else
-  -- For other classes, use a simple hash to get a consistent numeric value
-  -- This ensures cache check can return a number while preserving class info
-  local hash = 0
-  for i = 1, #class_name do
-    hash = hash + string.byte(class_name, i)
-  end
-  cache_value = tostring(2 + (hash % 1000)) -- Start from 2, avoid 0/1
-end
+-- Use class_id directly as cache value
+local cache_value = tostring(class_id)
 cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
 
 -- Try each prefix that is in Redis (as some other instance might have set it)
index 0de5cd094111d882f1db45868ef6cc3920dc1a72..afefeadcdab3604630d097993cc3db9c6c13940a 100644 (file)
@@ -152,6 +152,33 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
        return (void *) ctx;
 }
 
+/* Get class ID using rspamd_cryptobox_fast_hash */
+static uint64_t
+rspamd_stat_cache_get_class_id(const char *class_name)
+{
+       if (!class_name) {
+               return 0;
+       }
+
+       if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+               return 1;
+       }
+       else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+               return 0;
+       }
+       else {
+               /* For other classes, use rspamd_cryptobox_fast_hash */
+               uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0);
+
+               /* Ensure we don't get 0 or 1 (reserved for ham/spam) */
+               if (hash == 0 || hash == 1) {
+                       hash += 2;
+               }
+
+               return hash;
+       }
+}
+
 static int
 rspamd_stat_cache_checked(lua_State *L)
 {
@@ -161,23 +188,39 @@ rspamd_stat_cache_checked(lua_State *L)
        if (res) {
                auto val = lua_tointeger(L, 3);
 
-               if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
-                       (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
-                       /* Already learned */
-                       msg_info_task("<%s> has been already "
-                                                 "learned as %s, ignore it",
-                                                 MESSAGE_FIELD(task, message_id),
-                                                 (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
-                       task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+               /* Get the class being learned */
+               const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+               if (!autolearn_class) {
+                       /* Fallback to binary flags for backward compatibility */
+                       if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+                               autolearn_class = "spam";
+                       }
+                       else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+                               autolearn_class = "ham";
+                       }
                }
-               else {
-                       /* Unlearn flag */
-                       task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+
+               if (autolearn_class) {
+                       uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class);
+
+                       if ((uint64_t) val == expected_id) {
+                               /* Already learned */
+                               msg_info_task("<%s> has been already "
+                                                         "learned as %s, ignore it",
+                                                         MESSAGE_FIELD(task, message_id),
+                                                         autolearn_class);
+                               task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+                       }
+                       else {
+                               /* Different class learned, unlearn flag */
+                               msg_debug_task("<%s> cached value %ld != expected %lu for class %s, will unlearn",
+                                                          MESSAGE_FIELD(task, message_id),
+                                                          val, expected_id, autolearn_class);
+                               task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+                       }
                }
        }
 
-       /* Ignore errors for now, as we can do nothing about them at the moment */
-
        return 0;
 }
 
@@ -235,9 +278,20 @@ int rspamd_stat_cache_redis_learn(struct rspamd_task *task,
        lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
        rspamd_lua_task_push(L, task);
        lua_pushstring(L, h);
-       lua_pushboolean(L, is_spam);
 
-       if (lua_pcall(L, 3, 0, err_idx) != 0) {
+       /* Get the class being learned - prefer multiclass over binary */
+       const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+       if (!autolearn_class) {
+               /* Fallback to binary flag for backward compatibility */
+               autolearn_class = is_spam ? "spam" : "ham";
+       }
+
+       /* Push class name and class ID */
+       lua_pushstring(L, autolearn_class);
+       uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class);
+       lua_pushinteger(L, class_id);
+
+       if (lua_pcall(L, 4, 0, err_idx) != 0) {
                msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
                lua_settop(L, err_idx - 1);
                return RSPAMD_LEARN_IGNORE;