From a22fbdc1ae0a368f47682188b55e6a7fd17fffb4 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 30 Jul 2025 10:01:10 +0100 Subject: [PATCH] [Fix] Use a more straightforward approach for learn cache --- .overcommit.yml | 3 +- lualib/lua_bayes_redis.lua | 16 +---- lualib/redis_scripts/bayes_cache_learn.lua | 21 ++---- src/libstat/learn_cache/redis_cache.cxx | 84 ++++++++++++++++++---- 4 files changed, 78 insertions(+), 46 deletions(-) diff --git a/.overcommit.yml b/.overcommit.yml index d26d3de520..9212c33b32 100644 --- a/.overcommit.yml +++ b/.overcommit.yml @@ -29,7 +29,8 @@ PreCommit: command: ['luacheck', 'lualib', 'src/plugins/lua'] ClangFormat: enabled: true - command: ['git', 'clang-format', '--diff'] + command: ['sh', '-c', 'git clang-format --diff --quiet || (echo "Running clang-format to fix issues..." && git clang-format && git add -u && echo "Files formatted and staged.")'] + on_warn: fail #PostCheckout: # ALL: # Special hook name that customizes all hooks of this type # quiet: true # Change all post-checkout hooks to only display output on failure diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index 53ee06b0ef..a7af80bf14 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -233,26 +233,16 @@ end local function gen_cache_learn_functor(redis_params, learn_script_id, conf) local packed_conf = ucl.to_format(conf, 'msgpack') - return function(task, cache_id, class_name) + return function(task, cache_id, class_name, class_id) local function learn_redis_cb(err, data) lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data) end - -- Handle backward compatibility for boolean values - local cache_class_name = class_name - if type(class_name) == "boolean" then - cache_class_name = class_name and "spam" or "ham" - elseif class_name == true or class_name == "true" then - cache_class_name = "spam" - elseif class_name == false or class_name == "false" then - cache_class_name = "ham" - end - - lua_util.debugm(N, task, 'try to learn cache: %s as %s', cache_id, cache_class_name) + lua_util.debugm(N, task, 'try to learn cache: %s as %s (id=%s)', cache_id, class_name, class_id) lua_redis.exec_redis_script(learn_script_id, { task = task, is_write = true, key = cache_id }, learn_redis_cb, - { cache_id, cache_class_name, packed_conf }) + { cache_id, tostring(class_id), packed_conf }) end end diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua index f18b29c06d..a7c9ac443c 100644 --- a/lualib/redis_scripts/bayes_cache_learn.lua +++ b/lualib/redis_scripts/bayes_cache_learn.lua @@ -1,28 +1,15 @@ -- Lua script to perform cache checking for bayes classification (multi-class) -- This script accepts the following parameters: -- key1 - cache id --- key2 - class name (e.g. "spam", "ham", "transactional") +-- key2 - class_id (numeric hash of class name, computed by C side) -- key3 - configuration table in message pack local cache_id = KEYS[1] -local class_name = KEYS[2] +local class_id = KEYS[2] local conf = cmsgpack.unpack(KEYS[3]) --- Convert class names to numeric cache values for consistency -local cache_value -if class_name == "1" or class_name == "spam" or class_name == "S" then - cache_value = "1" -- spam -elseif class_name == "0" or class_name == "ham" or class_name == "H" then - cache_value = "0" -- ham -else - -- For other classes, use a simple hash to get a consistent numeric value - -- This ensures cache check can return a number while preserving class info - local hash = 0 - for i = 1, #class_name do - hash = hash + string.byte(class_name, i) - end - cache_value = tostring(2 + (hash % 1000)) -- Start from 2, avoid 0/1 -end +-- Use class_id directly as cache value +local cache_value = tostring(class_id) cache_id = string.sub(cache_id, 1, conf.cache_elt_len) -- Try each prefix that is in Redis (as some other instance might have set it) diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx index 0de5cd0941..afefeadcda 100644 --- a/src/libstat/learn_cache/redis_cache.cxx +++ b/src/libstat/learn_cache/redis_cache.cxx @@ -152,6 +152,33 @@ rspamd_stat_cache_redis_runtime(struct rspamd_task *task, return (void *) ctx; } +/* Get class ID using rspamd_cryptobox_fast_hash */ +static uint64_t +rspamd_stat_cache_get_class_id(const char *class_name) +{ + if (!class_name) { + return 0; + } + + if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) { + return 1; + } + else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) { + return 0; + } + else { + /* For other classes, use rspamd_cryptobox_fast_hash */ + uint64_t hash = rspamd_cryptobox_fast_hash(class_name, strlen(class_name), 0); + + /* Ensure we don't get 0 or 1 (reserved for ham/spam) */ + if (hash == 0 || hash == 1) { + hash += 2; + } + + return hash; + } +} + static int rspamd_stat_cache_checked(lua_State *L) { @@ -161,23 +188,39 @@ rspamd_stat_cache_checked(lua_State *L) if (res) { auto val = lua_tointeger(L, 3); - if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || - (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { - /* Already learned */ - msg_info_task("<%s> has been already " - "learned as %s, ignore it", - MESSAGE_FIELD(task, message_id), - (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); - task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + /* Get the class being learned */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (!autolearn_class) { + /* Fallback to binary flags for backward compatibility */ + if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) { + autolearn_class = "spam"; + } + else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + autolearn_class = "ham"; + } } - else { - /* Unlearn flag */ - task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + + if (autolearn_class) { + uint64_t expected_id = rspamd_stat_cache_get_class_id(autolearn_class); + + if ((uint64_t) val == expected_id) { + /* Already learned */ + msg_info_task("<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + autolearn_class); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + } + else { + /* Different class learned, unlearn flag */ + msg_debug_task("<%s> cached value %ld != expected %lu for class %s, will unlearn", + MESSAGE_FIELD(task, message_id), + val, expected_id, autolearn_class); + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + } } } - /* Ignore errors for now, as we can do nothing about them at the moment */ - return 0; } @@ -235,9 +278,20 @@ int rspamd_stat_cache_redis_learn(struct rspamd_task *task, lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref); rspamd_lua_task_push(L, task); lua_pushstring(L, h); - lua_pushboolean(L, is_spam); - if (lua_pcall(L, 3, 0, err_idx) != 0) { + /* Get the class being learned - prefer multiclass over binary */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (!autolearn_class) { + /* Fallback to binary flag for backward compatibility */ + autolearn_class = is_spam ? "spam" : "ham"; + } + + /* Push class name and class ID */ + lua_pushstring(L, autolearn_class); + uint64_t class_id = rspamd_stat_cache_get_class_id(autolearn_class); + lua_pushinteger(L, class_id); + + if (lua_pcall(L, 4, 0, err_idx) != 0) { msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); lua_settop(L, err_idx - 1); return RSPAMD_LEARN_IGNORE; -- 2.47.3