]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Fix other classification and learning issues
authorVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 21 Jul 2025 09:55:59 +0000 (10:55 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Mon, 21 Jul 2025 09:55:59 +0000 (10:55 +0100)
lualib/lua_bayes_redis.lua
lualib/redis_scripts/bayes_cache_learn.lua
lualib/redis_scripts/bayes_classify.lua
src/libstat/backends/redis_backend.cxx

index c086669fed13e4bf26398de48d6b16e2c029d439..4de7126c7f48cbdd4b1b8fddd60cf48d8530ef66 100644 (file)
@@ -40,7 +40,8 @@ local function gen_classify_functor(redis_params, classify_script_id)
     -- Determine class labels to send to Redis script
     local script_class_labels
     if type(class_labels) == "table" then
-      script_class_labels = class_labels
+      -- Use simple comma-separated string instead of messagepack
+      script_class_labels = "TABLE:" .. table.concat(class_labels, ",")
     else
       -- Single class label or boolean compatibility
       if class_labels == true or class_labels == "true" then
index b3e15a9bc3fdfd4ffdc92f1307ddf8930f42788f..f18b29c06def4e4b88a4b2e1726f2d7401b9e3de 100644 (file)
@@ -8,11 +8,20 @@ local cache_id = KEYS[1]
 local class_name = KEYS[2]
 local conf = cmsgpack.unpack(KEYS[3])
 
--- Handle backward compatibility for binary values
-if class_name == "1" then
-  class_name = "spam"
-elseif class_name == "0" then
-  class_name = "ham"
+-- Convert class names to numeric cache values for consistency
+local cache_value
+if class_name == "1" or class_name == "spam" or class_name == "S" then
+  cache_value = "1" -- spam
+elseif class_name == "0" or class_name == "ham" or class_name == "H" then
+  cache_value = "0" -- ham
+else
+  -- For other classes, use a simple hash to get a consistent numeric value
+  -- This ensures cache check can return a number while preserving class info
+  local hash = 0
+  for i = 1, #class_name do
+    hash = hash + string.byte(class_name, i)
+  end
+  cache_value = tostring(2 + (hash % 1000)) -- Start from 2, avoid 0/1
 end
 cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
 
@@ -22,8 +31,8 @@ for i = 0, conf.cache_max_keys do
   local have = redis.call('HGET', prefix, cache_id)
 
   if have then
-    -- Already in cache, but class_name changes when relearning
-    redis.call('HSET', prefix, cache_id, class_name)
+    -- Already in cache, but cache_value changes when relearning
+    redis.call('HSET', prefix, cache_id, cache_value)
     return false
   end
 end
@@ -37,7 +46,7 @@ for i = 0, conf.cache_max_keys do
 
     if count < lim then
       -- We can add it to this prefix
-      redis.call('HSET', prefix, cache_id, class_name)
+      redis.call('HSET', prefix, cache_id, cache_value)
       added = true
     end
   end
@@ -53,7 +62,7 @@ if not added then
     if exists then
       if not expired then
         redis.call('DEL', prefix)
-        redis.call('HSET', prefix, cache_id, class_name)
+        redis.call('HSET', prefix, cache_id, cache_value)
 
         -- Do not expire anything else
         expired = true
index 8e6feb32f80be340c015aa0d67f744a6709dbcf3..e07b9a795690a3ddfcc60405b636da0d3003e407 100644 (file)
@@ -10,14 +10,20 @@ local input_tokens = cmsgpack.unpack(KEYS[3])
 
 -- Determine if this is multi-class (table) or binary (string)
 local class_labels = {}
-if type(class_labels_arg) == "table" then
-  class_labels = class_labels_arg
+
+-- Check if this is a table serialized as "TABLE:label1,label2,..."
+if string.match(class_labels_arg, "^TABLE:") then
+  local labels_str = string.sub(class_labels_arg, 7) -- Remove "TABLE:" prefix
+  -- Split by comma
+  for label in string.gmatch(labels_str, "([^,]+)") do
+    table.insert(class_labels, label)
+  end
 else
   -- Binary compatibility: handle old boolean or single string format
   if class_labels_arg == "true" then
-    class_labels = { "S" }            -- spam
+    class_labels = { "S" }              -- spam
   elseif class_labels_arg == "false" then
-    class_labels = { "H" }            -- ham
+    class_labels = { "H" }              -- ham
   else
     class_labels = { class_labels_arg } -- single class label
   end
index 0fe738de50f9e852d08e31f6c0938e8870dd6955..0c663123cbc4d3c1a6c5d1d32436eb2d23781bca 100644 (file)
@@ -926,15 +926,29 @@ rspamd_redis_classified(lua_State *L)
                                return 0;
                        }
 
-                       if (rt->stcf->is_spam) {
-                               filler_func(rt, L, lua_tointeger(L, 4), 6);
-                               filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+                       /* Extract values from the result table at position 3 */
+                       lua_rawgeti(L, 3, 1); /* learned_ham -> position 4 */
+                       lua_rawgeti(L, 3, 2); /* learned_spam -> position 5 */
+                       lua_rawgeti(L, 3, 3); /* ham_tokens -> position 6 */
+                       lua_rawgeti(L, 3, 4); /* spam_tokens -> position 7 */
+
+                       unsigned learned_ham = lua_tointeger(L, 4);
+                       unsigned learned_spam = lua_tointeger(L, 5);
+
+                       if (rt->stcf->is_spam || (rt->stcf->class_name && strcmp(get_class_label(rt->stcf), "S") == 0)) {
+                               /* Current runtime is spam, use spam data */
+                               filler_func(rt, L, learned_spam, 7);                       /* spam_tokens at position 7 */
+                               filler_func(opposite_rt_maybe.value(), L, learned_ham, 6); /* ham_tokens at position 6 */
                        }
                        else {
-                               filler_func(rt, L, lua_tointeger(L, 3), 5);
-                               filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+                               /* Current runtime is ham, use ham data */
+                               filler_func(rt, L, learned_ham, 6);                         /* ham_tokens at position 6 */
+                               filler_func(opposite_rt_maybe.value(), L, learned_spam, 7); /* spam_tokens at position 7 */
                        }
 
+                       /* Clean up the stack - pop the 4 extracted values */
+                       lua_pop(L, 4);
+
                        /* Process all tokens */
                        g_assert(rt->tokens != nullptr);
                        rt->process_tokens(rt->tokens);
@@ -1048,9 +1062,10 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
        lua_pushinteger(L, id);
 
        /* Send all class labels for multi-class support */
-       if (rt->stcf->clcf && rt->stcf->clcf->class_labels && g_hash_table_size(rt->stcf->clcf->class_labels) > 0) {
+       if (rt->stcf->clcf && rt->stcf->clcf->class_labels &&
+               g_hash_table_size(rt->stcf->clcf->class_labels) > 0) {
                /* Multi-class: send array of all class labels */
-               lua_newtable(L);
+               lua_createtable(L, g_hash_table_size(rt->stcf->clcf->class_labels), 0);
                GHashTableIter iter;
                gpointer key, value;
                int idx = 1;
@@ -1061,8 +1076,12 @@ rspamd_redis_process_tokens(struct rspamd_task *task,
                }
        }
        else {
-               /* Binary compatibility: send current class label as single string */
-               lua_pushstring(L, get_class_label(rt->stcf));
+               /* Binary classification: send both spam and ham labels for optimization */
+               lua_createtable(L, 2, 0);
+               lua_pushstring(L, "H"); /* ham */
+               lua_rawseti(L, -2, 1);
+               lua_pushstring(L, "S"); /* spam */
+               lua_rawseti(L, -2, 2);
        }
 
        lua_new_text(L, tokens_buf, tokens_len, false);
@@ -1114,7 +1133,16 @@ rspamd_redis_learned(lua_State *L)
        bool result = lua_toboolean(L, 2);
 
        if (result) {
-               /* TODO: write it */
+               /* Learning successful - no complex data to process like in classification */
+               msg_debug_bayes("learned tokens successfully in Redis for symbol %s, class %s",
+                                               rt->stcf->symbol, get_class_label(rt->stcf));
+
+               /* Clear any previous error state */
+               rt->err = std::nullopt;
+
+               /* Learning operations don't return data structures to process,
+                * they just update Redis state. Success means the Redis script
+                * completed without errors. */
        }
        else {
                /* Error message is on index 3 */