From: Vsevolod Stakhov Date: Mon, 21 Jul 2025 09:55:59 +0000 (+0100) Subject: [Project] Fix other classification and learning issues X-Git-Tag: 3.13.0~38^2~21 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=148411f0d3bed4264283b8647c46e1fce1749c20;p=thirdparty%2Frspamd.git [Project] Fix other classification and learning issues --- diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index c086669fed..4de7126c7f 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -40,7 +40,8 @@ local function gen_classify_functor(redis_params, classify_script_id) -- Determine class labels to send to Redis script local script_class_labels if type(class_labels) == "table" then - script_class_labels = class_labels + -- Use simple comma-separated string instead of messagepack + script_class_labels = "TABLE:" .. table.concat(class_labels, ",") else -- Single class label or boolean compatibility if class_labels == true or class_labels == "true" then diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua index b3e15a9bc3..f18b29c06d 100644 --- a/lualib/redis_scripts/bayes_cache_learn.lua +++ b/lualib/redis_scripts/bayes_cache_learn.lua @@ -8,11 +8,20 @@ local cache_id = KEYS[1] local class_name = KEYS[2] local conf = cmsgpack.unpack(KEYS[3]) --- Handle backward compatibility for binary values -if class_name == "1" then - class_name = "spam" -elseif class_name == "0" then - class_name = "ham" +-- Convert class names to numeric cache values for consistency +local cache_value +if class_name == "1" or class_name == "spam" or class_name == "S" then + cache_value = "1" -- spam +elseif class_name == "0" or class_name == "ham" or class_name == "H" then + cache_value = "0" -- ham +else + -- For other classes, use a simple hash to get a consistent numeric value + -- This ensures cache check can return a number while preserving class info + local hash = 0 + for i = 1, #class_name do + hash = hash + string.byte(class_name, i) + end + cache_value = tostring(2 + (hash % 1000)) -- Start from 2, avoid 0/1 end cache_id = string.sub(cache_id, 1, conf.cache_elt_len) @@ -22,8 +31,8 @@ for i = 0, conf.cache_max_keys do local have = redis.call('HGET', prefix, cache_id) if have then - -- Already in cache, but class_name changes when relearning - redis.call('HSET', prefix, cache_id, class_name) + -- Already in cache, but cache_value changes when relearning + redis.call('HSET', prefix, cache_id, cache_value) return false end end @@ -37,7 +46,7 @@ for i = 0, conf.cache_max_keys do if count < lim then -- We can add it to this prefix - redis.call('HSET', prefix, cache_id, class_name) + redis.call('HSET', prefix, cache_id, cache_value) added = true end end @@ -53,7 +62,7 @@ if not added then if exists then if not expired then redis.call('DEL', prefix) - redis.call('HSET', prefix, cache_id, class_name) + redis.call('HSET', prefix, cache_id, cache_value) -- Do not expire anything else expired = true diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua index 8e6feb32f8..e07b9a7956 100644 --- a/lualib/redis_scripts/bayes_classify.lua +++ b/lualib/redis_scripts/bayes_classify.lua @@ -10,14 +10,20 @@ local input_tokens = cmsgpack.unpack(KEYS[3]) -- Determine if this is multi-class (table) or binary (string) local class_labels = {} -if type(class_labels_arg) == "table" then - class_labels = class_labels_arg + +-- Check if this is a table serialized as "TABLE:label1,label2,..." +if string.match(class_labels_arg, "^TABLE:") then + local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix + -- Split by comma + for label in string.gmatch(labels_str, "([^,]+)") do + table.insert(class_labels, label) + end else -- Binary compatibility: handle old boolean or single string format if class_labels_arg == "true" then - class_labels = { "S" } -- spam + class_labels = { "S" } -- spam elseif class_labels_arg == "false" then - class_labels = { "H" } -- ham + class_labels = { "H" } -- ham else class_labels = { class_labels_arg } -- single class label end diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 0fe738de50..0c663123cb 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -926,15 +926,29 @@ rspamd_redis_classified(lua_State *L) return 0; } - if (rt->stcf->is_spam) { - filler_func(rt, L, lua_tointeger(L, 4), 6); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + /* Extract values from the result table at position 3 */ + lua_rawgeti(L, 3, 1); /* learned_ham -> position 4 */ + lua_rawgeti(L, 3, 2); /* learned_spam -> position 5 */ + lua_rawgeti(L, 3, 3); /* ham_tokens -> position 6 */ + lua_rawgeti(L, 3, 4); /* spam_tokens -> position 7 */ + + unsigned learned_ham = lua_tointeger(L, 4); + unsigned learned_spam = lua_tointeger(L, 5); + + if (rt->stcf->is_spam || (rt->stcf->class_name && strcmp(get_class_label(rt->stcf), "S") == 0)) { + /* Current runtime is spam, use spam data */ + filler_func(rt, L, learned_spam, 7); /* spam_tokens at position 7 */ + filler_func(opposite_rt_maybe.value(), L, learned_ham, 6); /* ham_tokens at position 6 */ } else { - filler_func(rt, L, lua_tointeger(L, 3), 5); - filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); + /* Current runtime is ham, use ham data */ + filler_func(rt, L, learned_ham, 6); /* ham_tokens at position 6 */ + filler_func(opposite_rt_maybe.value(), L, learned_spam, 7); /* spam_tokens at position 7 */ } + /* Clean up the stack - pop the 4 extracted values */ + lua_pop(L, 4); + /* Process all tokens */ g_assert(rt->tokens != nullptr); rt->process_tokens(rt->tokens); @@ -1048,9 +1062,10 @@ rspamd_redis_process_tokens(struct rspamd_task *task, lua_pushinteger(L, id); /* Send all class labels for multi-class support */ - if (rt->stcf->clcf && rt->stcf->clcf->class_labels && g_hash_table_size(rt->stcf->clcf->class_labels) > 0) { + if (rt->stcf->clcf && rt->stcf->clcf->class_labels && + g_hash_table_size(rt->stcf->clcf->class_labels) > 0) { /* Multi-class: send array of all class labels */ - lua_newtable(L); + lua_createtable(L, g_hash_table_size(rt->stcf->clcf->class_labels), 0); GHashTableIter iter; gpointer key, value; int idx = 1; @@ -1061,8 +1076,12 @@ rspamd_redis_process_tokens(struct rspamd_task *task, } } else { - /* Binary compatibility: send current class label as single string */ - lua_pushstring(L, get_class_label(rt->stcf)); + /* Binary classification: send both spam and ham labels for optimization */ + lua_createtable(L, 2, 0); + lua_pushstring(L, "H"); /* ham */ + lua_rawseti(L, -2, 1); + lua_pushstring(L, "S"); /* spam */ + lua_rawseti(L, -2, 2); } lua_new_text(L, tokens_buf, tokens_len, false); @@ -1114,7 +1133,16 @@ rspamd_redis_learned(lua_State *L) bool result = lua_toboolean(L, 2); if (result) { - /* TODO: write it */ + /* Learning successful - no complex data to process like in classification */ + msg_debug_bayes("learned tokens successfully in Redis for symbol %s, class %s", + rt->stcf->symbol, get_class_label(rt->stcf)); + + /* Clear any previous error state */ + rt->err = std::nullopt; + + /* Learning operations don't return data structures to process, + * they just update Redis state. Success means the Redis script + * completed without errors. */ } else { /* Error message is on index 3 */